Repository: kvcache-ai/ktransformers
Branch: main
Commit: 8561a71dd11e
Files: 1146
Total size: 12.2 MB
Directory structure:
gitextract_0e22n38f/
├── .github/
│ ├── CODE_OF_CONDUCT.md
│ ├── CONTRIBUTING.md
│ ├── ISSUE_TEMPLATE/
│ │ ├── -bug-.yaml
│ │ ├── -feature-.yaml
│ │ └── config.yml
│ ├── PULL_REQUEST_TEMPLATE.md
│ ├── SECURITY.md
│ └── workflows/
│ ├── book-ci.yml
│ ├── deploy.yml
│ ├── docker-image.yml
│ ├── kt-kernel-tests.yml
│ ├── release-fake-tag.yml
│ ├── release-pypi.yml
│ ├── release-sglang-kt.yml
│ └── sync-sglang-submodule.yml
├── .gitignore
├── .gitmodules
├── LICENSE
├── MAINTAINERS.md
├── README.md
├── README_ZH.md
├── archive/
│ ├── .devcontainer/
│ │ ├── Dockerfile
│ │ └── devcontainer.json
│ ├── .flake8
│ ├── .gitmodules
│ ├── .pylintrc
│ ├── Dockerfile
│ ├── Dockerfile.xpu
│ ├── LICENSE
│ ├── MANIFEST.in
│ ├── Makefile
│ ├── README.md
│ ├── README_LEGACY.md
│ ├── README_ZH.md
│ ├── README_ZH_LEGACY.md
│ ├── SECURITY.md
│ ├── book.toml
│ ├── config.json
│ ├── csrc/
│ │ ├── balance_serve/
│ │ │ └── CMakeLists.txt
│ │ ├── custom_marlin/
│ │ │ ├── __init__.py
│ │ │ ├── binding.cpp
│ │ │ ├── gptq_marlin/
│ │ │ │ ├── gptq_marlin.cu
│ │ │ │ ├── gptq_marlin.cuh
│ │ │ │ ├── gptq_marlin_dtypes.cuh
│ │ │ │ ├── gptq_marlin_repack.cu
│ │ │ │ └── ops.h
│ │ │ ├── setup.py
│ │ │ ├── test_cuda_graph.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── format24.py
│ │ │ ├── marlin_24_perms.py
│ │ │ ├── marlin_perms.py
│ │ │ ├── marlin_utils.py
│ │ │ └── quant_utils.py
│ │ └── ktransformers_ext/
│ │ ├── CMakeLists.txt
│ │ ├── bench/
│ │ │ ├── bench_attention.py
│ │ │ ├── bench_attention_torch.py
│ │ │ ├── bench_linear.py
│ │ │ ├── bench_linear_torch.py
│ │ │ ├── bench_mlp.py
│ │ │ ├── bench_mlp_torch.py
│ │ │ ├── bench_moe.py
│ │ │ ├── bench_moe_amx.py
│ │ │ └── bench_moe_torch.py
│ │ ├── cmake/
│ │ │ └── FindSIMD.cmake
│ │ ├── cpu_backend/
│ │ │ ├── backend.cpp
│ │ │ ├── backend.h
│ │ │ ├── cpuinfer.h
│ │ │ ├── shared_mem_buffer.cpp
│ │ │ ├── shared_mem_buffer.h
│ │ │ ├── task_queue.cpp
│ │ │ ├── task_queue.h
│ │ │ └── vendors/
│ │ │ ├── README.md
│ │ │ ├── cuda.h
│ │ │ ├── hip.h
│ │ │ ├── musa.h
│ │ │ └── vendor.h
│ │ ├── cuda/
│ │ │ ├── binding.cpp
│ │ │ ├── custom_gguf/
│ │ │ │ ├── dequant.cu
│ │ │ │ └── ops.h
│ │ │ ├── gptq_marlin/
│ │ │ │ ├── gptq_marlin.cu
│ │ │ │ ├── gptq_marlin.cuh
│ │ │ │ ├── gptq_marlin_dtypes.cuh
│ │ │ │ └── ops.h
│ │ │ ├── setup.py
│ │ │ └── test_dequant.py
│ │ ├── examples/
│ │ │ ├── test_attention.py
│ │ │ ├── test_linear.py
│ │ │ ├── test_mlp.py
│ │ │ └── test_moe.py
│ │ ├── ext_bindings.cpp
│ │ ├── operators/
│ │ │ ├── amx/
│ │ │ │ ├── la/
│ │ │ │ │ ├── amx.hpp
│ │ │ │ │ └── utils.hpp
│ │ │ │ └── moe.hpp
│ │ │ ├── kvcache/
│ │ │ │ ├── kvcache.h
│ │ │ │ ├── kvcache_attn.cpp
│ │ │ │ ├── kvcache_load_dump.cpp
│ │ │ │ ├── kvcache_read_write.cpp
│ │ │ │ └── kvcache_utils.cpp
│ │ │ └── llamafile/
│ │ │ ├── conversion.h
│ │ │ ├── linear.cpp
│ │ │ ├── linear.h
│ │ │ ├── mlp.cpp
│ │ │ ├── mlp.h
│ │ │ ├── moe.cpp
│ │ │ └── moe.h
│ │ └── vendors/
│ │ ├── cuda.h
│ │ ├── hip.h
│ │ ├── musa.h
│ │ └── vendor.h
│ ├── install-with-cache.sh
│ ├── install.bat
│ ├── install.sh
│ ├── ktransformers/
│ │ ├── __init__.py
│ │ ├── configs/
│ │ │ ├── config.yaml
│ │ │ └── log_config.ini
│ │ ├── ktransformers_ext/
│ │ │ ├── operators/
│ │ │ │ └── custom_marlin/
│ │ │ │ └── quantize/
│ │ │ │ └── utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── format_24.py
│ │ │ │ ├── marlin_24_perms.py
│ │ │ │ ├── marlin_perms.py
│ │ │ │ ├── marlin_utils.py
│ │ │ │ └── quant_utils.py
│ │ │ └── triton/
│ │ │ └── fp8gemm.py
│ │ ├── local_chat.py
│ │ ├── local_chat_test.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── ascend/
│ │ │ │ ├── custom_ascend_modeling_deepseek_v3.py
│ │ │ │ └── custom_ascend_modeling_qwen3.py
│ │ │ ├── configuration_deepseek.py
│ │ │ ├── configuration_deepseek_v3.py
│ │ │ ├── configuration_glm4_moe.py
│ │ │ ├── configuration_llama.py
│ │ │ ├── configuration_qwen2_moe.py
│ │ │ ├── configuration_qwen3_moe.py
│ │ │ ├── configuration_qwen3_next.py
│ │ │ ├── configuration_smallthinker.py
│ │ │ ├── custom_cache.py
│ │ │ ├── custom_modeling_deepseek_v2.py
│ │ │ ├── custom_modeling_deepseek_v3.py
│ │ │ ├── custom_modeling_glm4_moe.py
│ │ │ ├── custom_modeling_qwen2_moe.py
│ │ │ ├── custom_modeling_qwen3_moe.py
│ │ │ ├── custom_modeling_qwen3_next.py
│ │ │ ├── custom_modeling_smallthinker.py
│ │ │ ├── modeling_deepseek.py
│ │ │ ├── modeling_deepseek_v3.py
│ │ │ ├── modeling_glm4_moe.py
│ │ │ ├── modeling_llama.py
│ │ │ ├── modeling_mixtral.py
│ │ │ ├── modeling_qwen2_moe.py
│ │ │ ├── modeling_qwen3_moe.py
│ │ │ ├── modeling_qwen3_next.py
│ │ │ └── modeling_smallthinker.py
│ │ ├── operators/
│ │ │ ├── RoPE.py
│ │ │ ├── __init__.py
│ │ │ ├── ascend/
│ │ │ │ ├── ascend_attention.py
│ │ │ │ ├── ascend_experts.py
│ │ │ │ ├── ascend_gate.py
│ │ │ │ ├── ascend_layernorm.py
│ │ │ │ ├── ascend_linear.py
│ │ │ │ └── ascend_mlp.py
│ │ │ ├── attention.py
│ │ │ ├── balance_serve_attention.py
│ │ │ ├── base_operator.py
│ │ │ ├── cpuinfer.py
│ │ │ ├── dynamic_attention.py
│ │ │ ├── experts.py
│ │ │ ├── flashinfer_batch_prefill_wrapper.py
│ │ │ ├── flashinfer_wrapper.py
│ │ │ ├── gate.py
│ │ │ ├── layernorm.py
│ │ │ ├── linear.py
│ │ │ ├── mlp.py
│ │ │ ├── models.py
│ │ │ ├── triton_attention.py
│ │ │ └── triton_attention_prefill.py
│ │ ├── optimize/
│ │ │ ├── optimize.py
│ │ │ └── optimize_rules/
│ │ │ ├── DeepSeek-V2-Chat-multi-gpu-4.yaml
│ │ │ ├── DeepSeek-V2-Chat-multi-gpu.yaml
│ │ │ ├── DeepSeek-V2-Chat.yaml
│ │ │ ├── DeepSeek-V2-Lite-Chat-gpu-cpu.yaml
│ │ │ ├── DeepSeek-V2-Lite-Chat-multi-gpu.yaml
│ │ │ ├── DeepSeek-V2-Lite-Chat.yaml
│ │ │ ├── DeepSeek-V3-Chat-amx.yaml
│ │ │ ├── DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve-amx.yaml
│ │ │ ├── DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve.yaml
│ │ │ ├── DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
│ │ │ ├── DeepSeek-V3-Chat-multi-gpu-4.yaml
│ │ │ ├── DeepSeek-V3-Chat-multi-gpu-8.yaml
│ │ │ ├── DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
│ │ │ ├── DeepSeek-V3-Chat-multi-gpu-marlin.yaml
│ │ │ ├── DeepSeek-V3-Chat-multi-gpu.yaml
│ │ │ ├── DeepSeek-V3-Chat-npu.yaml
│ │ │ ├── DeepSeek-V3-Chat-serve.yaml
│ │ │ ├── DeepSeek-V3-Chat.yaml
│ │ │ ├── Glm4Moe-serve.yaml
│ │ │ ├── Internlm2_5-7b-Chat-1m.yaml
│ │ │ ├── Mixtral.yaml
│ │ │ ├── Moonlight-16B-A3B-serve.yaml
│ │ │ ├── Moonlight-16B-A3B.yaml
│ │ │ ├── Qwen2-57B-A14B-Instruct-multi-gpu.yaml
│ │ │ ├── Qwen2-57B-A14B-Instruct.yaml
│ │ │ ├── Qwen2-serve-amx.yaml
│ │ │ ├── Qwen2-serve.yaml
│ │ │ ├── Qwen3Moe-serve-amx.yaml
│ │ │ ├── Qwen3Moe-serve.yaml
│ │ │ ├── Qwen3Next-serve.yaml
│ │ │ ├── Smallthinker-serve.yaml
│ │ │ ├── npu/
│ │ │ │ ├── DeepSeek-V3-Chat-300IA2-npu-serve.yaml
│ │ │ │ ├── DeepSeek-V3-Chat-300IA2-npu.yaml
│ │ │ │ └── Qwen3-Chat-300IA2-npu-serve.yaml
│ │ │ ├── rocm/
│ │ │ │ └── DeepSeek-V3-Chat.yaml
│ │ │ └── xpu/
│ │ │ ├── DeepSeek-V2-Chat.yaml
│ │ │ ├── DeepSeek-V3-Chat.yaml
│ │ │ └── Qwen3Moe-Chat.yaml
│ │ ├── server/
│ │ │ ├── __init__.py
│ │ │ ├── api/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── ollama/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── completions.py
│ │ │ │ ├── openai/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── assistants/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── assistants.py
│ │ │ │ │ │ ├── messages.py
│ │ │ │ │ │ ├── runs.py
│ │ │ │ │ │ └── threads.py
│ │ │ │ │ ├── endpoints/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ └── chat.py
│ │ │ │ │ └── legacy/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── completions.py
│ │ │ │ └── web/
│ │ │ │ ├── __init__.py
│ │ │ │ └── system.py
│ │ │ ├── args.py
│ │ │ ├── backend/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── args.py
│ │ │ │ ├── base.py
│ │ │ │ ├── context_manager.py
│ │ │ │ └── interfaces/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── balance_serve.py
│ │ │ │ ├── exllamav2.py
│ │ │ │ ├── ktransformers.py
│ │ │ │ └── transformers.py
│ │ │ ├── balance_serve/
│ │ │ │ ├── inference/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── config.py
│ │ │ │ │ ├── distributed/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── communication_op.py
│ │ │ │ │ │ ├── cuda_wrapper.py
│ │ │ │ │ │ ├── custom_all_reduce.py
│ │ │ │ │ │ ├── custom_all_reduce_utils.py
│ │ │ │ │ │ ├── parallel_state.py
│ │ │ │ │ │ ├── pynccl.py
│ │ │ │ │ │ ├── pynccl_wrapper.py
│ │ │ │ │ │ └── utils.py
│ │ │ │ │ ├── forward_batch.py
│ │ │ │ │ ├── model_runner.py
│ │ │ │ │ ├── query_manager.py
│ │ │ │ │ └── sampling/
│ │ │ │ │ ├── penaltylib/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── orchestrator.py
│ │ │ │ │ │ └── penalizers/
│ │ │ │ │ │ ├── frequency_penalty.py
│ │ │ │ │ │ ├── min_new_tokens.py
│ │ │ │ │ │ ├── presence_penalty.py
│ │ │ │ │ │ └── repetition_penalty.py
│ │ │ │ │ └── sampler.py
│ │ │ │ ├── sched_rpc.py
│ │ │ │ └── settings.py
│ │ │ ├── config/
│ │ │ │ ├── config.py
│ │ │ │ ├── log.py
│ │ │ │ └── singleton.py
│ │ │ ├── crud/
│ │ │ │ ├── __init__.py
│ │ │ │ └── assistants/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── assistants.py
│ │ │ │ ├── messages.py
│ │ │ │ ├── runs.py
│ │ │ │ └── threads.py
│ │ │ ├── exceptions.py
│ │ │ ├── main.py
│ │ │ ├── models/
│ │ │ │ ├── __init__.py
│ │ │ │ └── assistants/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── assistants.py
│ │ │ │ ├── messages.py
│ │ │ │ ├── run_steps.py
│ │ │ │ ├── runs.py
│ │ │ │ └── threads.py
│ │ │ ├── requirements.txt
│ │ │ ├── schemas/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── assistants/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── assistants.py
│ │ │ │ │ ├── messages.py
│ │ │ │ │ ├── runs.py
│ │ │ │ │ ├── streaming.py
│ │ │ │ │ ├── threads.py
│ │ │ │ │ └── tool.py
│ │ │ │ ├── base.py
│ │ │ │ ├── conversation.py
│ │ │ │ ├── endpoints/
│ │ │ │ │ └── chat.py
│ │ │ │ └── legacy/
│ │ │ │ ├── __init__.py
│ │ │ │ └── completions.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── create_interface.py
│ │ │ ├── multi_timer.py
│ │ │ ├── serve_profiling.py
│ │ │ └── sql_utils.py
│ │ ├── tests/
│ │ │ ├── .gitignore
│ │ │ ├── AIME_2024/
│ │ │ │ ├── eval_api.py
│ │ │ │ ├── evaluation.py
│ │ │ │ └── prompts.py
│ │ │ ├── UT/
│ │ │ │ ├── test_kdeepseek_attention_w8a8a2serve_npu.py
│ │ │ │ └── test_kdeepseek_ln_npu.py
│ │ │ ├── dequant_gpu.py
│ │ │ ├── dequant_gpu_t.py
│ │ │ ├── function_call_test.py
│ │ │ ├── humaneval/
│ │ │ │ ├── eval_api.py
│ │ │ │ ├── evaluation.py
│ │ │ │ └── prompts.py
│ │ │ ├── mmlu_pro_test.py
│ │ │ ├── mmlu_test.py
│ │ │ ├── mmlu_test_multi.py
│ │ │ ├── parse_cover_info.py
│ │ │ ├── score.py
│ │ │ ├── test_client.py
│ │ │ ├── test_prefix.py
│ │ │ ├── test_pytorch_q8.py
│ │ │ ├── test_speed.py
│ │ │ └── triton_fp8gemm_test.py
│ │ ├── util/
│ │ │ ├── ascend/
│ │ │ │ └── ascend_utils.py
│ │ │ ├── cuda_graph_runner.py
│ │ │ ├── custom_gguf.py
│ │ │ ├── custom_loader.py
│ │ │ ├── modeling_rope_utils.py
│ │ │ ├── npu_graph_runner.py
│ │ │ ├── textstream.py
│ │ │ ├── utils.py
│ │ │ ├── vendors.py
│ │ │ └── weight_loader.py
│ │ └── website/
│ │ ├── .browserslistrc
│ │ ├── .eslintrc.js
│ │ ├── .gitignore
│ │ ├── README.md
│ │ ├── config.d.ts
│ │ ├── jest.config.js
│ │ ├── package.json
│ │ ├── public/
│ │ │ ├── config.js
│ │ │ ├── css/
│ │ │ │ └── reset.css
│ │ │ └── index.html
│ │ ├── src/
│ │ │ ├── App.vue
│ │ │ ├── api/
│ │ │ │ ├── api-client.ts
│ │ │ │ ├── assistant.ts
│ │ │ │ ├── message.ts
│ │ │ │ ├── run.ts
│ │ │ │ └── thread.ts
│ │ │ ├── assets/
│ │ │ │ ├── css/
│ │ │ │ │ └── mixins.styl
│ │ │ │ └── iconfont/
│ │ │ │ ├── demo.css
│ │ │ │ ├── demo_index.html
│ │ │ │ ├── iconfont.css
│ │ │ │ ├── iconfont.js
│ │ │ │ └── iconfont.json
│ │ │ ├── components/
│ │ │ │ └── chat/
│ │ │ │ └── index.vue
│ │ │ ├── conf/
│ │ │ │ └── config.ts
│ │ │ ├── locals/
│ │ │ │ ├── en.js
│ │ │ │ ├── index.js
│ │ │ │ └── zh.js
│ │ │ ├── main.ts
│ │ │ ├── router/
│ │ │ │ └── index.ts
│ │ │ ├── shims-vue.d.ts
│ │ │ ├── store/
│ │ │ │ └── index.ts
│ │ │ ├── utils/
│ │ │ │ ├── copy.ts
│ │ │ │ └── types.ts
│ │ │ └── views/
│ │ │ └── home.vue
│ │ ├── tests/
│ │ │ └── unit/
│ │ │ └── example.spec.ts
│ │ ├── tsconfig.json
│ │ └── vue.config.js
│ ├── merge_tensors/
│ │ ├── merge_safetensor_gguf.py
│ │ └── merge_safetensor_gguf_for_qwen3.py
│ ├── pyproject.toml
│ ├── requirements-local_chat.txt
│ ├── setup.py
│ └── third_party/
│ ├── llamafile/
│ │ ├── README.md
│ │ ├── bench.h
│ │ ├── flags.cpp
│ │ ├── flags.h
│ │ ├── iqk_mul_mat.inc
│ │ ├── iqk_mul_mat_amd_avx2.cpp
│ │ ├── iqk_mul_mat_amd_zen4.cpp
│ │ ├── iqk_mul_mat_arm.inc
│ │ ├── iqk_mul_mat_arm82.cpp
│ │ ├── iqk_mul_mat_x86.inc
│ │ ├── macros.h
│ │ ├── micros.h
│ │ ├── numba.h
│ │ ├── sgemm.cpp
│ │ ├── sgemm.h
│ │ ├── sgemm_arm.cpp
│ │ ├── sgemm_x86.cpp
│ │ ├── tinyblas_cpu.h
│ │ ├── tinyblas_cpu_mixmul.inc
│ │ ├── tinyblas_cpu_mixmul_amd_avx.cpp
│ │ ├── tinyblas_cpu_mixmul_amd_avx2.cpp
│ │ ├── tinyblas_cpu_mixmul_amd_avx512f.cpp
│ │ ├── tinyblas_cpu_mixmul_amd_avxvnni.cpp
│ │ ├── tinyblas_cpu_mixmul_amd_fma.cpp
│ │ ├── tinyblas_cpu_mixmul_amd_zen4.cpp
│ │ ├── tinyblas_cpu_mixmul_arm80.cpp
│ │ ├── tinyblas_cpu_mixmul_arm82.cpp
│ │ ├── tinyblas_cpu_sgemm.inc
│ │ ├── tinyblas_cpu_sgemm_amd_avx.cpp
│ │ ├── tinyblas_cpu_sgemm_amd_avx2.cpp
│ │ ├── tinyblas_cpu_sgemm_amd_avx512f.cpp
│ │ ├── tinyblas_cpu_sgemm_amd_avxvnni.cpp
│ │ ├── tinyblas_cpu_sgemm_amd_fma.cpp
│ │ ├── tinyblas_cpu_sgemm_amd_zen4.cpp
│ │ ├── tinyblas_cpu_sgemm_arm.inc
│ │ ├── tinyblas_cpu_sgemm_arm80.cpp
│ │ ├── tinyblas_cpu_sgemm_arm82.cpp
│ │ ├── tinyblas_cpu_sgemm_x86.inc
│ │ └── tinyblas_cpu_unsupported.cpp
│ └── nlohmann/
│ ├── json.hpp
│ └── json_fwd.hpp
├── book.toml
├── doc/
│ ├── SUMMARY.md
│ ├── basic/
│ │ ├── note1.md
│ │ └── note2.md
│ ├── en/
│ │ ├── AMX.md
│ │ ├── DeepseekR1_V3_tutorial.md
│ │ ├── Docker.md
│ │ ├── Docker_xpu.md
│ │ ├── FAQ.md
│ │ ├── Kimi-K2-Thinking.md
│ │ ├── Kimi-K2.5.md
│ │ ├── Kimi-K2.md
│ │ ├── Kllama_tutorial_DeepSeekV2Lite.ipynb
│ │ ├── MiniMax-M2.5.md
│ │ ├── Qwen3-Next.md
│ │ ├── Qwen3.5.md
│ │ ├── ROCm.md
│ │ ├── SFT/
│ │ │ ├── DPO_tutorial.md
│ │ │ ├── KTransformers-Fine-Tuning_Developer-Technical-Notes.md
│ │ │ ├── KTransformers-Fine-Tuning_User-Guide.md
│ │ │ ├── README.md
│ │ │ └── injection_tutorial.md
│ │ ├── SFT_Installation_Guide_KimiK2.5.md
│ │ ├── SFT_Installation_Guide_KimiK2.md
│ │ ├── SmallThinker_and_Glm4moe.md
│ │ ├── V3-success.md
│ │ ├── api/
│ │ │ └── server/
│ │ │ ├── api.md
│ │ │ ├── server.md
│ │ │ ├── tabby.md
│ │ │ └── website.md
│ │ ├── balance-serve.md
│ │ ├── benchmark.md
│ │ ├── deepseek-v2-injection.md
│ │ ├── fp8_kernel.md
│ │ ├── install.md
│ │ ├── kt-kernel/
│ │ │ ├── GLM-5-Tutorial.md
│ │ │ ├── Kimi-K2-Thinking-Native.md
│ │ │ ├── MiniMax-M2.1-Tutorial.md
│ │ │ ├── Native-Precision-Tutorial.md
│ │ │ ├── Qwen3-Coder-Next-Tutorial.md
│ │ │ ├── README.md
│ │ │ ├── amd_blis.md
│ │ │ ├── deepseek-v3.2-sglang-tutorial.md
│ │ │ ├── experts-sched-Tutorial.md
│ │ │ └── kt-cli.md
│ │ ├── llama4.md
│ │ ├── long_context_introduction.md
│ │ ├── long_context_tutorial.md
│ │ ├── makefile_usage.md
│ │ ├── multi-gpu-tutorial.md
│ │ ├── operators/
│ │ │ └── llamafile.md
│ │ ├── prefix_cache.md
│ │ └── xpu.md
│ └── zh/
│ ├── DeepseekR1_V3_tutorial_zh.md
│ ├── DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md
│ ├── KTransformers-Fine-Tuning_Developer-Technical-Notes_zh.md
│ ├── KTransformers-Fine-Tuning_User-Guide_zh.md
│ ├── Qwen3-MoE_tutorial_zh_for_Ascend_NPU.md
│ ├── api/
│ │ └── server/
│ │ ├── api.md
│ │ ├── server.md
│ │ ├── tabby.md
│ │ └── website.md
│ └── clawdbot_integration_guide.md
├── docker/
│ ├── Dockerfile
│ ├── README-packaging.md
│ ├── docker-utils.sh
│ └── push-to-dockerhub.sh
├── install.sh
├── kt-kernel/
│ ├── .clang-format
│ ├── .githooks/
│ │ ├── commit-msg
│ │ └── pre-commit
│ ├── .gitignore
│ ├── .gitmodules
│ ├── CMakeLists.txt
│ ├── CMakePresets.json
│ ├── MANIFEST.in
│ ├── README.md
│ ├── README_zh.md
│ ├── bench/
│ │ ├── .gitignore
│ │ ├── Makefile
│ │ ├── bench_attention.py
│ │ ├── bench_attention_torch.py
│ │ ├── bench_bf16_moe.py
│ │ ├── bench_fp8_moe.py
│ │ ├── bench_fp8_perchannel_moe.py
│ │ ├── bench_k2_moe_amx.py
│ │ ├── bench_k2_write_buffer.py
│ │ ├── bench_linear.py
│ │ ├── bench_linear_torch.py
│ │ ├── bench_mla.py
│ │ ├── bench_mlp.py
│ │ ├── bench_mlp_torch.py
│ │ ├── bench_moe.py
│ │ ├── bench_moe_amx.py
│ │ ├── bench_moe_amx_k.py
│ │ ├── bench_moe_kernel.py
│ │ ├── bench_moe_kernel_tiling.py
│ │ ├── bench_moe_kml.py
│ │ ├── bench_moe_torch.py
│ │ ├── bench_write_buffer.py
│ │ ├── compare_moe_performance.py
│ │ ├── multi_bench_moe.py
│ │ └── upload-bench-json.py
│ ├── cmake/
│ │ ├── DetectCPU.cmake
│ │ └── FindSIMD.cmake
│ ├── cpu_backend/
│ │ ├── cpuinfer.h
│ │ ├── shared_mem_buffer.cpp
│ │ ├── shared_mem_buffer.h
│ │ ├── task_queue.cpp
│ │ ├── task_queue.h
│ │ ├── vendors/
│ │ │ ├── README.md
│ │ │ ├── cuda.h
│ │ │ ├── hip.h
│ │ │ ├── musa.h
│ │ │ └── vendor.h
│ │ ├── worker_pool.cpp
│ │ └── worker_pool.h
│ ├── cuda/
│ │ ├── binding.cpp
│ │ ├── custom_gguf/
│ │ │ ├── dequant.cu
│ │ │ └── ops.h
│ │ ├── gptq_marlin/
│ │ │ ├── gptq_marlin.cu
│ │ │ ├── gptq_marlin.cuh
│ │ │ ├── gptq_marlin_dtypes.cuh
│ │ │ └── ops.h
│ │ ├── moe/
│ │ │ ├── moe_topk_softmax_kernels.cu
│ │ │ ├── ops.h
│ │ │ └── utils.h
│ │ ├── setup.py
│ │ └── test_dequant.py
│ ├── demo/
│ │ ├── .gitignore
│ │ ├── Makefile
│ │ ├── bench_reorder_bandwidth.cpp
│ │ ├── bf16-test.cpp
│ │ ├── fp16-test.cpp
│ │ ├── plot.py
│ │ ├── simple_test.cpp
│ │ ├── simple_test_aocl.cpp
│ │ └── tflops.py
│ ├── examples/
│ │ ├── .gitignore
│ │ ├── bench_moe_amx_int8.py
│ │ ├── configuration_deepseek_v3.py
│ │ ├── modeling_deepseek_v3.py
│ │ ├── repro_llamafile_re.py
│ │ ├── test-debug.py
│ │ ├── test_apply_rope.py
│ │ ├── test_attention.py
│ │ ├── test_awq_moe_amx.py
│ │ ├── test_bf16_moe.py
│ │ ├── test_deepseekv3.py
│ │ ├── test_deepseekv3_prefill.py
│ │ ├── test_deepseekv3_prefill_speed.py
│ │ ├── test_fp8_moe.py
│ │ ├── test_fp8_perchannel_moe.py
│ │ ├── test_gate.py
│ │ ├── test_k2_moe_amx.py
│ │ ├── test_k2_write_buffer.py
│ │ ├── test_linear.py
│ │ ├── test_mla.py
│ │ ├── test_mla_qlen.py
│ │ ├── test_mla_quant.py
│ │ ├── test_mla_simple.py
│ │ ├── test_mla_torch.py
│ │ ├── test_mlp.py
│ │ ├── test_moe.py
│ │ ├── test_moe_amx.py
│ │ ├── test_moe_kernel.py
│ │ ├── test_moe_kml.py
│ │ ├── test_rope.cpp
│ │ ├── test_rope.py
│ │ ├── test_softmax.py
│ │ ├── test_write_buffer.py
│ │ └── torch_attention.py
│ ├── ext_bindings.cpp
│ ├── install.sh
│ ├── operators/
│ │ ├── amx/
│ │ │ ├── awq-moe.hpp
│ │ │ ├── bf16-moe.hpp
│ │ │ ├── fp8-moe.hpp
│ │ │ ├── fp8-perchannel-moe.hpp
│ │ │ ├── k2-moe.hpp
│ │ │ ├── la/
│ │ │ │ ├── amx-example.cpp
│ │ │ │ ├── amx.hpp
│ │ │ │ ├── amx_buffers.hpp
│ │ │ │ ├── amx_config.hpp
│ │ │ │ ├── amx_kernels.hpp
│ │ │ │ ├── amx_quantization.hpp
│ │ │ │ ├── amx_raw_buffers.hpp
│ │ │ │ ├── amx_raw_kernels.hpp
│ │ │ │ ├── amx_utils.hpp
│ │ │ │ ├── pack.hpp
│ │ │ │ └── utils.hpp
│ │ │ ├── moe.hpp
│ │ │ ├── moe_base.hpp
│ │ │ └── test/
│ │ │ ├── amx-bkgroup-test.cpp
│ │ │ ├── amx-c-reduce-test.cpp
│ │ │ ├── amx-kgroup-test.cpp
│ │ │ ├── amx-test.cpp
│ │ │ ├── analyze-error.cpp
│ │ │ ├── avx-test.cpp
│ │ │ ├── debug-kgroup-details.cpp
│ │ │ ├── debug-kgroup.cpp
│ │ │ ├── debug-specific-dims.cpp
│ │ │ ├── mat-test.hpp
│ │ │ ├── mmq-test.cpp
│ │ │ ├── mmq.cpp
│ │ │ ├── mmq.h
│ │ │ ├── test-kgroup-128.cpp
│ │ │ ├── test-kgroup-kernel.cpp
│ │ │ ├── test-specific-dims.cpp
│ │ │ ├── thread_test.sh
│ │ │ ├── timer.hh
│ │ │ └── verify-kgroup.cpp
│ │ ├── common.hpp
│ │ ├── kvcache/
│ │ │ ├── kvcache.h
│ │ │ ├── kvcache_attn.cpp
│ │ │ ├── kvcache_load_dump.cpp
│ │ │ ├── kvcache_read_write.cpp
│ │ │ └── kvcache_utils.cpp
│ │ ├── llamafile/
│ │ │ ├── conversion.h
│ │ │ ├── linear.cpp
│ │ │ ├── linear.h
│ │ │ ├── mla.hpp
│ │ │ ├── mlp.cpp
│ │ │ ├── mlp.h
│ │ │ └── moe.hpp
│ │ ├── mla-tp.hpp
│ │ ├── moe-tp.hpp
│ │ ├── moe_kernel/
│ │ │ ├── api/
│ │ │ │ ├── common.h
│ │ │ │ └── mat_kernel.h
│ │ │ ├── la/
│ │ │ │ ├── kernel.hpp
│ │ │ │ ├── mat_kernel.cpp
│ │ │ │ └── utils.hpp
│ │ │ ├── mat_kernel/
│ │ │ │ ├── aocl_kernel/
│ │ │ │ │ └── kernel.cpp
│ │ │ │ └── batch_gemm_api.hpp
│ │ │ ├── moe.hpp
│ │ │ └── test/
│ │ │ ├── convert-test.cpp
│ │ │ ├── debug.hpp
│ │ │ ├── int4_mul-test.cpp
│ │ │ ├── mat_test.cpp
│ │ │ └── utils_test.cpp
│ │ ├── reduce.hpp
│ │ ├── rms-norm.hpp
│ │ ├── rope.hpp
│ │ ├── softmax.hpp
│ │ └── tp.hpp
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── python/
│ │ ├── __init__.py
│ │ ├── _cpu_detect.py
│ │ ├── cli/
│ │ │ ├── __init__.py
│ │ │ ├── commands/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── bench.py
│ │ │ │ ├── chat.py
│ │ │ │ ├── config.py
│ │ │ │ ├── doctor.py
│ │ │ │ ├── model.py
│ │ │ │ ├── quant.py
│ │ │ │ ├── run.py
│ │ │ │ ├── sft.py
│ │ │ │ └── version.py
│ │ │ ├── completions/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── _kt
│ │ │ │ ├── kt-completion.bash
│ │ │ │ └── kt.fish
│ │ │ ├── config/
│ │ │ │ ├── __init__.py
│ │ │ │ └── settings.py
│ │ │ ├── i18n.py
│ │ │ ├── main.py
│ │ │ ├── requirements/
│ │ │ │ ├── inference.txt
│ │ │ │ └── sft.txt
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── analyze_moe_model.py
│ │ │ ├── console.py
│ │ │ ├── debug_configs.py
│ │ │ ├── download_helper.py
│ │ │ ├── environment.py
│ │ │ ├── input_validators.py
│ │ │ ├── kv_cache_calculator.py
│ │ │ ├── model_discovery.py
│ │ │ ├── model_registry.py
│ │ │ ├── model_scanner.py
│ │ │ ├── model_table_builder.py
│ │ │ ├── model_verifier.py
│ │ │ ├── port_checker.py
│ │ │ ├── quant_interactive.py
│ │ │ ├── repo_detector.py
│ │ │ ├── run_configs.py
│ │ │ ├── run_interactive.py
│ │ │ ├── sglang_checker.py
│ │ │ ├── tuna_engine.py
│ │ │ └── user_model_registry.py
│ │ ├── experts.py
│ │ ├── experts_base.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── amx.py
│ │ ├── llamafile.py
│ │ ├── loader.py
│ │ └── moe_kernel.py
│ ├── requirements.txt
│ ├── scripts/
│ │ ├── README.md
│ │ ├── check.py
│ │ ├── check_cpu_features.py
│ │ ├── compare_weights.py
│ │ ├── convert_cpu_weights.py
│ │ ├── convert_gpu_weights.py
│ │ ├── convert_kimi_k2_fp8_to_bf16_cpu.py
│ │ ├── convert_moe_to_bf16.py
│ │ └── install-git-hooks.sh
│ ├── setup.py
│ └── test/
│ ├── __init__.py
│ ├── ci/
│ │ ├── __init__.py
│ │ ├── ci_register.py
│ │ └── ci_utils.py
│ ├── per_commit/
│ │ ├── __init__.py
│ │ ├── test_amd_placeholder.py
│ │ ├── test_basic_cpu.py
│ │ ├── test_cuda_placeholder.py
│ │ ├── test_moe_amx_accuracy_int4.py
│ │ ├── test_moe_amx_accuracy_int4_1.py
│ │ ├── test_moe_amx_accuracy_int4_1k.py
│ │ ├── test_moe_amx_accuracy_int8.py
│ │ ├── test_moe_amx_bench_int4.py
│ │ ├── test_moe_amx_bench_int4_1.py
│ │ ├── test_moe_amx_bench_int4_1k.py
│ │ └── test_moe_amx_bench_int8.py
│ ├── run_suite.py
│ └── test_generate_gpu_experts_masks.py
├── kt-sft/
│ ├── .flake8
│ ├── .gitignore
│ ├── .gitmodules
│ ├── .pylintrc
│ ├── Dockerfile
│ ├── Dockerfile.xpu
│ ├── LICENSE
│ ├── MANIFEST.in
│ ├── Makefile
│ ├── README.md
│ ├── SECURITY.md
│ ├── autosetup.sh
│ ├── book.toml
│ ├── csrc/
│ │ ├── custom_marlin/
│ │ │ ├── __init__.py
│ │ │ ├── binding.cpp
│ │ │ ├── gptq_marlin/
│ │ │ │ ├── gptq_marlin.cu
│ │ │ │ ├── gptq_marlin.cuh
│ │ │ │ ├── gptq_marlin_dtypes.cuh
│ │ │ │ ├── gptq_marlin_repack.cu
│ │ │ │ └── ops.h
│ │ │ ├── setup.py
│ │ │ ├── test_cuda_graph.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── format24.py
│ │ │ ├── marlin_24_perms.py
│ │ │ ├── marlin_perms.py
│ │ │ ├── marlin_utils.py
│ │ │ └── quant_utils.py
│ │ └── ktransformers_ext/
│ │ ├── CMakeLists.txt
│ │ ├── bench/
│ │ │ ├── bench_attention.py
│ │ │ ├── bench_attention_torch.py
│ │ │ ├── bench_linear.py
│ │ │ ├── bench_linear_torch.py
│ │ │ ├── bench_mlp.py
│ │ │ ├── bench_mlp_torch.py
│ │ │ ├── bench_moe.py
│ │ │ ├── bench_moe_amx.py
│ │ │ └── bench_moe_torch.py
│ │ ├── cmake/
│ │ │ └── FindSIMD.cmake
│ │ ├── cpu_backend/
│ │ │ ├── backend.cpp
│ │ │ ├── backend.h
│ │ │ ├── cpuinfer.h
│ │ │ ├── shared_mem_buffer.cpp
│ │ │ ├── shared_mem_buffer.h
│ │ │ ├── task_queue.cpp
│ │ │ ├── task_queue.h
│ │ │ └── vendors/
│ │ │ ├── README.md
│ │ │ ├── cuda.h
│ │ │ ├── hip.h
│ │ │ ├── musa.h
│ │ │ └── vendor.h
│ │ ├── cuda/
│ │ │ ├── binding.cpp
│ │ │ ├── custom_gguf/
│ │ │ │ ├── dequant.cu
│ │ │ │ └── ops.h
│ │ │ ├── gptq_marlin/
│ │ │ │ ├── gptq_marlin.cu
│ │ │ │ ├── gptq_marlin.cuh
│ │ │ │ ├── gptq_marlin_dtypes.cuh
│ │ │ │ └── ops.h
│ │ │ ├── setup.py
│ │ │ └── test_dequant.py
│ │ ├── examples/
│ │ │ ├── test_attention.py
│ │ │ ├── test_linear.py
│ │ │ ├── test_mlp.py
│ │ │ ├── test_moe.py
│ │ │ ├── test_sft_amx_moe.py
│ │ │ └── test_sft_moe.py
│ │ ├── ext_bindings.cpp
│ │ ├── operators/
│ │ │ ├── amx/
│ │ │ │ ├── debug_sft_moe.hpp
│ │ │ │ ├── debug_tools_sft_moe.hpp
│ │ │ │ ├── la/
│ │ │ │ │ ├── amx.hpp
│ │ │ │ │ └── utils.hpp
│ │ │ │ ├── moe.hpp
│ │ │ │ └── sft_moe.hpp
│ │ │ ├── kvcache/
│ │ │ │ ├── kvcache.h
│ │ │ │ ├── kvcache_attn.cpp
│ │ │ │ ├── kvcache_load_dump.cpp
│ │ │ │ ├── kvcache_read_write.cpp
│ │ │ │ └── kvcache_utils.cpp
│ │ │ └── llamafile/
│ │ │ ├── conversion.h
│ │ │ ├── linear.cpp
│ │ │ ├── linear.h
│ │ │ ├── mlp.cpp
│ │ │ ├── mlp.h
│ │ │ ├── moe.cpp
│ │ │ ├── moe.h
│ │ │ ├── sft_moe.cpp
│ │ │ ├── sft_moe.h
│ │ │ └── sft_moe_forward_cache.h
│ │ └── vendors/
│ │ ├── cuda.h
│ │ ├── hip.h
│ │ ├── musa.h
│ │ └── vendor.h
│ ├── install-with-cache.sh
│ ├── install.bat
│ ├── install.sh
│ ├── ktransformers/
│ │ ├── __init__.py
│ │ ├── configs/
│ │ │ ├── config.yaml
│ │ │ ├── log_config.ini
│ │ │ └── model_config/
│ │ │ ├── config.json
│ │ │ └── configuration_deepseek.py
│ │ ├── ktransformers_ext/
│ │ │ ├── operators/
│ │ │ │ └── custom_marlin/
│ │ │ │ └── quantize/
│ │ │ │ └── utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── format_24.py
│ │ │ │ ├── marlin_24_perms.py
│ │ │ │ ├── marlin_perms.py
│ │ │ │ ├── marlin_utils.py
│ │ │ │ └── quant_utils.py
│ │ │ └── triton/
│ │ │ └── fp8gemm.py
│ │ ├── local_chat.py
│ │ ├── local_chat.sh
│ │ ├── lora_test_module.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── configuration_deepseek.py
│ │ │ ├── configuration_deepseek_v3.py
│ │ │ ├── configuration_llama.py
│ │ │ ├── configuration_qwen2_moe.py
│ │ │ ├── configuration_qwen3_moe.py
│ │ │ ├── custom_cache.py
│ │ │ ├── custom_modeling_deepseek_v2.py
│ │ │ ├── custom_modeling_deepseek_v3.py
│ │ │ ├── custom_modeling_qwen2_moe.py
│ │ │ ├── custom_modeling_qwen3_moe.py
│ │ │ ├── modeling_deepseek.py
│ │ │ ├── modeling_deepseek_v3.py
│ │ │ ├── modeling_llama.py
│ │ │ ├── modeling_mixtral.py
│ │ │ ├── modeling_qwen2_moe.py
│ │ │ └── modeling_qwen3_moe.py
│ │ ├── moe_test_module.py
│ │ ├── moe_test_module_old.py
│ │ ├── operators/
│ │ │ ├── RoPE.py
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── balance_serve_attention.py
│ │ │ ├── base_operator.py
│ │ │ ├── cpuinfer.py
│ │ │ ├── dynamic_attention.py
│ │ │ ├── experts.py
│ │ │ ├── flashinfer_batch_prefill_wrapper.py
│ │ │ ├── flashinfer_wrapper.py
│ │ │ ├── gate.py
│ │ │ ├── layernorm.py
│ │ │ ├── linear.py
│ │ │ ├── mlp.py
│ │ │ ├── models.py
│ │ │ ├── triton_attention.py
│ │ │ └── triton_attention_prefill.py
│ │ ├── optimize/
│ │ │ ├── optimize.py
│ │ │ └── optimize_rules/
│ │ │ ├── DeepSeek-V2-Chat-multi-gpu-4.yaml
│ │ │ ├── DeepSeek-V2-Chat-multi-gpu.yaml
│ │ │ ├── DeepSeek-V2-Chat-sft-amx.yaml
│ │ │ ├── DeepSeek-V2-Chat.yaml
│ │ │ ├── DeepSeek-V2-Lite-Chat-multi-gpu.yaml
│ │ │ ├── DeepSeek-V2-Lite-Chat-sft-amx-multi-gpu.yaml
│ │ │ ├── DeepSeek-V2-Lite-Chat-sft-amx.yaml
│ │ │ ├── DeepSeek-V2-Lite-Chat-sft.yaml
│ │ │ ├── DeepSeek-V2-Lite-Chat-use-adapter.yaml
│ │ │ ├── DeepSeek-V2-Lite-Chat.yaml
│ │ │ ├── DeepSeek-V3-Chat-amx.yaml
│ │ │ ├── DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve-amx.yaml
│ │ │ ├── DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve.yaml
│ │ │ ├── DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
│ │ │ ├── DeepSeek-V3-Chat-multi-gpu-4.yaml
│ │ │ ├── DeepSeek-V3-Chat-multi-gpu-8.yaml
│ │ │ ├── DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
│ │ │ ├── DeepSeek-V3-Chat-multi-gpu-marlin.yaml
│ │ │ ├── DeepSeek-V3-Chat-multi-gpu.yaml
│ │ │ ├── DeepSeek-V3-Chat-serve.yaml
│ │ │ ├── DeepSeek-V3-Chat-sft-amx-multi-gpu-4.yaml
│ │ │ ├── DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
│ │ │ ├── DeepSeek-V3-Chat-sft-amx.yaml
│ │ │ ├── DeepSeek-V3-Chat.yaml
│ │ │ ├── Internlm2_5-7b-Chat-1m.yaml
│ │ │ ├── Mixtral.yaml
│ │ │ ├── Moonlight-16B-A3B-serve.yaml
│ │ │ ├── Moonlight-16B-A3B.yaml
│ │ │ ├── Qwen2-57B-A14B-Instruct-multi-gpu.yaml
│ │ │ ├── Qwen2-57B-A14B-Instruct.yaml
│ │ │ ├── Qwen2-serve-amx.yaml
│ │ │ ├── Qwen2-serve.yaml
│ │ │ ├── Qwen3Moe-serve-amx.yaml
│ │ │ ├── Qwen3Moe-serve.yaml
│ │ │ ├── Qwen3Moe-sft-amx.yaml
│ │ │ ├── rocm/
│ │ │ │ └── DeepSeek-V3-Chat.yaml
│ │ │ └── xpu/
│ │ │ ├── DeepSeek-V2-Chat.yaml
│ │ │ ├── DeepSeek-V3-Chat.yaml
│ │ │ └── Qwen3Moe-Chat.yaml
│ │ ├── server/
│ │ │ ├── __init__.py
│ │ │ ├── api/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── ollama/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── completions.py
│ │ │ │ ├── openai/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── assistants/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── assistants.py
│ │ │ │ │ │ ├── messages.py
│ │ │ │ │ │ ├── runs.py
│ │ │ │ │ │ └── threads.py
│ │ │ │ │ ├── endpoints/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ └── chat.py
│ │ │ │ │ └── legacy/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── completions.py
│ │ │ │ └── web/
│ │ │ │ ├── __init__.py
│ │ │ │ └── system.py
│ │ │ ├── args.py
│ │ │ ├── backend/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── args.py
│ │ │ │ ├── base.py
│ │ │ │ ├── context_manager.py
│ │ │ │ └── interfaces/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── balance_serve.py
│ │ │ │ ├── exllamav2.py
│ │ │ │ ├── ktransformers.py
│ │ │ │ └── transformers.py
│ │ │ ├── balance_serve/
│ │ │ │ ├── inference/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── config.py
│ │ │ │ │ ├── distributed/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── communication_op.py
│ │ │ │ │ │ ├── cuda_wrapper.py
│ │ │ │ │ │ ├── custom_all_reduce.py
│ │ │ │ │ │ ├── custom_all_reduce_utils.py
│ │ │ │ │ │ ├── parallel_state.py
│ │ │ │ │ │ ├── pynccl.py
│ │ │ │ │ │ ├── pynccl_wrapper.py
│ │ │ │ │ │ └── utils.py
│ │ │ │ │ ├── forward_batch.py
│ │ │ │ │ ├── model_runner.py
│ │ │ │ │ ├── query_manager.py
│ │ │ │ │ └── sampling/
│ │ │ │ │ ├── penaltylib/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── orchestrator.py
│ │ │ │ │ │ └── penalizers/
│ │ │ │ │ │ ├── frequency_penalty.py
│ │ │ │ │ │ ├── min_new_tokens.py
│ │ │ │ │ │ ├── presence_penalty.py
│ │ │ │ │ │ └── repetition_penalty.py
│ │ │ │ │ └── sampler.py
│ │ │ │ ├── sched_rpc.py
│ │ │ │ └── settings.py
│ │ │ ├── config/
│ │ │ │ ├── config.py
│ │ │ │ ├── log.py
│ │ │ │ └── singleton.py
│ │ │ ├── crud/
│ │ │ │ ├── __init__.py
│ │ │ │ └── assistants/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── assistants.py
│ │ │ │ ├── messages.py
│ │ │ │ ├── runs.py
│ │ │ │ └── threads.py
│ │ │ ├── exceptions.py
│ │ │ ├── main.py
│ │ │ ├── models/
│ │ │ │ ├── __init__.py
│ │ │ │ └── assistants/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── assistants.py
│ │ │ │ ├── messages.py
│ │ │ │ ├── run_steps.py
│ │ │ │ ├── runs.py
│ │ │ │ └── threads.py
│ │ │ ├── schemas/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── assistants/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── assistants.py
│ │ │ │ │ ├── messages.py
│ │ │ │ │ ├── runs.py
│ │ │ │ │ ├── streaming.py
│ │ │ │ │ ├── threads.py
│ │ │ │ │ └── tool.py
│ │ │ │ ├── base.py
│ │ │ │ ├── conversation.py
│ │ │ │ ├── endpoints/
│ │ │ │ │ └── chat.py
│ │ │ │ └── legacy/
│ │ │ │ ├── __init__.py
│ │ │ │ └── completions.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── create_interface.py
│ │ │ ├── multi_timer.py
│ │ │ └── sql_utils.py
│ │ ├── sft/
│ │ │ ├── __init__.py
│ │ │ ├── flops_utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── custom_profile.py
│ │ │ │ └── lora_test_utils.py
│ │ │ ├── lora.py
│ │ │ ├── metrics.py
│ │ │ ├── metrics_utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── constants.py
│ │ │ │ ├── env.py
│ │ │ │ ├── logging.py
│ │ │ │ ├── misc.py
│ │ │ │ ├── packages.py
│ │ │ │ └── ploting.py
│ │ │ ├── monkey_patch_torch_module.py
│ │ │ ├── peft_utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── lora_layer.py
│ │ │ │ ├── lora_model.py
│ │ │ │ ├── mapping.py
│ │ │ │ └── peft_model.py
│ │ │ └── torchviz_test.py
│ │ ├── tests/
│ │ │ ├── .gitignore
│ │ │ ├── AIME_2024/
│ │ │ │ ├── eval_api.py
│ │ │ │ ├── evaluation.py
│ │ │ │ └── prompts.py
│ │ │ ├── dequant_gpu.py
│ │ │ ├── dequant_gpu_t.py
│ │ │ ├── function_call_test.py
│ │ │ ├── humaneval/
│ │ │ │ ├── eval_api.py
│ │ │ │ ├── evaluation.py
│ │ │ │ └── prompts.py
│ │ │ ├── mmlu_pro_test.py
│ │ │ ├── mmlu_test.py
│ │ │ ├── mmlu_test_multi.py
│ │ │ ├── score.py
│ │ │ ├── test_client.py
│ │ │ ├── test_pytorch_q8.py
│ │ │ ├── test_speed.py
│ │ │ └── triton_fp8gemm_test.py
│ │ ├── util/
│ │ │ ├── cuda_graph_runner.py
│ │ │ ├── custom_gguf.py
│ │ │ ├── custom_loader.py
│ │ │ ├── globals.py
│ │ │ ├── grad_wrapper.py
│ │ │ ├── inference_state.py
│ │ │ ├── modeling_rope_utils.py
│ │ │ ├── textstream.py
│ │ │ ├── utils.py
│ │ │ ├── vendors.py
│ │ │ └── weight_loader.py
│ │ └── website/
│ │ ├── .browserslistrc
│ │ ├── .eslintrc.js
│ │ ├── .gitignore
│ │ ├── README.md
│ │ ├── config.d.ts
│ │ ├── jest.config.js
│ │ ├── package.json
│ │ ├── public/
│ │ │ ├── config.js
│ │ │ ├── css/
│ │ │ │ └── reset.css
│ │ │ └── index.html
│ │ ├── src/
│ │ │ ├── App.vue
│ │ │ ├── api/
│ │ │ │ ├── api-client.ts
│ │ │ │ ├── assistant.ts
│ │ │ │ ├── message.ts
│ │ │ │ ├── run.ts
│ │ │ │ └── thread.ts
│ │ │ ├── assets/
│ │ │ │ ├── css/
│ │ │ │ │ └── mixins.styl
│ │ │ │ └── iconfont/
│ │ │ │ ├── demo.css
│ │ │ │ ├── demo_index.html
│ │ │ │ ├── iconfont.css
│ │ │ │ ├── iconfont.js
│ │ │ │ └── iconfont.json
│ │ │ ├── components/
│ │ │ │ └── chat/
│ │ │ │ └── index.vue
│ │ │ ├── conf/
│ │ │ │ └── config.ts
│ │ │ ├── locals/
│ │ │ │ ├── en.js
│ │ │ │ ├── index.js
│ │ │ │ └── zh.js
│ │ │ ├── main.ts
│ │ │ ├── router/
│ │ │ │ └── index.ts
│ │ │ ├── shims-vue.d.ts
│ │ │ ├── store/
│ │ │ │ └── index.ts
│ │ │ ├── utils/
│ │ │ │ ├── copy.ts
│ │ │ │ └── types.ts
│ │ │ └── views/
│ │ │ └── home.vue
│ │ ├── tests/
│ │ │ └── unit/
│ │ │ └── example.spec.ts
│ │ ├── tsconfig.json
│ │ └── vue.config.js
│ ├── merge_tensors/
│ │ └── merge_safetensor_gguf.py
│ ├── pyproject.toml
│ ├── requirements-sft.txt
│ ├── setup.py
│ ├── test_adapter/
│ │ ├── data_transfer.py
│ │ ├── infer_with_adapter.py
│ │ ├── inspect_adapter.py
│ │ ├── pred2metrics.py
│ │ ├── test_grad.py
│ │ └── time_test_lora_train.py
│ └── withoutKT_PEFT.py
├── pyproject.toml
├── setup.py
├── third_party/
│ └── llamafile/
│ ├── README.md
│ ├── bench.h
│ ├── flags.cpp
│ ├── flags.h
│ ├── iqk_mul_mat.inc
│ ├── iqk_mul_mat_amd_avx2.cpp
│ ├── iqk_mul_mat_amd_zen4.cpp
│ ├── iqk_mul_mat_arm.inc
│ ├── iqk_mul_mat_arm82.cpp
│ ├── macros.h
│ ├── micros.h
│ ├── numba.h
│ ├── sgemm.cpp
│ ├── sgemm.h
│ ├── tinyblas_cpu.h
│ ├── tinyblas_cpu_mixmul.inc
│ ├── tinyblas_cpu_mixmul_amd_avx.cpp
│ ├── tinyblas_cpu_mixmul_amd_avx2.cpp
│ ├── tinyblas_cpu_mixmul_amd_avx512f.cpp
│ ├── tinyblas_cpu_mixmul_amd_avxvnni.cpp
│ ├── tinyblas_cpu_mixmul_amd_fma.cpp
│ ├── tinyblas_cpu_mixmul_amd_zen4.cpp
│ ├── tinyblas_cpu_mixmul_arm80.cpp
│ ├── tinyblas_cpu_mixmul_arm82.cpp
│ ├── tinyblas_cpu_sgemm.inc
│ ├── tinyblas_cpu_sgemm_amd_avx.cpp
│ ├── tinyblas_cpu_sgemm_amd_avx2.cpp
│ ├── tinyblas_cpu_sgemm_amd_avx512f.cpp
│ ├── tinyblas_cpu_sgemm_amd_avxvnni.cpp
│ ├── tinyblas_cpu_sgemm_amd_fma.cpp
│ ├── tinyblas_cpu_sgemm_amd_zen4.cpp
│ ├── tinyblas_cpu_sgemm_arm80.cpp
│ ├── tinyblas_cpu_sgemm_arm82.cpp
│ └── tinyblas_cpu_unsupported.cpp
└── version.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
* Focusing on what is best not just for us as individuals, but for the overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.
Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at [INSERT CONTACT METHOD]. All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series of actions.
**Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0].
Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder][Mozilla CoC].
For answers to common questions about this code of conduct, see the FAQ at [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at [https://www.contributor-covenant.org/translations][translations].
[homepage]: https://www.contributor-covenant.org
[v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html
[Mozilla CoC]: https://github.com/mozilla/diversity
[FAQ]: https://www.contributor-covenant.org/faq
[translations]: https://www.contributor-covenant.org/translations
================================================
FILE: .github/CONTRIBUTING.md
================================================
## Before Commit!
Your commit message must follow Conventional Commits (https://www.conventionalcommits.org/) and your code should be formatted. The Git hooks will do most of the work automatically:
### Tool Requirements
You need a recent `clang-format` (>= 18). In a conda environment you can install:
```shell
conda install -c conda-forge clang-format=18
```
If you previously configured with an older version, remove the build directory and reconfigure:
```shell
rm -rf kt-kernel/build
```
Install `black` for Python formatting:
```shell
conda install black
```
### Install hook:
```shell
bash kt-kernel/scripts/install-git-hooks.sh
#or just cmake the kt-kernel
cmake -S kt-kernel -B kt-kernel/build
```
There are manual commands if you need format.
```shell
cmake -S kt-kernel -B kt-kernel/build
cmake --build kt-kernel/build --target format
```
## Developer Note
Formatting and commit message rules are enforced by Git hooks. After installing `clang-format` and `black`, just commit normally—the hooks will run formatting for you.
> [!NOTE]
> If formatting modifies files, the commit is aborted after staging those changes. Review them and run `git commit` again. Repeat until no further formatting changes appear.
---
### Conventional Commit Regex (Reference)
The commit-msg hook enforces this pattern:
```text
regex='^\[(feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert|wip)\](\([^\)]+\))?(!)?: .+'
```
Meaning (English):
* `[type]` required — one of feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert|wip
* Optional scope: `(scope)` — any chars except `)`
* Optional breaking change marker: `!` right after type or scope
* Separator: `: ` (colon + space)
* Subject: free text (at least one character)
Examples:
```text
[feat]: add adaptive batching
[fix(parser)]: handle empty token list
[docs]!: update API section for breaking rename
```
You can bypass locally (not recommended) with:
```shell
git commit --no-verify
```
## 提交前提醒
提交信息必须满足 Conventional Commits 规范 (https://www.conventionalcommits.org/),代码需要符合格式要求。Git 钩子已经集成了大部分工作:
### 软件要求
需要较新的 `clang-format` (>= 18),在 conda 环境中安装:
```shell
conda install -c conda-forge clang-format=18
```
如果之前用老版本配置过,请删除构建目录重新配置:
```shell
rm -rf kt-kernel/build
```
安装 `black` 以进行 Python 文件格式化:
```shell
conda install black
```
### 安装钩子
```shell
bash kt-kernel/scripts/install-git-hooks.sh
#or just cmake the kt-kernel
cmake -S kt-kernel -B kt-kernel/build
```
如果你需要手动格式化:
```shell
cmake -S kt-kernel -B kt-kernel/build
cmake --build kt-kernel/build --target format
```
## 开发者说明
本仓库通过 Git hooks 自动执行代码格式化与提交信息规范检查。只需安装好 `clang-format` 与 `black` 后正常执行提交即可,钩子会自动格式化。
> [!NOTE]
> 如果格式化修改了文件,钩子会终止提交并已暂存这些改动。请查看修改后再次执行 `git commit`,重复直到没有新的格式化变更。
### 提交信息正则(参考)
钩子使用如下正则检查提交信息:
```text
regex='^\[(feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert|wip)\](\([^\)]+\))?(!)?: .+'
```
含义:
* `[type]` 必填:feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert|wip
* 作用域可选:`(scope)`,不能包含右括号
* 可选的破坏性标记:`!`
* 分隔符:冒号+空格 `: `
* 描述:至少一个字符
示例:
```text
[feat]: 增加自适应 batch 功能
[fix(tokenizer)]: 修复空 token 列表处理
[docs]!: 更新接口文档(存在破坏性修改)
```
跳过钩子(不推荐,仅紧急时):
```shell
git commit --no-verify
```
================================================
FILE: .github/ISSUE_TEMPLATE/-bug-.yaml
================================================
name: "\U0001F41B Bug / Help"
description: Create a report to help us improve the ktransformers project
labels: ["pending"]
body:
- type: markdown
attributes:
value: |
Issues included in **[FAQs](https://github.com/kvcache-ai/ktransformers/issues/1608)** or those with **insufficient** information may be closed without a response.
已经包含在 **[常见问题](https://github.com/kvcache-ai/ktransformers/issues/1608)** 内或提供信息**不完整**的 issues 可能不会被回复。
- type: checkboxes
id: reminder
attributes:
label: Reminder
description: |
Please ensure you have read the above rules carefully and searched the existing issues (including FAQs).
请确保您已经认真阅读了上述规则并且搜索过现有的 issues(包括常见问题)。
options:
- label: I have read the above rules and searched the existing issues.
required: true
- type: textarea
id: system-info
validations:
required: true
attributes:
label: System Info
description: |
Please share your system info with us. You can run the command **lscpu**, ** nvidia-smi ** etc. and copy-paste its output below.
请提供您的系统信息。您可以在命令行运行 **lscpu**, **nvidia-smi** 等命令,并将其输出复制到该文本框中。
placeholder: ktransformers version,sglang version, platform, python version, cpu info, GPU/NPU info ...
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction
description: |
Please provide entry arguments, error messages and stack traces that reproduces the problem.
请提供入口参数,错误日志以及异常堆栈以便于我们复现问题。
value: |
```text
Put your message here.
```
- type: textarea
id: others
validations:
required: false
attributes:
label: Others
================================================
FILE: .github/ISSUE_TEMPLATE/-feature-.yaml
================================================
name: "\U0001F680 Feature request"
description: Submit a request for a new feature
labels: ["enhancement", "pending"]
body:
- type: markdown
attributes:
value: |
Please do not create issues that are not related to new features under this category.
请勿在此分类下创建和新特性无关的 issues。
- type: checkboxes
id: reminder
attributes:
label: Reminder
description: |
Please ensure you have read the above rules carefully and searched the existing issues.
请确保您已经认真阅读了上述规则并且搜索过现有的 issues。
options:
- label: I have read the above rules and searched the existing issues.
required: true
- type: textarea
id: description
validations:
required: true
attributes:
label: Description
description: |
A clear and concise description of the feature proposal.
请详细描述您希望加入的新功能特性。
- type: textarea
id: contribution
validations:
required: false
attributes:
label: Pull Request
description: |
Have you already created the relevant PR and submitted the code?
您是否已经创建了相关 PR 并提交了代码?
================================================
FILE: .github/ISSUE_TEMPLATE/config.yml
================================================
blank_issues_enabled: false
contact_links:
- name: 📚 FAQs | 常见问题
url: https://github.com/kvcache-ai/ktransformers/issues/1608
about: Reading in advance is recommended | 建议提前阅读
================================================
FILE: .github/PULL_REQUEST_TEMPLATE.md
================================================
# What does this PR do?
Fixes # (issue)
## Before submitting
- [ ] Did you read the [contributor guideline](https://github.com/kvcache-ai/ktransformers/blob/main/.github/CONTRIBUTING.md)?
- [ ] Did you write any new necessary tests?
================================================
FILE: .github/SECURITY.md
================================================
# Reporting Security Issues
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/kvcache-ai/ktransformers/security/advisories/new) tab.
We will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
Report security bugs in third-party modules to the person or team maintaining the module.
================================================
FILE: .github/workflows/book-ci.yml
================================================
name: Book-CI
on:
push:
branches:
- main
# - server_support
pull_request:
branches:
- main
# - server_support
jobs:
test:
name: test
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
steps:
- uses: actions/checkout@v4
- name: Install Rust
run: |
rustup set profile minimal
rustup toolchain install stable
rustup default stable
- name: Setup mdBook
uses: peaceiris/actions-mdbook@v2
with:
mdbook-version: "latest"
# - name: Run tests
# run: mdbook test
================================================
FILE: .github/workflows/deploy.yml
================================================
name: Deploy
on:
push:
branches:
- main
# - server_support
pull_request:
branches:
- main
# - server_support
defaults:
run:
shell: bash
permissions:
contents: write
jobs:
deploy:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
steps:
- uses: actions/checkout@v4
- name: Install Rust
run: |
rustup set profile minimal
rustup toolchain install stable
rustup default stable
- name: Setup mdBook
uses: peaceiris/actions-mdbook@v2
with:
mdbook-version: "latest"
- run: mdbook build
# - name: Copy Assets
# run: |
# chmod +x ci/copy-assets.sh
# ci/copy-assets.sh ${{ matrix.os }}
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
# or || github.ref == 'refs/heads/server_support'
if: ${{ github.ref == 'refs/heads/main' }}
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: ./book
================================================
FILE: .github/workflows/docker-image.yml
================================================
name: DockerHub CI
on:
release:
types: [published]
workflow_dispatch:
inputs:
push_to_dockerhub:
description: 'Push image to DockerHub? (true/false)'
required: true
default: 'false'
type: boolean
cuda_version:
description: 'CUDA version (e.g., 12.8.1)'
required: false
default: '12.8.1'
type: string
push_simplified_tag:
description: 'Also push simplified tag? (true/false)'
required: false
default: 'true'
type: boolean
ubuntu_mirror:
description: 'Use Tsinghua Ubuntu mirror? (0/1)'
required: false
default: '0'
type: string
# push:
# branches:
# - main
env:
DOCKERHUB_REPO: ${{ secrets.DOCKERHUB_USERNAME }}/ktransformers
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Run tests
run: |
if [ -f docker-compose.test.yml ]; then
docker-compose --file docker-compose.test.yml build
docker-compose --file docker-compose.test.yml run sut
else
docker build . --file docker/Dockerfile
fi
build-and-push:
needs: test
name: Build and Push Multi-Variant Docker Image
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Move Docker data directory
run: |
sudo systemctl stop docker
sudo mkdir -p /mnt/docker
sudo rsync -avz /var/lib/docker/ /mnt/docker
sudo rm -rf /var/lib/docker
sudo ln -s /mnt/docker /var/lib/docker
sudo systemctl start docker
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Determine build parameters
id: params
run: |
# Determine if we should push
if [ "${{ github.event_name }}" = "release" ]; then
echo "should_push=true" >> $GITHUB_OUTPUT
echo "push_simplified=true" >> $GITHUB_OUTPUT
elif [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
echo "should_push=${{ inputs.push_to_dockerhub }}" >> $GITHUB_OUTPUT
echo "push_simplified=${{ inputs.push_simplified_tag }}" >> $GITHUB_OUTPUT
else
echo "should_push=false" >> $GITHUB_OUTPUT
echo "push_simplified=false" >> $GITHUB_OUTPUT
fi
# Determine CUDA version
if [ "${{ github.event_name }}" = "workflow_dispatch" ] && [ -n "${{ inputs.cuda_version }}" ]; then
echo "cuda_version=${{ inputs.cuda_version }}" >> $GITHUB_OUTPUT
else
echo "cuda_version=12.8.1" >> $GITHUB_OUTPUT
fi
# Determine Ubuntu mirror setting
if [ "${{ github.event_name }}" = "workflow_dispatch" ] && [ -n "${{ inputs.ubuntu_mirror }}" ]; then
echo "ubuntu_mirror=${{ inputs.ubuntu_mirror }}" >> $GITHUB_OUTPUT
else
echo "ubuntu_mirror=0" >> $GITHUB_OUTPUT
fi
- name: Build and push Docker image
run: |
cd docker
# Build command arguments
BUILD_ARGS=(
--cuda-version "${{ steps.params.outputs.cuda_version }}"
--ubuntu-mirror "${{ steps.params.outputs.ubuntu_mirror }}"
--repository "${{ env.DOCKERHUB_REPO }}"
)
# Add simplified tag option if enabled
if [ "${{ steps.params.outputs.push_simplified }}" = "true" ]; then
BUILD_ARGS+=(--also-push-simplified)
fi
# Add HTTP proxy if available
if [ -n "${{ secrets.HTTP_PROXY }}" ]; then
BUILD_ARGS+=(--http-proxy "${{ secrets.HTTP_PROXY }}")
fi
# Add HTTPS proxy if available
if [ -n "${{ secrets.HTTPS_PROXY }}" ]; then
BUILD_ARGS+=(--https-proxy "${{ secrets.HTTPS_PROXY }}")
fi
# Dry run if not pushing
if [ "${{ steps.params.outputs.should_push }}" != "true" ]; then
BUILD_ARGS+=(--dry-run)
fi
# Execute build script
./push-to-dockerhub.sh "${BUILD_ARGS[@]}"
- name: Display image information
if: steps.params.outputs.should_push == 'true'
run: |
echo "::notice title=Docker Image::Image pushed successfully to ${{ env.DOCKERHUB_REPO }}"
echo "Pull command: docker pull ${{ env.DOCKERHUB_REPO }}:v\$(VERSION)-cu\$(CUDA_SHORT)"
================================================
FILE: .github/workflows/kt-kernel-tests.yml
================================================
name: PR KT-Kernel Test
on:
pull_request:
branches:
- main
- develop
types: [synchronize, labeled]
workflow_dispatch:
concurrency:
group: pr-kt-kernel-test-${{ github.ref }}
cancel-in-progress: true
jobs:
# =============================================== check changes ====================================================
check-changes:
runs-on: ubuntu-latest
outputs:
kt_kernel: ${{ steps.filter.outputs.kt_kernel }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Fail if the PR does not have the 'run-ci' label
if: github.event_name == 'pull_request' && !contains(github.event.pull_request.labels.*.name, 'run-ci')
run: |
echo "This pull request does not have the 'run-ci' label. Failing the workflow."
exit 1
- name: Fail if the PR is a draft
if: github.event_name == 'pull_request' && github.event.pull_request.draft == true
run: |
echo "This pull request is a draft. Failing the workflow."
exit 1
- name: Detect file changes
id: filter
uses: dorny/paths-filter@v3
with:
filters: |
kt_kernel:
- "kt-kernel/**"
- ".github/workflows/kt-kernel-tests.yml"
# =============================================== KT-Kernel tests ====================================================
per-commit-kt-kernel-cpu:
needs: [check-changes]
if: always() && !failure() && !cancelled() &&
(needs.check-changes.outputs.kt_kernel == 'true' || github.event_name == 'workflow_dispatch')
runs-on: kt-cpu
continue-on-error: false
steps:
- name: Cleanup
run: |
sudo rm -rf $GITHUB_WORKSPACE/* || true
- name: Checkout code
uses: actions/checkout@v4
with:
submodules: recursive
- name: Install KT-Kernel
run: |
cd kt-kernel
bash install.sh build
- name: Run KT-Kernel CPU tests
timeout-minutes: 60
run: |
cd kt-kernel/test
python3 run_suite.py --hw cpu --suite default
# =============================================== finish ====================================================
pr-test-kt-kernel-finish:
needs: [check-changes, per-commit-kt-kernel-cpu]
if: always()
runs-on: ubuntu-latest
steps:
- name: Check all dependent job statuses
run: |
# Convert the 'needs' context to a JSON string
json_needs='${{ toJson(needs) }}'
# Get a list of all job names from the JSON keys
job_names=$(echo "$json_needs" | jq -r 'keys_unsorted[]')
for job in $job_names; do
# For each job, extract its result
result=$(echo "$json_needs" | jq -r --arg j "$job" '.[$j].result')
# Print the job name and its result
echo "$job: $result"
# Check for failure or cancellation and exit if found
if [[ "$result" == "failure" || "$result" == "cancelled" ]]; then
echo "The above jobs failed."
exit 1
fi
done
# If the loop completes, all jobs were successful
echo "All jobs completed successfully"
exit 0
================================================
FILE: .github/workflows/release-fake-tag.yml
================================================
name: Release Fake Tag
on:
push:
branches:
- main
paths:
- "version.py"
workflow_dispatch:
permissions:
contents: write
jobs:
publish:
if: github.repository == 'kvcache-ai/ktransformers'
runs-on: ubuntu-latest
environment: 'prod'
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
token: ${{ secrets.GITHUB_TOKEN }}
- name: Get version
id: get_version
run: |
version=$(cat version.py | grep '__version__' | cut -d'"' -f2)
echo "TAG=v$version" >> $GITHUB_OUTPUT
- name: Create and push tag
run: |
git config user.name "ktransformers-bot"
git config user.email "ktransformers-bot@users.noreply.github.com"
git tag ${{ steps.get_version.outputs.TAG }}
git push origin ${{ steps.get_version.outputs.TAG }}
================================================
FILE: .github/workflows/release-pypi.yml
================================================
name: Release to PyPI
on:
push:
branches:
- main
paths:
- "version.py"
workflow_dispatch:
inputs:
test_pypi:
description: 'Publish to TestPyPI instead of PyPI (for testing)'
required: false
default: 'false'
type: choice
options:
- 'true'
- 'false'
permissions:
contents: read
jobs:
# ── sglang-kt (must be on PyPI before users can pip install kt-kernel) ──
build-and-publish-sglang-kt:
name: Build & publish sglang-kt
runs-on: [self-hosted, linux, x64]
if: github.repository == 'kvcache-ai/ktransformers' && github.ref == 'refs/heads/main'
environment: prod
permissions:
id-token: write
contents: read
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install build tools
run: |
python -m pip install --upgrade pip
pip install build wheel setuptools twine
- name: Build sglang-kt wheel
working-directory: third_party/sglang/python
run: |
KT_VERSION=$(python3 -c "exec(open('${{ github.workspace }}/version.py').read()); print(__version__)")
export SGLANG_KT_VERSION="$KT_VERSION"
echo "Building sglang-kt v${KT_VERSION} wheel..."
python -m build --wheel -v
ls dist/ | grep -q "sglang_kt" || (echo "ERROR: Wheel name does not contain sglang_kt" && exit 1)
- name: Publish sglang-kt to PyPI
if: github.event.inputs.test_pypi != 'true'
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python -m twine upload --skip-existing --verbose third_party/sglang/python/dist/*.whl
- name: Publish sglang-kt to TestPyPI (if requested)
if: github.event.inputs.test_pypi == 'true'
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }}
run: |
python -m twine upload --repository testpypi --skip-existing --verbose third_party/sglang/python/dist/*.whl
# ── kt-kernel ──
build-kt-kernel:
name: Build kt-kernel (Python ${{ matrix.python-version }})
runs-on: [self-hosted, linux, x64, gpu]
strategy:
fail-fast: false
matrix:
python-version: ['3.11', '3.12']
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Verify CUDA availability
run: |
nvidia-smi || (echo "ERROR: GPU not available" && exit 1)
nvcc --version || (echo "ERROR: CUDA toolkit not found" && exit 1)
- name: Install dependencies
run: |
apt-get update && apt-get install -y cmake libhwloc-dev pkg-config libnuma-dev
python -m pip install --upgrade pip
pip install build wheel setuptools torch --index-url https://download.pytorch.org/whl/cu118
- name: Build kt-kernel wheel
working-directory: kt-kernel
env:
CPUINFER_BUILD_ALL_VARIANTS: '1'
CPUINFER_USE_CUDA: '1'
CPUINFER_CUDA_ARCHS: '80;86;89;90'
CPUINFER_CUDA_STATIC_RUNTIME: '1'
CPUINFER_BUILD_TYPE: 'Release'
CPUINFER_PARALLEL: '4'
CPUINFER_FORCE_REBUILD: '1'
CUDA_HOME: '/usr/local/cuda-11.8'
run: |
echo "Building kt-kernel with:"
echo " - CUDA support (SM 80, 86, 89, 90)"
echo " - CPU multi-variant (AMX, AVX512, AVX2)"
python -m build --wheel -v
- name: Verify wheel
working-directory: kt-kernel
run: |
echo "Generated wheel:"
ls -lh dist/
# Install and test
pip install dist/*.whl
python -c "import kt_kernel; print(f'✓ Version: {kt_kernel.__version__}')"
python -c "import kt_kernel; print(f'✓ CPU variant: {kt_kernel.__cpu_variant__}')"
# Verify CUDA support
python -c "
from kt_kernel import kt_kernel_ext
cpu_infer = kt_kernel_ext.CPUInfer(4)
methods = dir(cpu_infer)
has_cuda = 'submit_with_cuda_stream' in methods
print(f'✓ CUDA support: {has_cuda}')
"
# Verify CPU multi-variant support
echo "Checking CPU variants in wheel..."
python -m zipfile -l dist/*.whl | grep "_kt_kernel_ext_" || echo "Warning: No variant .so files found"
python -m zipfile -l dist/*.whl | grep "_kt_kernel_ext_amx.cpython" && echo "✓ AMX variant found" || echo "Note: AMX variant missing"
python -m zipfile -l dist/*.whl | grep "_kt_kernel_ext_avx512" && echo "✓ AVX512 variants found" || echo "Note: AVX512 variants missing"
python -m zipfile -l dist/*.whl | grep "_kt_kernel_ext_avx2.cpython" && echo "✓ AVX2 variant found" || echo "Note: AVX2 variant missing"
# Verify static linking (should NOT depend on libcudart.so)
rm -rf /tmp/check
unzip -q dist/*.whl -d /tmp/check
if ldd /tmp/check/kt_kernel/*.so 2>/dev/null | grep -q "libcudart.so"; then
echo "ERROR: Dynamic cudart found, should be statically linked"
exit 1
else
echo "✓ CUDA runtime statically linked"
fi
- name: Repair wheel for manylinux
working-directory: kt-kernel
run: |
pip install auditwheel patchelf
mkdir -p wheelhouse
for wheel in dist/*.whl; do
auditwheel repair "$wheel" --plat manylinux_2_17_x86_64 --exclude libcuda.so.1 -w wheelhouse/ || \
cp "$wheel" wheelhouse/$(basename "$wheel" | sed 's/linux_x86_64/manylinux_2_17_x86_64/')
done
rm -f dist/*.whl && cp wheelhouse/*.whl dist/
- name: Upload artifact
uses: actions/upload-artifact@v4
with:
name: kt-kernel-wheels-py${{ matrix.python-version }}
path: kt-kernel/dist/*.whl
retention-days: 7
publish-pypi:
name: Publish kt-kernel to PyPI
needs: [build-and-publish-sglang-kt, build-kt-kernel]
runs-on: [self-hosted, linux, x64]
if: github.repository == 'kvcache-ai/ktransformers' && github.ref == 'refs/heads/main'
environment: prod
permissions:
id-token: write # For trusted publishing (OIDC)
contents: read
steps:
- name: Download all wheel artifacts
uses: actions/download-artifact@v4
with:
path: artifacts/
- name: Organize wheels into dist/
run: |
mkdir -p dist/
find artifacts/ -name "*.whl" -exec cp {} dist/ \;
echo "Wheels to publish:"
ls -lh dist/
- name: Get version from wheel
id: get_version
run: |
# Extract version from first wheel filename
wheel_name=$(ls dist/*.whl | head -1 | xargs basename)
# Extract version (format: kt_kernel-X.Y.Z-...)
version=$(echo "$wheel_name" | sed 's/kt_kernel-\([0-9.]*\)-.*/\1/')
echo "VERSION=$version" >> $GITHUB_OUTPUT
echo "Publishing version: $version"
- name: Install twine
run: |
python -m pip install --upgrade pip
pip install twine
- name: Publish to TestPyPI (if requested)
if: github.event.inputs.test_pypi == 'true'
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }}
run: |
python -m twine upload \
--repository testpypi \
--skip-existing \
--verbose \
dist/*.whl
- name: Publish to PyPI
if: github.event.inputs.test_pypi != 'true'
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python -m twine upload \
--skip-existing \
--verbose \
dist/*.whl
- name: Create release summary
run: |
echo "## 🎉 kt-kernel v${{ steps.get_version.outputs.VERSION }} Published to PyPI" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "### Installation" >> $GITHUB_STEP_SUMMARY
echo '```bash' >> $GITHUB_STEP_SUMMARY
echo "pip install kt-kernel==${{ steps.get_version.outputs.VERSION }}" >> $GITHUB_STEP_SUMMARY
echo '```' >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "### Published Wheels" >> $GITHUB_STEP_SUMMARY
echo "Total: $(ls -1 dist/*.whl | wc -l) wheels (Python 3.10, 3.11, 3.12)" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "### Features" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "**CPU Multi-Variant Support:**" >> $GITHUB_STEP_SUMMARY
echo "- ✅ AMX (Intel Sapphire Rapids+, 2023)" >> $GITHUB_STEP_SUMMARY
echo "- ✅ AVX512 Base/VNNI/VBMI/BF16 (Intel Skylake-X/Ice Lake/Cascade Lake, 2017+)" >> $GITHUB_STEP_SUMMARY
echo "- ✅ AVX2 (Maximum compatibility, 2013+)" >> $GITHUB_STEP_SUMMARY
echo "- 🔧 Runtime CPU detection: Automatically selects optimal variant" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "**CUDA Support:**" >> $GITHUB_STEP_SUMMARY
echo "- ✅ SM 80 (Ampere: A100, RTX 3000 series)" >> $GITHUB_STEP_SUMMARY
echo "- ✅ SM 86 (Ampere: RTX 3060-3090)" >> $GITHUB_STEP_SUMMARY
echo "- ✅ SM 89 (Ada Lovelace: RTX 4000 series)" >> $GITHUB_STEP_SUMMARY
echo "- ✅ SM 90 (Hopper: H100)" >> $GITHUB_STEP_SUMMARY
echo "- 🔧 Static CUDA runtime: Compatible with CUDA 11.8+ and 12.x drivers" >> $GITHUB_STEP_SUMMARY
echo "- 🔧 Works on CPU-only systems (CUDA features disabled gracefully)" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "**Requirements:**" >> $GITHUB_STEP_SUMMARY
echo "- Python 3.10, 3.11, or 3.12" >> $GITHUB_STEP_SUMMARY
echo "- Linux x86-64 (manylinux_2_17 compatible)" >> $GITHUB_STEP_SUMMARY
echo "- For CUDA features: NVIDIA driver with CUDA 11.8+ or 12.x support" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "PyPI link: https://pypi.org/project/kt-kernel/${{ steps.get_version.outputs.VERSION }}/" >> $GITHUB_STEP_SUMMARY
================================================
FILE: .github/workflows/release-sglang-kt.yml
================================================
name: Release sglang-kt to PyPI
on:
push:
branches:
- main
paths:
- "third_party/sglang"
- "version.py"
workflow_dispatch:
inputs:
test_pypi:
description: 'Publish to TestPyPI instead of PyPI (for testing)'
required: false
default: 'false'
type: choice
options:
- 'true'
- 'false'
permissions:
contents: read
jobs:
build-sglang-kt:
name: Build sglang-kt wheel
runs-on: [self-hosted, linux, x64]
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install build tools
run: |
python -m pip install --upgrade pip
pip install build wheel setuptools
- name: Build sglang-kt wheel
working-directory: third_party/sglang/python
run: |
# Read version from ktransformers version.py
KT_VERSION=$(python3 -c "exec(open('${{ github.workspace }}/version.py').read()); print(__version__)")
export SGLANG_KT_VERSION="$KT_VERSION"
echo "Building sglang-kt v${KT_VERSION} wheel..."
python -m build --wheel -v
- name: Verify wheel
working-directory: third_party/sglang/python
run: |
echo "Generated wheel:"
ls -lh dist/
# Verify the wheel has the correct package name
ls dist/ | grep -q "sglang_kt" || (echo "ERROR: Wheel name does not contain sglang_kt" && exit 1)
echo "Wheel name verified."
- name: Upload artifact
uses: actions/upload-artifact@v4
with:
name: sglang-kt-wheel
path: third_party/sglang/python/dist/*.whl
retention-days: 7
publish-pypi:
name: Publish sglang-kt to PyPI
needs: [build-sglang-kt]
runs-on: [self-hosted, linux, x64]
if: github.repository == 'kvcache-ai/ktransformers' && github.ref == 'refs/heads/main'
environment: prod
permissions:
id-token: write
contents: read
steps:
- name: Download wheel artifact
uses: actions/download-artifact@v4
with:
name: sglang-kt-wheel
path: dist/
- name: Display wheels
run: |
echo "Wheels to publish:"
ls -lh dist/
- name: Install twine
run: |
python -m pip install --upgrade pip
pip install twine
- name: Publish to TestPyPI (if requested)
if: github.event.inputs.test_pypi == 'true'
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }}
run: |
python -m twine upload \
--repository testpypi \
--skip-existing \
--verbose \
dist/*.whl
- name: Publish to PyPI
if: github.event.inputs.test_pypi != 'true'
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python -m twine upload \
--skip-existing \
--verbose \
dist/*.whl
- name: Create release summary
run: |
echo "## sglang-kt Published to PyPI" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "### Installation" >> $GITHUB_STEP_SUMMARY
echo '```bash' >> $GITHUB_STEP_SUMMARY
echo "pip install sglang-kt" >> $GITHUB_STEP_SUMMARY
echo '```' >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "This is the kvcache-ai fork of SGLang with kt-kernel support." >> $GITHUB_STEP_SUMMARY
echo "PyPI link: https://pypi.org/project/sglang-kt/" >> $GITHUB_STEP_SUMMARY
================================================
FILE: .github/workflows/sync-sglang-submodule.yml
================================================
name: Sync sglang submodule
on:
schedule:
# Run daily at 08:00 UTC
- cron: "0 8 * * *"
workflow_dispatch:
permissions:
contents: write
pull-requests: write
jobs:
sync:
name: Check for sglang-kt updates
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
submodules: true
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Update sglang submodule to latest main
id: update
run: |
OLD_SHA=$(git -C third_party/sglang rev-parse HEAD)
git submodule update --remote third_party/sglang
NEW_SHA=$(git -C third_party/sglang rev-parse HEAD)
echo "old_sha=$OLD_SHA" >> "$GITHUB_OUTPUT"
echo "new_sha=$NEW_SHA" >> "$GITHUB_OUTPUT"
if [ "$OLD_SHA" = "$NEW_SHA" ]; then
echo "changed=false" >> "$GITHUB_OUTPUT"
echo "sglang submodule is already up to date ($OLD_SHA)"
else
echo "changed=true" >> "$GITHUB_OUTPUT"
# Collect commit log between old and new
COMMITS=$(git -C third_party/sglang log --oneline "$OLD_SHA..$NEW_SHA" | head -20)
echo "commits<> "$GITHUB_OUTPUT"
echo "$COMMITS" >> "$GITHUB_OUTPUT"
echo "EOF" >> "$GITHUB_OUTPUT"
# sglang-kt version = ktransformers version (from version.py)
VERSION=$(python3 -c "exec(open('version.py').read()); print(__version__)" 2>/dev/null || echo "unknown")
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
echo "sglang submodule updated: $OLD_SHA -> $NEW_SHA (v$VERSION)"
fi
- name: Create pull request
if: steps.update.outputs.changed == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: |
[build]: sync sglang submodule to ${{ steps.update.outputs.new_sha }}
branch: auto/sync-sglang
delete-branch: true
title: "[build] Sync sglang-kt submodule (v${{ steps.update.outputs.version }})"
body: |
Automated sync of `third_party/sglang` submodule to latest `main`.
**Old ref:** `${{ steps.update.outputs.old_sha }}`
**New ref:** `${{ steps.update.outputs.new_sha }}`
**sglang-kt version:** `${{ steps.update.outputs.version }}`
### Commits included
```
${{ steps.update.outputs.commits }}
```
---
*This PR was created automatically by the [sync-sglang-submodule](${{ github.server_url }}/${{ github.repository }}/actions/workflows/sync-sglang-submodule.yml) workflow.*
labels: |
dependencies
automated
================================================
FILE: .gitignore
================================================
__pycache__
build
.vscode
*.so
*.cache
server.db
logs
node_modules
*.nsys-rep
.vs/
*pycache*
*build/
.DS_Store
compile_commands.json
*.egg-info*
*dist/
ktransformers/server/local_store/
ktransformers/server_test1.db
*.patch
img/
tmp*.txt
test.txt
book
ktransformers/tests/chat_txt.txt
mmlu_result*
ktransformers/ktransformers_ext/cuda_musa/
test_prompt.txt
csrc/demo
build*
CMakeFiles/
kvc2/
sched/
*.png
================================================
FILE: .gitmodules
================================================
[submodule "third_party/llama.cpp"]
path = third_party/llama.cpp
url = https://github.com/ggerganov/llama.cpp.git
[submodule "third_party/pybind11"]
path = third_party/pybind11
url = https://github.com/pybind/pybind11.git
[submodule "third_party/custom_flashinfer"]
path = third_party/custom_flashinfer
url = https://github.com/kvcache-ai/custom_flashinfer.git
branch = fix-precision-mla-merge-main
[submodule "third_party/sglang"]
path = third_party/sglang
url = https://github.com/kvcache-ai/sglang.git
branch = main
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: MAINTAINERS.md
================================================
# Maintainers
This document lists the current maintainers and outlines their responsibilities.
## Current Maintainers
| Name | GitHub | Role | Affiliation | Email |
|------|--------|------|-------------|-------|
| Weiyu Xie | [@ErvinXie](https://github.com/ErvinXie) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | xwy21@mails.tsinghua.edu.cn |
| Hongtao Chen | [@chenht2022](https://github.com/chenht2022) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | cht22@mails.tsinghua.edu.cn |
| Jianwei Dong | [@ovowei](https://github.com/ovowei) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | dongjw24@mails.tsinghua.edu.cn |
| Ziwei Yuan | [@KMSorSMS](https://github.com/KMSorSMS) | Maintainer | [Approaching.AI](http://approaching.ai/) | 2022090910005@std.uestc.edu.cn |
| Qingliang Ou | [@ouqingliang](https://github.com/ouqingliang) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | oql@bupt.edu.cn |
| Jiaqi Liao | [@SkqLiao](https://github.com/SkqLiao) | Maintainer | [Approaching.AI](http://approaching.ai/) | jiaqi.liao@bit.edu.cn |
| Peilin Li | [@JimmyPeilinLi](https://github.com/JimmyPeilinLi) | Maintainer | [Approaching.AI](http://approaching.ai/) | lipeilin@mail.nwpu.edu.cn |
| Xingxing Hao | [@mrhaoxx](https://github.com/mrhaoxx) | Maintainer | [Approaching.AI](http://approaching.ai/) | mr.haoxx@gmail.com |
| Boxin Zhang | [@Atream](https://github.com/Atream) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | zhangbx24@mails.tsinghua.edu.cn |
| Jingqi Tang | [@Azure-Tang](https://github.com/Azure-Tang) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | tangjq25@mails.tsinghua.edu.cn |
| Jiahao Wang | [@qiyuxinlin](https://github.com/qiyuxinlin) | Maintainer | [Approaching.AI](http://approaching.ai/) | 202241050020@hdu.edu.cn |
## Responsibilities
Maintainers steward the project and keep it healthy for users and contributors.
- Review and approve pull requests; ensure changes meet quality, testing, and documentation standards.
- Triage issues, keep labels organized, and respond to questions in a timely manner.
- Uphold the project’s code of conduct and report violations when needed.
- Maintain CI reliability and address regressions promptly.
- Oversee releases and keep compatibility with supported dependency versions.
- Protect project security and follow the security disclosure process.
## Becoming a Maintainer
We welcome contributors who show sustained, high-quality contributions and collaborative behavior. If you are interested, please contact an existing maintainer and share your recent contributions and areas of focus.
================================================
FILE: README.md
================================================
## 🎯 Overview
KTransformers is a research project focused on efficient inference and fine-tuning of large language models through CPU-GPU heterogeneous computing. The project has evolved into **two core modules**: [kt-kernel](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel/) and [kt-sft](https://github.com/kvcache-ai/ktransformers/tree/main/kt-sft).
## 🔥 Updates
* **Feb 13, 2026**: MiniMax-M2.5 Day0 Support! ([Tutorial](./doc/en/MiniMax-M2.5.md))
* **Feb 12, 2026**: GLM-5 Day0 Support! ([Tutorial](./doc/en/kt-kernel/GLM-5-Tutorial.md))
* **Jan 27, 2026**: Kimi-K2.5 Day0 Support! ([Tutorial](./doc/en/Kimi-K2.5.md)) ([SFT Tutorial](./doc/en/SFT_Installation_Guide_KimiK2.5.md))
* **Jan 22, 2026**: Support [CPU-GPU Expert Scheduling](./doc/en/kt-kernel/experts-sched-Tutorial.md), [Native BF16 and FP8 per channel Precision](./doc/en/kt-kernel/Native-Precision-Tutorial.md) and [AutoDL unified fine-tuning and inference](./doc/zh/【云端低价训推】%20KTransformers%2BAutoDL%2BLlamaFactory:随用随租的低成本超大模型「微调%2B推理」一体化流程.pdf)
* **Dec 24, 2025**: Support Native MiniMax-M2.1 inference. ([Tutorial](./doc/en/kt-kernel/MiniMax-M2.1-Tutorial.md))
* **Dec 22, 2025**: Support RL-DPO fine-tuning with LLaMA-Factory. ([Tutorial](./doc/en/SFT/DPO_tutorial.md))
* **Dec 5, 2025**: Support Native Kimi-K2-Thinking inference ([Tutorial](./doc/en/kt-kernel/Kimi-K2-Thinking-Native.md))
* **Nov 6, 2025**: Support Kimi-K2-Thinking inference ([Tutorial](./doc/en/Kimi-K2-Thinking.md)) and fine-tune ([Tutorial](./doc/en/SFT_Installation_Guide_KimiK2.md))
* **Nov 4, 2025**: KTransformers Fine-Tuning × LLaMA-Factory Integration. ([Tutorial](./doc/en/KTransformers-Fine-Tuning_User-Guide.md))
* **Oct 27, 2025**: Support Ascend NPU. ([Tutorial](./doc/zh/DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md))
* **Oct 10, 2025**: Integrating into SGLang. ([Roadmap](https://github.com/sgl-project/sglang/issues/11425), [Blog](https://lmsys.org/blog/2025-10-22-KTransformers/))
* **Sept 11, 2025**: Support Qwen3-Next. ([Tutorial](./doc/en/Qwen3-Next.md))
* **Sept 05, 2025**: Support Kimi-K2-0905. ([Tutorial](./doc/en/Kimi-K2.md))
* **July 26, 2025**: Support SmallThinker and GLM4-MoE. ([Tutorial](./doc/en/SmallThinker_and_Glm4moe.md))
* **July 11, 2025**: Support Kimi-K2. ([Tutorial](./doc/en/Kimi-K2.md))
* **June 30, 2025**: Support 3-layer (GPU-CPU-Disk) [prefix cache](./doc/en/prefix_cache.md) reuse.
* **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)).
* **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md))
* **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./doc/en/llama4.md)).
* **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./doc/en/balance-serve.md)).
* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)).
* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022--v023-longer-context--fp8-kernel) for DeepSeek-V3 and R1 in 24GB VRAM.
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).
* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. For detailed show case and reproduction tutorial, see [here](./doc/en/DeepseekR1_V3_tutorial.md).
* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.
* **Aug 15, 2024**: Update detailed [tutorial](doc/en/injection_tutorial.md) for injection and multi-GPU.
* **Aug 14, 2024**: Support llamfile as linear backend.
* **Aug 12, 2024**: Support multiple GPU; Support new model: mixtral 8\*7B and 8\*22B; Support q2k, q3k, q5k dequant on gpu.
* **Aug 9, 2024**: Support windows native.
---
## 📦 Core Modules
### 🚀 [kt-kernel](./kt-kernel/) - High-Performance Inference Kernels
CPU-optimized kernel operations for heterogeneous LLM inference.
**Key Features:**
- **AMX/AVX Acceleration**: Intel AMX and AVX512/AVX2 optimized kernels for INT4/INT8 quantized inference
- **MoE Optimization**: Efficient Mixture-of-Experts inference with NUMA-aware memory management
- **Quantization Support**: CPU-side INT4/INT8 quantized weights, GPU-side GPTQ support
- **Easy Integration**: Clean Python API for SGLang and other frameworks
**Quick Start:**
```bash
cd kt-kernel
pip install .
```
**Use Cases:**
- CPU-GPU hybrid inference for large MoE models
- Integration with SGLang for production serving
- Heterogeneous expert placement (hot experts on GPU, cold experts on CPU)
**Performance Examples:**
| Model | Hardware Configuration | Total Throughput | Output Throughput |
|-------|------------------------|------------------|-------------------|
| DeepSeek-R1-0528 (FP8) | 8×L20 GPU + Xeon Gold 6454S | 227.85 tokens/s | 87.58 tokens/s (8-way concurrency) |
👉 **[Full Documentation →](./kt-kernel/README.md)**
---
### 🎓 [kt-sft](./kt-sft/) - Fine-Tuning Framework
KTransformers × LLaMA-Factory integration for ultra-large MoE model fine-tuning.

**Key Features:**
- **Resource Efficient**: Fine-tune 671B DeepSeek-V3 with just **70GB GPU memory** + 1.3TB RAM
- **LoRA Support**: Full LoRA fine-tuning with heterogeneous acceleration
- **LLaMA-Factory Integration**: Seamless integration with popular fine-tuning framework
- **Production Ready**: Chat, batch inference, and metrics evaluation
**Performance Examples:**
| Model | Configuration | Throughput | GPU Memory |
|-------|--------------|------------|------------|
| DeepSeek-V3 (671B) | LoRA + AMX | ~40 tokens/s | 70GB (multi-GPU) |
| DeepSeek-V2-Lite (14B) | LoRA + AMX | ~530 tokens/s | 6GB |
**Quick Start:**
```bash
cd kt-sft
# Install environment following kt-sft/README.md
USE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml
```
👉 **[Full Documentation →](./kt-sft/README.md)**
---
## 🔥 Citation
If you use KTransformers in your research, please cite our paper:
```bibtex
@inproceedings{10.1145/3731569.3764843,
title = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models},
author = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing},
booktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles},
year = {2025}
}
```
## 👥 Contributors & Team
Developed and maintained by:
- [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University
- [Approaching.AI](http://approaching.ai/)
- [9#AISoft](https://github.com/aisoft9)
- Community contributors
We welcome contributions! Please feel free to submit issues and pull requests.
## 💬 Community & Support
- **GitHub Issues**: [Report bugs or request features](https://github.com/kvcache-ai/ktransformers/issues)
- **WeChat Group**: See [archive/WeChatGroup.png](./archive/WeChatGroup.png)
## 📦 KT original Code
The original integrated KTransformers framework has been archived to the [`archive/`](./archive/) directory for reference. The project now focuses on the two core modules above for better modularity and maintainability.
For the original documentation with full quick-start guides and examples, see:
- [archive/README.md](./archive/README.md) (English)
- [archive/README_ZH.md](./archive/README_ZH.md) (中文)
================================================
FILE: README_ZH.md
================================================
## 🎯 概览
KTransformers 是一个专注于通过 CPU-GPU 异构计算实现大语言模型高效推理和微调的研究项目。该项目已发展为**两个核心模块**:[kt-kernel](./kt-kernel/) 和 [kt-sft](./kt-sft/)。
## 🔥 更新
* **2025 年 12 月 5 日**:支持原生 Kimi-K2-Thinking 推理([教程](./doc/en/Kimi-K2-Thinking-Native.md))
* **2025 年 11 月 6 日**:支持 Kimi-K2-Thinking 推理([教程](./doc/en/Kimi-K2-Thinking.md))和微调([教程](./doc/en/SFT_Installation_Guide_KimiK2.md))
* **2025 年 11 月 4 日**:KTransformers 微调 × LLaMA-Factory 集成([教程](./doc/en/KTransformers-Fine-Tuning_User-Guide.md))
* **2025 年 10 月 27 日**:支持昇腾 NPU([教程](./doc/zh/DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md))
* **2025 年 10 月 10 日**:集成到 SGLang([路线图](https://github.com/sgl-project/sglang/issues/11425),[博客](https://lmsys.org/blog/2025-10-22-KTransformers/))
* **2025 年 9 月 11 日**:支持 Qwen3-Next([教程](./doc/en/Qwen3-Next.md))
* **2025 年 9 月 5 日**:支持 Kimi-K2-0905([教程](./doc/en/Kimi-K2.md))
* **2025 年 7 月 26 日**:支持 SmallThinker 和 GLM4-MoE([教程](./doc/en/SmallThinker_and_Glm4moe.md))
* **2025 年 7 月 11 日**:支持 Kimi-K2([教程](./doc/en/Kimi-K2.md))
* **2025 年 6 月 30 日**:支持 3 层(GPU-CPU-磁盘)[前缀缓存](./doc/en/prefix_cache.md)复用
* **2025 年 5 月 14 日**:支持 Intel Arc GPU([教程](./doc/en/xpu.md))
* **2025 年 4 月 29 日**:支持 AMX-Int8、AMX-BF16 和 Qwen3MoE([教程](./doc/en/AMX.md))
* **2025 年 4 月 9 日**:实验性支持 LLaMA 4 模型([教程](./doc/en/llama4.md))
* **2025 年 4 月 2 日**:支持多并发([教程](./doc/en/balance-serve.md))
* **2025 年 3 月 15 日**:支持 AMD GPU 上的 ROCm([教程](./doc/en/ROCm.md))
* **2025 年 3 月 5 日**:支持 unsloth 1.58/2.51 位权重和 [IQ1_S/FP8 混合](./doc/en/fp8_kernel.md)权重。在 24GB VRAM 中支持 DeepSeek-V3 和 R1 的 139K [更长上下文](./doc/en/DeepseekR1_V3_tutorial.md#v022--v023-longer-context--fp8-kernel)
* **2025 年 2 月 25 日**:为 DeepSeek-V3 和 R1 支持 [FP8 GPU 内核](./doc/en/fp8_kernel.md);[更长上下文](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context)
* **2025 年 2 月 15 日**:更长上下文(24GB VRAM 从 4K 到 8K)& 速度稍快(+15%,最高 16 Tokens/s),更新[文档](./doc/en/DeepseekR1_V3_tutorial.md)和[在线手册](https://kvcache-ai.github.io/ktransformers/)
* **2025 年 2 月 10 日**:支持 Deepseek-R1 和 V3 在单 GPU(24GB VRAM)/多 GPU 和 382GB DRAM 上运行,速度提升高达 3~28 倍。详细案例展示和复现教程请参见[这里](./doc/en/DeepseekR1_V3_tutorial.md)
* **2024 年 8 月 28 日**:将 DeepseekV2 所需的 VRAM 从 21GB 降低到 11GB
* **2024 年 8 月 15 日**:更新了关于注入和多 GPU 的详细[教程](doc/en/injection_tutorial.md)
* **2024 年 8 月 14 日**:支持 llamfile 作为线性后端
* **2024 年 8 月 12 日**:支持多 GPU;支持新模型:mixtral 8\*7B 和 8\*22B;支持 GPU 上的 q2k、q3k、q5k 去量化
* **2024 年 8 月 9 日**:支持 Windows 原生环境
---
## 📦 核心模块
### 🚀 [kt-kernel](./kt-kernel/) - 高性能推理内核
用于异构 LLM 推理的 CPU 优化内核操作。

**主要特性:**
- **AMX/AVX 加速**:Intel AMX 和 AVX512/AVX2 优化的内核,用于 INT4/INT8 量化推理
- **MoE 优化**:高效的专家混合推理,具有 NUMA 感知内存管理
- **量化支持**:CPU 端 INT4/INT8 量化权重,GPU 端 GPTQ 支持
- **易于集成**:为 SGLang 和其他框架提供简洁的 Python API
**快速开始:**
```bash
cd kt-kernel
pip install .
```
**使用场景:**
- 大型 MoE 模型的 CPU-GPU 混合推理
- 与 SGLang 集成用于生产服务
- 异构专家放置(热专家在 GPU 上,冷专家在 CPU 上)
**性能示例:**
| 模型 | 硬件配置 | 总吞吐量 | 输出吞吐量 |
|-------|------------------------|------------------|-------------------|
| DeepSeek-R1-0528 (FP8) | 8×L20 GPU + Xeon Gold 6454S | 227.85 tokens/s | 87.58 tokens/s(8 路并发)|
👉 **[完整文档 →](./kt-kernel/README.md)**
---
### 🎓 [kt-sft](./kt-sft/) - 微调框架
KTransformers × LLaMA-Factory 集成,用于超大型 MoE 模型微调。

**主要特性:**
- **资源高效**:仅需 **70GB GPU 显存** + 1.3TB 内存即可微调 671B DeepSeek-V3
- **LoRA 支持**:完整的 LoRA 微调,带有异构加速
- **LLaMA-Factory 集成**:与流行的微调框架无缝集成
- **生产就绪**:聊天、批量推理和指标评估
**性能示例:**
| 模型 | 配置 | 吞吐量 | GPU 显存 |
|-------|--------------|------------|--------------|
| DeepSeek-V3 (671B) | LoRA + AMX | ~40 tokens/s | 70GB(多 GPU)|
| DeepSeek-V2-Lite (14B) | LoRA + AMX | ~530 tokens/s | 6GB |
**快速开始:**
```bash
cd kt-sft
# 按照 kt-sft/README.md 安装环境
USE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml
```
👉 **[完整文档 →](./kt-sft/README.md)**
---
## 🔥 引用
如果您在研究中使用了 KTransformers,请引用我们的论文:
```bibtex
@inproceedings{10.1145/3731569.3764843,
title = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models},
author = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing},
booktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles},
year = {2025}
}
```
## 👥 贡献者与团队
由以下团队开发和维护:
- 清华大学 [MADSys 实验室](https://madsys.cs.tsinghua.edu.cn/)
- [Approaching.AI](http://approaching.ai/)
- 社区贡献者
我们欢迎贡献!请随时提交问题和拉取请求。
## 💬 社区与支持
- **GitHub Issues**:[报告问题或请求功能](https://github.com/kvcache-ai/ktransformers/issues)
- **微信群**:请参见 [archive/WeChatGroup.png](./archive/WeChatGroup.png)
## 📦 KT原仓库
原始的集成 KTransformers 框架已归档到 [`archive/`](./archive/) 目录以供参考。该项目现在专注于上述两个核心模块,以获得更好的模块化和可维护性。
有关原始文档以及完整的快速入门指南和示例,请参见:
- [archive/README.md](./archive/README.md)(英文)
- [archive/README_ZH.md](./archive/README_ZH.md)(中文)
================================================
FILE: archive/.devcontainer/Dockerfile
================================================
FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel as compile_server
WORKDIR /workspace
ENV CUDA_HOME /usr/local/cuda
RUN <> ~/.bashrc && \
echo "conda activate ktransformers" >> ~/.bashrc
WORKDIR /ktransformers/
CMD ["bash"]
================================================
FILE: archive/LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: archive/MANIFEST.in
================================================
graft third_party
graft ktransformers
graft local_chat.py
graft csrc
include LICENSE README.md
prune ktransformers/website
prune ktransformers/logs
prune ktransformers.egg-info
prune third_party/llama.cpp/models
graft ktransformers/website/dist
global-exclude __pycache__
include KTransformersOps.*.so
include cpuinfer_ext.*.so
================================================
FILE: archive/Makefile
================================================
flake_find:
cd ktransformers && flake8 | grep -Eo '[A-Z][0-9]{3}' | sort | uniq| paste -sd ',' -
format:
@cd ktransformers && black .
@black setup.py
dev_install:
# clear build dirs
rm -rf build
rm -rf *.egg-info
rm -rf ktransformers/ktransformers_ext/build
rm -rf ktransformers/ktransformers_ext/cuda/build
rm -rf ktransformers/ktransformers_ext/cuda/dist
rm -rf ktransformers/ktransformers_ext/cuda/*.egg-info
# install ktransformers
echo "Installing python dependencies from requirements.txt"
pip install -r requirements-local_chat.txt
echo "Installing ktransformers"
KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . -v --no-build-isolation
echo "Installation completed successfully"
clean:
rm -rf build
rm -rf *.egg-info
rm -rf ktransformers/ktransformers_ext/build
rm -rf ktransformers/ktransformers_ext/cuda/build
rm -rf ktransformers/ktransformers_ext/cuda/dist
rm -rf ktransformers/ktransformers_ext/cuda/*.egg-info
install_numa:
USE_NUMA=1 make dev_install
install_no_numa:
env -u USE_NUMA make dev_install
================================================
FILE: archive/README.md
================================================
High-Performance CPU-GPU Hybrid Inference for Large Language Models
## 🎯 Overview
KTransformers is a research project focused on efficient inference and fine-tuning of large language models through CPU-GPU heterogeneous computing. The project has evolved into **two core modules**: [kt-kernel](./kt-kernel/) and [kt-sft](./kt-sft/).
## 🔥 Updates
* **Nov 6, 2025**: Support Kimi-K2-Thinking inference and fine-tune
* **Nov 4, 2025**: KTransformers Fine-Tuning × LLaMA-Factory Integration
* **Oct 27, 2025**: Support Ascend NPU
* **Oct 10, 2025**: Integrating into SGLang ([Roadmap](https://github.com/sgl-project/sglang/issues/11425), [Blog](https://lmsys.org/blog/2025-10-22-KTransformers/))
* **Sept 11, 2025**: Support Qwen3-Next
* **Sept 05, 2025**: Support Kimi-K2-0905
* **July 26, 2025**: Support SmallThinker and GLM4-MoE
* **June 30, 2025**: Support 3-layer (GPU-CPU-Disk) prefix cache reuse
* **May 14, 2025**: Support Intel Arc GPU
* **Apr 29, 2025**: Support AMX-Int8、AMX-BF16 and Qwen3MoE
* **Apr 9, 2025**: Experimental support for LLaMA 4 models
* **Apr 2, 2025**: Support Multi-concurrency
* **Mar 15, 2025**: Support ROCm on AMD GPU
* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and IQ1_S/FP8 hybrid weights; 139K longer context for DeepSeek-V3/R1
* **Feb 25, 2025**: Support FP8 GPU kernel for DeepSeek-V3 and R1
* **Feb 10, 2025**: Support Deepseek-R1 and V3, up to 3~28x speedup
---
## 📦 Core Modules
### 🚀 [kt-kernel](./kt-kernel/) - High-Performance Inference Kernels
CPU-optimized kernel operations for heterogeneous LLM inference.

**Key Features:**
- **AMX/AVX Acceleration**: Intel AMX and AVX512/AVX2 optimized kernels for INT4/INT8 quantized inference
- **MoE Optimization**: Efficient Mixture-of-Experts inference with NUMA-aware memory management
- **Quantization Support**: CPU-side INT4/INT8 quantized weights, GPU-side GPTQ support
- **Easy Integration**: Clean Python API for SGLang and other frameworks
**Quick Start:**
```bash
cd kt-kernel
pip install .
```
**Use Cases:**
- CPU-GPU hybrid inference for large MoE models
- Integration with SGLang for production serving
- Heterogeneous expert placement (hot experts on GPU, cold experts on CPU)
**Performance Examples:**
| Model | Hardware Configuration | Total Throughput | Output Throughput |
|-------|------------------------|------------------|-------------------|
| DeepSeek-R1-0528 (FP8) | 8×L20 GPU + Xeon Gold 6454S | 227.85 tokens/s | 87.58 tokens/s (8-way concurrency) |
👉 **[Full Documentation →](./kt-kernel/README.md)**
---
### 🎓 [kt-sft](./kt-sft/) - Fine-Tuning Framework
KTransformers × LLaMA-Factory integration for ultra-large MoE model fine-tuning.

**Key Features:**
- **Resource Efficient**: Fine-tune 671B DeepSeek-V3 with just **70GB GPU memory** + 1.3TB RAM
- **LoRA Support**: Full LoRA fine-tuning with heterogeneous acceleration
- **LLaMA-Factory Integration**: Seamless integration with popular fine-tuning framework
- **Production Ready**: Chat, batch inference, and metrics evaluation
**Performance Examples:**
| Model | Configuration | Throughput | GPU Memory |
|-------|--------------|------------|------------|
| DeepSeek-V3 (671B) | LoRA + AMX | ~40 tokens/s | 70GB (multi-GPU) |
| DeepSeek-V2-Lite (14B) | LoRA + AMX | ~530 tokens/s | 6GB |
**Quick Start:**
```bash
cd kt-sft
# Install environment following kt-sft/README.md
USE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml
```
👉 **[Full Documentation →](./kt-sft/README.md)**
---
## 🔥 Citation
If you use KTransformers in your research, please cite our paper:
```bibtex
@inproceedings{10.1145/3731569.3764843,
title = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models},
author = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing},
booktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles},
year = {2025}
}
```
## 👥 Contributors & Team
Developed and maintained by:
- [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University
- [Approaching.AI](http://approaching.ai/)
- Community contributors
We welcome contributions! Please feel free to submit issues and pull requests.
## 💬 Community & Support
- **GitHub Issues**: [Report bugs or request features](https://github.com/kvcache-ai/ktransformers/issues)
- **GitHub Discussions**: [Ask questions and share ideas](https://github.com/kvcache-ai/ktransformers/discussions)
- **WeChat Group**: See [archive/WeChatGroup.png](./archive/WeChatGroup.png)
## 📦 Legacy Code
The original integrated KTransformers framework has been archived to the [`archive/`](./archive/) directory for reference. The project now focuses on the two core modules above for better modularity and maintainability.
For the original documentation with full quick-start guides and examples, see:
- [archive/README_LEGACY.md](./archive/README_LEGACY.md) (English)
- [archive/README_ZH_LEGACY.md](./archive/README_ZH_LEGACY.md) (中文)
================================================
FILE: archive/README_LEGACY.md
================================================
🎉 Introduction
KTransformers, pronounced as Quick Transformers, is designed to enhance your 🤗 Transformers experience with advanced kernel optimizations and placement/parallelism strategies.
KTransformers is a flexible, Python-centric framework designed with extensibility at its core.
By implementing and injecting an optimized module with a single line of code, users gain access to a Transformers-compatible
interface, RESTful APIs compliant with OpenAI and Ollama, and even a simplified ChatGPT-like web UI.
Our vision for KTransformers is to serve as a flexible platform for experimenting with innovative LLM inference optimizations. Please let us know if you need any other features.
🔥 Updates
* **Nov 6, 2025**: Support Kimi-K2-Thinking inference ([Tutorial](./doc/en/Kimi-K2-Thinking.md)) and fine-tune ([Tutorial](./doc/en/SFT_Installation_Guide_KimiK2.md))
* **Nov 4, 2025**: KTransformers Fine-Tuning × LLaMA-Factory Integration. ([Tutorial](./doc/en/KTransformers-Fine-Tuning_User-Guide.md))
* **Oct 27, 2025**: Support Ascend NPU. ([Tutorial](./doc/zh/DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md))
* **Oct 10, 2025**: Integrating into SGLang. ([Roadmap](https://github.com/sgl-project/sglang/issues/11425))
* **Sept 11, 2025**: Support Qwen3-Next. ([Tutorial](./doc/en/Qwen3-Next.md))
* **Sept 05, 2025**: Support Kimi-K2-0905. ([Tutorial](./doc/en/Kimi-K2.md))
* **July 26, 2025**: Support SmallThinker and GLM4-MoE. ([Tutorial](./doc/en/SmallThinker_and_Glm4moe.md))
* **July 11, 2025**: Support Kimi-K2. ([Tutorial](./doc/en/Kimi-K2.md))
* **June 30, 2025**: Support 3-layer (GPU-CPU-Disk) [prefix cache](./doc/en/prefix_cache.md) reuse.
* **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)).
* **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md))
https://github.com/user-attachments/assets/fafe8aec-4e22-49a8-8553-59fb5c6b00a2
* **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./doc/en/llama4.md)).
* **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./doc/en/balance-serve.md)).
https://github.com/user-attachments/assets/faa3bda2-928b-45a7-b44f-21e12ec84b8a
* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)).
* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022--v023-longer-context--fp8-kernel) for DeepSeek-V3 and R1 in 24GB VRAM.
* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).
* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).
* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. For detailed show case and reproduction tutorial, see [here](./doc/en/DeepseekR1_V3_tutorial.md).
* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.
* **Aug 15, 2024**: Update detailed [tutorial](doc/en/injection_tutorial.md) for injection and multi-GPU.
* **Aug 14, 2024**: Support llamfile as linear backend.
* **Aug 12, 2024**: Support multiple GPU; Support new model: mixtral 8\*7B and 8\*22B; Support q2k, q3k, q5k dequant on gpu.
* **Aug 9, 2024**: Support windows native.
🌟 Show Cases
GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM
https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285
- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM([Tutorial](./doc/en/DeepseekR1_V3_tutorial.md)).
- Prefill Speed (tokens/s):
- KTransformers: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → 255.26 (optimized AMX-based MoE kernel, V0.3 only) → 286.55 (selectively using 6 experts, V0.3 only)
- Compared to 10.31 tokens/s in llama.cpp with 2×32 cores, achieving up to **27.79× speedup**.
- Decode Speed (tokens/s):
- KTransformers: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, V0.3 only)
- Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**.
- Upcoming Open Source Release:
- AMX optimizations and selective expert activation will be open-sourced in V0.3.
- Currently available only in preview binary distribution, which can be downloaded [here](./doc/en/DeepseekR1_V3_tutorial.md).
- **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench).
- **Faster Speed:** Achieving 126 tokens/s for 2K prompt prefill and 13.6 tokens/s for generation through MoE offloading and injecting advanced kernels from [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) and [Marlin](https://github.com/IST-DASLab/marlin).
- **VSCode Integration:** Wrapped into an OpenAI and Ollama compatible API for seamless integration as a backend for [Tabby](https://github.com/TabbyML/tabby) and various other frontends.
https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c
More advanced features will coming soon, so stay tuned!
🚀 Quick Start
Getting started with KTransformers is simple! Follow the steps below to set up and start using it.
we have already supported vendors:
- Metax
- Sanechips (ZhuFeng V1.0)
- Intel
- Ascend
- Kunpeng
- AMD
### 📥 Installation
To install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html).
📃 Brief Injection Tutorial
At the heart of KTransformers is a user-friendly, template-based injection framework.
This allows researchers to easily replace original torch modules with optimized variants. It also simplifies the process of combining multiple optimizations, allowing the exploration of their synergistic effects.
Given that vLLM already serves as a great framework for large-scale deployment optimizations, KTransformers is particularly focused on local deployments that are constrained by limited resources. We pay special attention to heterogeneous computing opportunities, such as GPU/CPU offloading of quantized models. For example, we support the efficient Llamafile and Marlin kernels for CPU and GPU, respectively. More details can be found here.
Example Usage
To utilize the provided kernels, users only need to create a YAML-based injection template and add the call to `optimize_and_load_gguf` before using the Transformers model.
```python
with torch.device("meta"):
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
...
generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens=1000)
```
In this example, the AutoModel is first initialized on the meta device to avoid occupying any memory resources. Then, `optimize_and_load_gguf` iterates through all sub-modules of the model, matches rules specified in your YAML rule file, and replaces them with advanced modules as specified.
After injection, the original `generate` interface is available, but we also provide a compatible `prefill_and_generate` method, which enables further optimizations like CUDAGraph to improve generation speed.
How to custom your model
A detailed tutorial of the injection and multi-GPU using DeepSeek-V2 as an example is given [here](doc/en/injection_tutorial.md).
Below is an example of a YAML template for replacing all original Linear modules with Marlin, an advanced 4-bit quantization kernel.
```yaml
- match:
name: "^model\\.layers\\..*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
device: "cpu" # which devices to load this module when initializing
kwargs:
generate_device: "cuda"
generate_linear_type: "QuantizedLinearMarlin"
```
Each rule in the YAML file has two parts: `match` and `replace`. The `match` part specifies which module should be replaced, and the `replace` part specifies the module to be injected into the model along with the initialization keywords.
You can find example rule templates for optimizing DeepSeek-V2 and Qwen2-57B-A14, two SOTA MoE models, in the [ktransformers/optimize/optimize_rules](ktransformers/optimize/optimize_rules) directory. These templates are used to power the `local_chat.py` demo.
If you are interested in our design principles and the implementation of the injection framework, please refer to the [design document](doc/en/deepseek-v2-injection.md).
🔥 Citation
If you use KTransformers for your research, please cite our [paper](https://madsys.cs.tsinghua.edu.cn/publication/ktransformers-unleashing-the-full-potential-of-cpu/gpu-hybrid-inference-for-moe-models/):
```
@inproceedings{10.1145/3731569.3764843,
title = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models},
author = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing},
booktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles},
year = {2025}
}
```
Acknowledgment and Contributors
The development of KTransformers is based on the flexible and versatile framework provided by Transformers. We also benefit from advanced kernels such as GGUF/GGML, Llamafile, Marlin, sglang and flashinfer. We are planning to contribute back to the community by upstreaming our modifications.
KTransformers is actively maintained and developed by contributors from the MADSys group at Tsinghua University and members from Approaching.AI. We welcome new contributors to join us in making KTransformers faster and easier to use.
Discussion
If you have any questions, feel free to open an issue. Alternatively, you can join our WeChat group for further discussion. QR Code: [WeChat Group](WeChatGroup.png)
🙋 FAQ
Some common questions are answered in the [FAQ](doc/en/FAQ.md).
================================================
FILE: archive/README_ZH.md
================================================
高性能 CPU-GPU 异构大语言模型推理
## 🎯 项目概述
KTransformers 是一个专注于大语言模型高效推理和微调的研究项目,通过 CPU-GPU 异构计算实现资源受限环境下的模型部署。项目已演进为**两个核心模块**:[kt-kernel](./kt-kernel/) 和 [kt-sft](./kt-sft/)。
## 🔥 更新
* **2025年11月6日**:支持 Kimi-K2-Thinking 推理和微调
* **2025年11月4日**:KTransformers 微调 × LLaMA-Factory 集成
* **2025年10月27日**:支持 Ascend NPU
* **2025年10月10日**:集成到 SGLang ([路线图](https://github.com/sgl-project/sglang/issues/11425), [博客](https://lmsys.org/blog/2025-10-22-KTransformers/))
* **2025年9月11日**:支持 Qwen3-Next
* **2025年9月5日**:支持 Kimi-K2-0905
* **2025年7月26日**:支持 SmallThinker 和 GLM4-MoE
* **2025年6月30日**:支持 3层(GPU-CPU-磁盘)前缀缓存复用
* **2025年5月14日**:支持 Intel Arc GPU
* **2025年4月29日**:支持 AMX-Int8、AMX-BF16 和 Qwen3MoE
* **2025年4月9日**:实验性支持 LLaMA 4 模型
* **2025年4月2日**:支持多并发
* **2025年3月15日**:支持 AMD GPU 的 ROCm
* **2025年3月5日**:支持 unsloth 1.58/2.51 bits 权重和 IQ1_S/FP8 混合权重;DeepSeek-V3/R1 支持 139K 长上下文
* **2025年2月25日**:支持 DeepSeek-V3 和 R1 的 FP8 GPU 内核
* **2025年2月10日**:支持 Deepseek-R1 和 V3,速度提升最高达 3~28 倍
---
## 📦 核心模块
### 🚀 [kt-kernel](./kt-kernel/) - 高性能推理内核
面向异构 LLM 推理的 CPU 优化内核操作库。

**核心特性:**
- **AMX/AVX 加速**:Intel AMX 和 AVX512/AVX2 优化内核,支持 INT4/INT8 量化推理
- **MoE 优化**:高效的专家混合推理,支持 NUMA 感知内存管理
- **量化支持**:CPU 端 INT4/INT8 量化权重,GPU 端 GPTQ 支持
- **易于集成**:简洁的 Python API,可集成到 SGLang 等框架
**快速开始:**
```bash
cd kt-kernel
pip install .
```
**应用场景:**
- 大型 MoE 模型的 CPU-GPU 混合推理
- 与 SGLang 集成用于生产服务
- 异构专家放置(热门专家在 GPU,冷门专家在 CPU)
**性能示例:**
| 模型 | 硬件配置 | 总吞吐量 | 输出吞吐量 |
|------|---------|---------|-----------|
| DeepSeek-R1-0528 (FP8) | 8×L20 GPU + Xeon Gold 6454S | 227.85 tokens/s | 87.58 tokens/s(8路并发)|
👉 **[完整文档 →](./kt-kernel/README.md)**
---
### 🎓 [kt-sft](./kt-sft/) - 微调框架
KTransformers × LLaMA-Factory 集成,支持超大 MoE 模型微调。

**核心特性:**
- **资源高效**:仅需 **70GB 显存** + 1.3TB 内存即可微调 671B DeepSeek-V3
- **LoRA 支持**:完整的 LoRA 微调与异构加速
- **LLaMA-Factory 集成**:与流行微调框架无缝集成
- **生产就绪**:支持对话、批量推理和指标评估
**性能示例:**
| 模型 | 配置 | 吞吐量 | GPU 显存 |
|------|------|--------|----------|
| DeepSeek-V3 (671B) | LoRA + AMX | ~40 tokens/s | 70GB (多卡) |
| DeepSeek-V2-Lite (14B) | LoRA + AMX | ~530 tokens/s | 6GB |
**快速开始:**
```bash
cd kt-sft
# 按照 kt-sft/README.md 安装环境
USE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml
```
👉 **[完整文档 →](./kt-sft/README.md)**
---
## 🔥 引用
如果您在研究中使用了 KTransformers,请引用我们的论文:
```bibtex
@inproceedings{10.1145/3731569.3764843,
title = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models},
author = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing},
booktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles},
year = {2025}
}
```
## 👥 贡献者与团队
由以下团队开发和维护:
- 清华大学 [MADSys 实验室](https://madsys.cs.tsinghua.edu.cn/)
- [Approaching.AI](http://approaching.ai/)
- 社区贡献者
我们欢迎贡献!请随时提交 issues 和 pull requests。
## 💬 社区与支持
- **GitHub Issues**:[报告 bug 或请求功能](https://github.com/kvcache-ai/ktransformers/issues)
- **GitHub Discussions**:[提问和分享想法](https://github.com/kvcache-ai/ktransformers/discussions)
- **微信群**:查看 [archive/WeChatGroup.png](./archive/WeChatGroup.png)
## 📦 历史代码
原完整的 KTransformers 框架代码已归档至 [`archive/`](./archive/) 目录供参考。项目现专注于上述两个核心模块,以实现更好的模块化和可维护性。
关于原始完整文档(包含快速入门指南和示例),请查看:
- [archive/README_LEGACY.md](./archive/README_LEGACY.md) (English)
- [archive/README_ZH_LEGACY.md](./archive/README_ZH_LEGACY.md) (中文)
================================================
FILE: archive/README_ZH_LEGACY.md
================================================
🎉 介绍
KTransformers(发音为 Quick Transformers)旨在通过先进的内核优化和放置/并行策略来增强您对 🤗 [Transformers](https://github.com/huggingface/transformers) 的体验。
KTransformers 是一个以 Python 为中心的灵活框架,其核心是可扩展性。通过用一行代码实现并注入优化模块,用户可以获得与 Transformers 兼容的接口、符合 OpenAI 和 Ollama 的 RESTful API,甚至是一个简化的类似 ChatGPT 的 Web 界面。
我们对 KTransformers 的愿景是成为一个用于实验创新 LLM 推理优化的灵活平台。如果您需要任何其他功能,请告诉我们。
🔥 更新
* **2025 年 2 月 15 日**:为DeepSeek-V3/R1支持[FP8 GPU内核](./doc/en/fp8_kernel.md); 支持更长的上下文([教程](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context)).
* **2025 年 2 月 15 日**:长上下文(从4K到8K,24GB VRAM) & 稍快的速度(+15%)(最快 16 Tokens/s),文档请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md) 和 [在线指南](https://kvcache-ai.github.io/ktransformers/) 。
* **2025 年 2 月 10 日**:支持 Deepseek-R1 和 V3 在单个(24GB VRAM)/多 GPU 和 382G DRAM 上运行,速度提升高达 3~28 倍。详细教程请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md)。
* **2024 年 8 月 28 日**:支持 InternLM2.5-7B-Chat-1M 模型下的 1M 上下文,使用 24GB 的 VRAM 和 150GB 的 DRAM。详细教程请参见 [这里](./doc/en/long_context_tutorial.md)。
* **2024 年 8 月 28 日**:将 DeepseekV2 所需的 VRAM 从 21G 降低到 11G。
* **2024 年 8 月 15 日**:更新了详细的 [教程](doc/en/injection_tutorial.md),介绍注入和多 GPU 的使用。
* **2024 年 8 月 14 日**:支持 llamfile 作为线性后端。
* **2024 年 8 月 12 日**:支持多 GPU;支持新模型:mixtral 8\*7B 和 8\*22B;支持 q2k、q3k、q5k 在 GPU 上的去量化。
* **2024 年 8 月 9 日**:支持 Windows。
🌟 案例展示
在仅 24GB VRAM 的桌面上运行 GPT-4/o1 级别的本地 VSCode Copilot
https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285
- **[NEW!!!] 本地 671B DeepSeek-Coder-V3/R1**:使用其 Q4_K_M 版本,仅需 14GB VRAM 和 382GB DRAM 即可运行(教程请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md))。
- 预填充速度(tokens/s):
- KTransformers:54.21(32 核)→ 74.362(双插槽,2×32 核)→ 255.26(优化的 AMX 基 MoE 内核,仅 V0.3)→ 286.55(选择性使用 6 个专家,仅 V0.3)
- 与 llama.cpp 在 2×32 核下相比,达到 **27.79× 速度提升**。
- 解码速度(tokens/s):
- KTransformers:8.73(32 核)→ 11.26(双插槽,2×32 核)→ 13.69(选择性使用 6 个专家,仅 V0.3)
- 与 llama.cpp 在 2×32 核下相比,达到 **3.03× 速度提升**。
- 即将开源发布:
- AMX 优化和选择性专家激活将在 V0.3 中开源。
- 目前仅在预览二进制分发中可用,可从 [这里](./doc/en/DeepseekR1_V3_tutorial.md) 下载。
- **本地 236B DeepSeek-Coder-V2**:使用其 Q4_K_M 版本,仅需 21GB VRAM 和 136GB DRAM 即可运行,甚至在 [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench) 中得分超过 GPT4-0613。
- **更快的速度**:通过 MoE 卸载和注入来自 [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) 和 [Marlin](https://github.com/IST-DASLab/marlin) 的高级内核,实现了 2K 提示预填充 126 tokens/s 和生成 13.6 tokens/s 的速度。
- **VSCode 集成**:封装成符合 OpenAI 和 Ollama 的 API,可无缝集成到 [Tabby](https://github.com/TabbyML/tabby) 和其他前端的后端。
https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c
更多高级功能即将推出,敬请期待!
🚀 快速入门
KTransformers 的入门非常简单!请参考我们的[安装指南]((https://kvcache-ai.github.io/ktransformers/))进行安装。
📃 简要注入教程
KTransformers 的核心是一个用户友好的、基于模板的注入框架。这使得研究人员可以轻松地将原始 torch 模块替换为优化的变体。它还简化了多种优化的组合过程,允许探索它们的协同效应。
鉴于 vLLM 已经是一个用于大规模部署优化的优秀框架,KTransformers 特别关注受资源限制的本地部署。我们特别关注异构计算时机,例如量化模型的 GPU/CPU 卸载。例如,我们支持高效的 Llamafile 和Marlin 内核,分别用于 CPU 和 GPU。 更多详细信息可以在 这里找到。
示例用法
要使用提供的内核,用户只需创建一个基于 YAML 的注入模板,并在使用 Transformers 模型之前添加对 `optimize_and_load_gguf` 的调用。
```python
with torch.device("meta"):
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
...
generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens=1000)
```
在这个示例中,首先在 meta 设备上初始化 AutoModel,以避免占用任何内存资源。然后,`optimize_and_load_gguf` 遍历模型的所有子模块,匹配您的 YAML 规则文件中指定的规则,并将它们替换为指定的高级模块。
注入后,原始的 `generate` 接口仍然可用,但我们还提供了一个兼容的 `prefill_and_generate` 方法,这使得可以进一步优化,例如使用 CUDAGraph 提高生成速度。
如何自定义您的模型
一个详细的使用 DeepSeek-V2 作为示例的注入和 multi-GPU 教程在 [这里](doc/en/injection_tutorial.md)。
以下是一个将所有原始 Linear 模块替换为 Marlin 的 YAML 模板示例,Marlin 是一个高级的 4 位量化内核。
```yaml
- match:
name: "^model\\.layers\\..*$" # 正则表达式
class: torch.nn.Linear # 仅匹配同时符合名称和类的模块
replace:
class: ktransformers.operators.linear.KTransformerLinear # 量化数据类型的优化内核
device: "cpu" # 初始化时加载该模块的 device
kwargs:
generate_device: "cuda"
generate_linear_type: "QuantizedLinearMarlin"
```
YAML 文件中的每个规则都有两部分:`match` 和 `replace`。`match` 部分指定应替换的模块,`replace` 部分指定要注入到模型中的模块以及初始化关键字。
您可以在 [ktransformers/optimize/optimize_rules](ktransformers/optimize/optimize_rules) 目录中找到用于优化 DeepSeek-V2 和 Qwen2-57B-A14 的示例规则模板。这些模板用于为 `local_chat.py` 示例提供支持。
如果您对我们的设计原则和注入框架的实现感兴趣,请参考 [设计文档](doc/en/deepseek-v2-injection.md)。
致谢和贡献者
KTransformers 的开发基于 Transformers 提供的灵活和多功能框架。我们还受益于 GGUF/GGML、Llamafile 、 Marlin、sglang和flashinfer 等高级内核。我们计划通过向上游贡献我们的修改来回馈社区。
KTransformers 由清华大学 MADSys group 小组的成员以及 Approaching.AI 的成员积极维护和开发。我们欢迎新的贡献者加入我们,使 KTransformers 更快、更易于使用。
讨论
如果您有任何问题,欢迎随时提出 issue。或者,您可以加入我们的微信群进行进一步讨论。二维码: [微信群](WeChatGroup.png)
🙋 常见问题
一些常见问题的答案可以在 [FAQ](doc/en/FAQ.md) 中找到。
================================================
FILE: archive/SECURITY.md
================================================
# Security Policy
## Supported Versions
Use this section to tell people about which versions of your project are
currently being supported with security updates.
| Version | Supported |
| ------- | ------------------ |
| 5.1.x | :white_check_mark: |
| 5.0.x | :x: |
| 4.0.x | :white_check_mark: |
| < 4.0 | :x: |
## Reporting a Vulnerability
Use this section to tell people how to report a vulnerability.
Tell them where to go, how often they can expect to get an update on a
reported vulnerability, what to expect if the vulnerability is accepted or
declined, etc.
================================================
FILE: archive/book.toml
================================================
[book]
authors = ["kvcache-ai"]
language = "zh-CN"
title = "Ktransformers"
src = "doc"
[output.html]
git-repository-url = "https://github.com/kvcache-ai/ktransformers"
edit-url-template = "https://github.com/kvcache-ai/ktransformers/edit/main/{path}"
[output.html.playground]
editable = true
copy-js = true
# line-numbers = true
[output.html.fold]
enable = true
level = 0
================================================
FILE: archive/config.json
================================================
================================================
FILE: archive/csrc/balance_serve/CMakeLists.txt
================================================
option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF)
if(KTRANSFORMERS_USE_NPU)
add_definitions(-DKTRANSFORMERS_USE_NPU=1)
endif()
if(KTRANSFORMERS_USE_NPU)
set(ASCEND_HOME_PATH "$ENV{ASCEND_HOME_PATH}")
message(STATUS "ASCEND_HOME_PATH is ${ASCEND_HOME_PATH}")
include_directories(${ASCEND_HOME_PATH}/include)
link_directories(${TORCH_INSTALL_PREFIX}/../torch.libs)
# find torch_npu
execute_process(
COMMAND python -c "import torch; import torch_npu; print(torch_npu.__path__[0])"
OUTPUT_VARIABLE TORCH_NPU_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE
)
message(STATUS "Found PTA at: ${TORCH_NPU_PATH}")
find_library(PTA_LIBRARY torch_npu PATH "${TORCH_NPU_PATH}/lib")
endif()
cmake_minimum_required(VERSION 3.21)
find_program(GCC_COMPILER NAMES g++-13 g++-12 g++-11 g++ REQUIRED)
set(CMAKE_CXX_COMPILER ${GCC_COMPILER})
# 显示选定的编译器
message(STATUS "Using compiler: ${CMAKE_CXX_COMPILER}")
project(balance_serve VERSION 0.1.0)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC")
set(CMAKE_BUILD_TYPE "Debug")
# set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC")
# set(CMAKE_BUILD_TYPE "Release")
if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI)
find_package(Python3 REQUIRED COMPONENTS Interpreter)
execute_process(
COMMAND ${Python3_EXECUTABLE} -c
"import torch; print('1' if torch.compiled_with_cxx11_abi() else '0')"
OUTPUT_VARIABLE ABI_FLAG
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(_GLIBCXX_USE_CXX11_ABI ${ABI_FLAG} CACHE STRING "C++11 ABI setting from PyTorch" FORCE)
endif()
# 无论是否是自动检测,都传给编译器
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})
message(STATUS "_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI}")
file(GLOB_RECURSE FMT_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.h")
add_custom_target(
format
COMMAND clang-format
-i
-style=file
${FMT_SOURCES}
COMMENT "Running clang-format on all source files"
)
set(BUILD_SHARED_LIBS ON)
set(ENABLE_PUSH OFF)
set(ENABLE_COMPRESSION OFF)
# set(CMAKE_BUILD_TYPE "Release")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(THIRD_PARTY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
set(THIRD_PARTY_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/third_party)
add_subdirectory(${THIRD_PARTY_DIR}/prometheus-cpp ${THIRD_PARTY_BUILD_DIR}/prometheus-cpp EXCLUDE_FROM_ALL)
add_subdirectory(${THIRD_PARTY_DIR}/xxHash/cmake_unofficial ${THIRD_PARTY_BUILD_DIR}/xxHash EXCLUDE_FROM_ALL)
set_target_properties(xxhash PROPERTIES POSITION_INDEPENDENT_CODE ON)
# add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third_party/prometheus-cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/prometheus-cpp)
set(SPDLOG_DIR ${THIRD_PARTY_DIR}/spdlog)
set(FMT_DIR ${THIRD_PARTY_DIR}/fmt)
set(KVC2_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/kvc2/src)
include_directories(${THIRD_PARTY_DIR})
add_subdirectory(${THIRD_PARTY_DIR}/pybind11 ${THIRD_PARTY_BUILD_DIR}/pybind11)
execute_process(
COMMAND python3 -c "import torch; print(torch.__path__[0])"
OUTPUT_VARIABLE TORCH_INSTALL_PREFIX
OUTPUT_STRIP_TRAILING_WHITESPACE
)
message(STATUS "Found PyTorch at: ${TORCH_INSTALL_PREFIX}")
# set(TORCH_INSTALL_PREFIX "/home/xwy/.conda/envs/kvc/lib/python3.12/site-packages/torch")
find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib")
find_package(Torch REQUIRED PATHS "${TORCH_INSTALL_PREFIX}/share/cmake/Torch" NO_DEFAULT_PATH)
add_subdirectory(kvc2)
add_subdirectory(sched)
# add_subdirectory(test)
================================================
FILE: archive/csrc/custom_marlin/__init__.py
================================================
================================================
FILE: archive/csrc/custom_marlin/binding.cpp
================================================
/**
* @Description :
* @Author : Azure-Tang
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-12 03:05:04
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "gptq_marlin/ops.h"
// Python bindings
#include
#include
#include
#include
#include
// namespace py = pybind11;
PYBIND11_MODULE(vLLMMarlin, m) {
/*m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0
data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k
data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k
data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k
data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k
data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k
data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize
iq4_xs data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));*/
m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
"Function to perform GEMM using Marlin quantization.", py::arg("a"),
py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m_tensor"),
py::arg("size_m"), py::arg("size_n"), py::arg("size_k"),
py::arg("sms"), py::arg("is_k_full"));
m.def("gptq_marlin_repack", &gptq_marlin_repack,
"gptq_marlin repack from GPTQ");
}
================================================
FILE: archive/csrc/custom_marlin/gptq_marlin/gptq_marlin.cu
================================================
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
/*
* Adapted from
* https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin
*/
#include "gptq_marlin.cuh"
#include "gptq_marlin_dtypes.cuh"
#include
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same::value || \
std::is_same::value, \
"only float16 and bfloat16 is supported");
template inline std::string str(T x) { return std::to_string(x); }
namespace gptq_marlin {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {}
template shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__ void
Marlin(const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int* __restrict__ g_idx, // int32 group indices of shape k
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
) {}
} // namespace gptq_marlin
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& g_idx,
torch::Tensor& perm, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({ 1, 1 });
}
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template
__device__ inline void mma(const typename ScalarType::FragA& a_frag,
const typename ScalarType::FragB& frag_b,
typename ScalarType::FragC& frag_c) {
const uint32_t* a = reinterpret_cast(&a_frag);
const uint32_t* b = reinterpret_cast(&frag_b);
float* c = reinterpret_cast(&frag_c);
if constexpr (std::is_same::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
}
else if constexpr (std::is_same::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
}
else {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template
__device__ inline void ldsm4(typename ScalarType::FragA& frag_a,
const void* smem_ptr) {
uint32_t* a = reinterpret_cast(&frag_a);
uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template __device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
template
__device__ inline typename ScalarType::FragB dequant_4bit(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType::FragB dequant_4bit(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
typename ScalarType::FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast(&lo),
*reinterpret_cast(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast(&hi),
*reinterpret_cast(&MUL),
*reinterpret_cast(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType::FragB
dequant_4bit(int q) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
typename ScalarType::FragB frag_b;
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC308C308;
frag_b[0] = __hfma2(*reinterpret_cast(&lo),
*reinterpret_cast(&MUL),
*reinterpret_cast(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast(&hi),
*reinterpret_cast(&MUL),
*reinterpret_cast(&ADD));
return frag_b;
}
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
template
__device__ inline typename ScalarType::FragB dequant_8bit(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType::FragB dequant_8bit(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt(q);
uint32_t hi = prmt(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
typename ScalarType::FragB frag_b;
frag_b[0] =
__hsub2(*reinterpret_cast(&lo),
*reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] =
__hsub2(*reinterpret_cast(&hi),
*reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
template <>
__device__ inline typename ScalarType::FragB
dequant_8bit(int q) {
typename ScalarType::FragB frag_b;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388736.f;
fp32_intermediates[1] -= 8388736.f;
fp32_intermediates[2] -= 8388736.f;
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template
__device__ inline void scale(typename ScalarType::FragB& frag_b,
typename ScalarType::FragS& frag_s,
int i) {
using scalar_t2 = typename ScalarType::scalar_t2;
scalar_t2 s = ScalarType::num2num2(
reinterpret_cast(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s);
}
// Same as above, but for act_order (each K is multiplied individually)
template
__device__ inline void scale4(typename ScalarType::FragB& frag_b,
typename ScalarType::FragS& frag_s_1,
typename ScalarType::FragS& frag_s_2,
typename ScalarType::FragS& frag_s_3,
typename ScalarType::FragS& frag_s_4,
int i) {
using scalar_t2 = typename ScalarType::scalar_t2;
scalar_t2 s_val_1_2;
s_val_1_2.x = reinterpret_cast(&frag_s_1)[i];
s_val_1_2.y = reinterpret_cast(&frag_s_2)[i];
scalar_t2 s_val_3_4;
s_val_3_4.x = reinterpret_cast(&frag_s_3)[i];
s_val_3_4.y = reinterpret_cast(&frag_s_4)[i];
frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
}
// Given 2 floats multiply by 2 scales (halves)
template
__device__ inline void scale_float(float* c,
typename ScalarType::FragS& s) {
scalar_t* s_ptr = reinterpret_cast(&s);
c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be
// visible globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
}
__syncthreads();
}
// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
}
}
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {
int start_row = block_rows * blockIdx.x;
int finish_row = start_row + block_rows;
if (finish_row > size_m) {
finish_row = size_m;
}
int cur_block_rows = finish_row - start_row;
int row_stride = size_k * sizeof(half) / 16;
auto permute_row = [&](int row) {
int iters = size_k / default_threads;
int rest = size_k % default_threads;
int offset = row * row_stride;
half const* a_row_half =
reinterpret_cast(a_int4_ptr + offset);
half* out_half = reinterpret_cast(out_int4_ptr + offset);
int base_k = 0;
for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
base_k += default_threads;
}
if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
}
}
};
for (int i = 0; i < cur_block_rows; i++) {
int cur_row = start_row + i;
if (cur_row < size_m) {
permute_row(cur_row);
}
}
}
template shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
>
__device__ void
Marlin(const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int* __restrict__ g_idx, // int32 group indices of shape k
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m, should be divisible by (16 * thread_m_blocks) if bigger than that
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
using Dtype = ScalarType;
using scalar_t2 = typename ScalarType::scalar_t2;
using FragA = typename ScalarType::FragA;
using FragB = typename ScalarType::FragB;
using FragC = typename ScalarType::FragC;
using FragS = typename ScalarType::FragS;
constexpr int pack_factor = 32 / num_bits;
// int prob_m = *prob_m_ptr;
// const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks);
// constexpr int thread_m_blocks = template_thread_m_blocks;
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int parallel = 1;
if (prob_m > 16 * thread_m_blocks) {
parallel = prob_m / (16 * thread_m_blocks);
prob_m = 16 * thread_m_blocks;
}
int k_tiles = prob_k / 16 / thread_k_blocks;
int n_tiles = prob_n / 16 / thread_n_blocks;
int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
if constexpr (!has_act_order && group_blocks != -1) {
if (group_blocks >= thread_k_blocks) {
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts
// in the middle of group.
iters = (group_blocks / thread_k_blocks) *
div_ceil(iters, (group_blocks / thread_k_blocks));
}
}
int slice_row = (iters * blockIdx.x) % k_tiles;
int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par;
int slice_iters; // number of threadblock tiles in the current slice
int slice_count =
0; // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to
// top
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if (slice_col_par >= n_tiles) {
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
locks += (slice_col_par / n_tiles) * n_tiles;
slice_col = slice_col_par % n_tiles;
}
// Compute all information about the current slice which is required for
// synchronization.
auto init_slice = [&]() {
slice_iters =
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
if (slice_iters == 0) return;
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
slice_count = 1;
slice_idx = 0;
int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
if (col_first <= k_tiles * (slice_col_par + 1)) {
int col_off = col_first - k_tiles * slice_col_par;
slice_count = div_ceil(k_tiles - col_off, iters);
if (col_off > 0) slice_count++;
int delta_first = iters * blockIdx.x - col_first;
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
slice_idx = slice_count - 1;
else {
slice_idx = slice_count - 1 - delta_first / iters;
if (col_off > 0) slice_idx--;
}
}
if (slice_col == n_tiles) {
A += 16 * thread_m_blocks * prob_k / 8;
C += 16 * thread_m_blocks * prob_n / 8;
locks += n_tiles;
slice_col = 0;
}
};
init_slice();
// A sizes/strides
// stride of the A matrix in global memory
int a_gl_stride = prob_k / 8;
// stride of an A matrix tile in shared memory
constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
// delta between subsequent A tiles in global memory
constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
// between subsequent accesses within a tile
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
// between shared memory writes
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
// between shared memory tile reads
constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
// within a shared memory tile
constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
// overall size of a tile
constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
// number of shared write iterations for a tile
constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
// B sizes/strides
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
constexpr int b_sh_wr_delta = threads * b_thread_vecs;
constexpr int b_sh_rd_delta = threads * b_thread_vecs;
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / 8;
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
: 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
// Scale size/strides with act_order
constexpr int tb_k = 16 * thread_k_blocks;
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
int act_s_col_stride = 1;
int act_s_col_warp_stride = act_s_col_stride * 8;
int tb_n_warps = thread_n_blocks / 4;
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
// Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
a_gl_rd += a_gl_rd_delta_o * slice_row;
// Shared write index of current thread.
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
// Shared read index.
int a_sh_rd =
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x * b_thread_vecs;
int b_sh_rd = threadIdx.x * b_thread_vecs;
// For act_order
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
int slice_k_start = tb_k * slice_row;
int slice_k_finish = slice_k_start + tb_k * slice_iters;
int slice_k_start_shared_fetch = slice_k_start;
int slice_n_offset = act_s_col_tb_stride * slice_col;
// No act_order
int s_gl_rd;
if constexpr (!has_act_order) {
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
}
else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_sh_stride * slice_col + threadIdx.x;
}
}
int s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int s_sh_rd;
if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
else
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool a_sh_wr_pred[a_sh_wr_iters];
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
}
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto transform_a = [&](int i) {
int row = i / a_gl_rd_delta_o;
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int a_sh_wr_trans[a_sh_wr_iters];
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
}
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < thread_m_blocks; j++)
{
a_sh_rd_trans[i][j] =
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
}
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const int4* B_ptr[b_sh_wr_iters];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
extern __shared__ int4 sh[];
// Shared memory storage for global fetch pipelines.
int4* sh_a = sh;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
int4* sh_s = sh_g_idx + (stages * g_idx_stage);
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order
FragS act_frag_s[2][4][4]; // For act-order
// Zero accumulators.
auto zero_accums = [&]() {
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
{
reinterpret_cast(frag_c)[i] = 0;
}
};
int sh_first_group_id = -1;
int sh_num_groups = -1;
constexpr int sh_max_num_groups = 32;
auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
int last_group_id) {
sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1;
if (sh_num_groups < sh_max_num_groups) {
sh_num_groups = sh_max_num_groups;
}
if (sh_first_group_id + sh_num_groups > num_groups) {
sh_num_groups = num_groups - sh_first_group_id;
}
int row_offset = first_group_id * s_gl_stride;
if (is_async) {
for (int i = 0; i < sh_num_groups; i++) {
if (threadIdx.x < s_sh_stride) {
cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
&scales_ptr[row_offset + (i * s_gl_stride) +
slice_n_offset + threadIdx.x]);
}
}
}
else {
for (int i = 0; i < sh_num_groups; i++) {
if (threadIdx.x < s_sh_stride) {
sh_s[(i * s_sh_stride) + threadIdx.x] =
scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
threadIdx.x];
}
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
if (pred) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
cp_async4_pred(
&sh_a_stage[a_sh_wr_trans[i]],
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
a_sh_wr_pred[i]);
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < b_thread_vecs; j++) {
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
}
B_ptr[i] += b_gl_rd_delta_o;
}
if constexpr (has_act_order) {
// Fetch g_idx thread-block portion
int full_pipe = a_off;
int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
if (cur_k < prob_k && cur_k < slice_k_finish) {
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
int4 const* cur_g_idx_stage_ptr =
reinterpret_cast(&g_idx[cur_k]);
if (threadIdx.x < g_idx_stage) {
cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
&cur_g_idx_stage_ptr[threadIdx.x]);
}
}
}
else {
if constexpr (group_blocks != -1) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch scales if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
}
}
else {
for (int i = 0; i < s_tb_groups; i++) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
&scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence();
};
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait();
__syncthreads();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto fetch_to_registers = [&](int k, int pipe) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++)
{
ldsm4(frag_a[k % 2][i],
&sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_thread_vecs; i++) {
frag_b_quant[k % 2][i] = *reinterpret_cast(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
}
};
bool is_same_group[stages];
int same_group_id[stages];
auto init_same_group = [&](int pipe) {
if constexpr (!has_act_order) {
is_same_group[pipe] = false;
same_group_id[pipe] = 0;
return;
}
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage);
int group_id_1 = sh_g_idx_int_ptr[0];
int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];
is_same_group[pipe] = group_id_1 == group_id_2;
same_group_id[pipe] = group_id_1;
};
auto fetch_scales_to_registers = [&](int k, int full_pipe) {
int pipe = full_pipe % stages;
if constexpr (!has_act_order) {
// No act-order case
if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
}
else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / group_blocks;
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
reinterpret_cast(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
}
}
return;
}
// Act-order case
// Determine K of the "current" thread-block
int cur_k = slice_k_start + tb_k * full_pipe;
if (cur_k >= prob_k || cur_k >= slice_k_finish) {
return;
}
// Reset (to current thread-block) since we read g_idx portion from the
// shared memory
cur_k = 0;
// Progress to current iteration
cur_k += k_iter_size * (k % b_sh_wr_iters);
// Determine "position" inside the thread-block (based on warp and
// thread-id)
int warp_id = threadIdx.x / 32;
int n_warps =
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
int warp_row = warp_id / n_warps;
int warp_col = warp_id % n_warps;
cur_k += warp_row * 16;
int th_id = threadIdx.x % 32;
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
int s_col_shift =
/*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +
(th_id / 4) * act_s_col_stride;
if (is_same_group[pipe]) {
if (k % 2 == 0) {
*(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) =
sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +
s_col_shift];
}
else {
*(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) =
*(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0])));
}
for (int i = 1; i < 4; i++) {
*(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) =
*(reinterpret_cast(&(act_frag_s[k % 2][0][0])));
}
return;
}
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage);
constexpr int k_frag_offsets[4] = { 0, 1, 8,
9 }; // Tensor core offsets per thread
#pragma unroll
for (int i = 0; i < 4; i++) {
int actual_k = cur_k + k_frag_offsets[i];
int group_id = sh_g_idx_int_ptr[actual_k];
int rel_group_id = group_id - sh_first_group_id;
*(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) =
sh_s[rel_group_id * s_sh_stride + s_col_shift];
}
};
// Execute the actual tensor core matmul of a sub-tile.
auto matmul = [&](int k) {
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for (int j = 0; j < 4; j++) {
FragB frag_b0;
FragB frag_b1;
if constexpr (num_bits == 4) {
int b_quant = frag_b_quant[k % 2][0][j];
int b_quant_shift = b_quant >> 8;
frag_b0 = dequant_4bit(b_quant);
frag_b1 = dequant_4bit(b_quant_shift);
}
else {
int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]);
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
frag_b0 = dequant_8bit(b_quant_0);
frag_b1 = dequant_8bit(b_quant_1);
}
// Apply scale to frag_b0
if constexpr (has_act_order) {
scale4(frag_b0, act_frag_s[k % 2][0][j],
act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
act_frag_s[k % 2][3][j], 0);
}
else {
if constexpr (group_blocks != -1) {
scale(frag_b0, frag_s[k % 2][j], 0);
}
}
// Apply scale to frag_b1
if constexpr (has_act_order) {
scale4(frag_b1, act_frag_s[k % 2][0][j],
act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
act_frag_s[k % 2][3][j], 1);
}
else {
if constexpr (group_blocks != -1) {
scale(frag_b1, frag_s[k % 2][j], 1);
}
}
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
#pragma unroll
for (int i = red_off; i > 0; i /= 2) {
if (i <= red_idx && red_idx < 2 * i) {
#pragma unroll
for (int j = 0; j < 4 * 2; j++) {
int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) {
float* c_rd =
reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast(&sh[red_sh_wr]);
#pragma unroll
for (int k = 0; k < 4; k++)
reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] +=
c_rd[k] + c_wr[k];
}
sh[red_sh_wr] =
reinterpret_cast(&frag_c)[4 * 2 * m_block + j];
}
}
__syncthreads();
}
if (red_idx == 0) {
#pragma unroll
for (int i = 0; i < 4 * 2; i++) {
float* c_rd =
reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] +=
c_rd[j];
}
}
__syncthreads();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr int active_threads = 32 * thread_n_blocks / 4;
if (threadIdx.x < active_threads) {
int c_gl_stride = prob_n / 8;
int c_gl_wr_delta_o = 8 * c_gl_stride;
int c_gl_wr_delta_i = 4 * (active_threads / 32);
int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
4 * (threadIdx.x / 32) + threadIdx.x % 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
constexpr int c_sh_wr_delta = active_threads;
int c_sh_wr = threadIdx.x;
int row = (threadIdx.x % 32) / 4;
if (!first) {
// Interestingly, doing direct global accesses here really seems to mess up
// the compiler and lead to slowdowns, hence we also use async-copies even
// though these fetches are not actually asynchronous.
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(
&sh[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
}
cp_async_fence();
cp_async_wait<0>();
}
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
if (!first) {
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
Dtype::num2float(reinterpret_cast(&c_red)[j]);
}
}
if (!last) {
int4 c;
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast(&c)[j] =
Dtype::float2num(reinterpret_cast(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
}
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
c;
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto write_result = [&]() {
int c_gl_stride = prob_n / 8;
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
constexpr int c_sh_rd_delta =
c_sh_stride * (threads / (2 * thread_n_blocks));
int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
(threadIdx.x % (2 * thread_n_blocks));
c_gl_wr += (2 * thread_n_blocks) * slice_col;
int c_sh_wr =
(4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
c_sh_wr += 32 * (threadIdx.x / 32);
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
(threadIdx.x % (2 * thread_n_blocks));
int c_gl_wr_end = c_gl_stride * prob_m;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s) {
scalar_t2 res =
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) {
res = __hmul2(res, s[0]);
}
((scalar_t2*)sh)[idx] = res;
};
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
int wr = c_sh_wr + 8 * j;
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
}
c_sh_wr += 16 * (4 * c_sh_stride);
}
}
__syncthreads();
#pragma unroll
for (int i = 0;
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
i++) {
if (c_gl_wr < c_gl_wr_end) {
C[c_gl_wr] = sh[c_sh_rd];
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
}
}
};
// Start global fetch and register load pipelines.
auto start_pipes = [&]() {
#pragma unroll
for (int i = 0; i < stages - 1; i++) {
if (has_act_order && i == 0) {
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
}
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
}
fetch_to_shared(i, i, i < slice_iters);
}
zero_accums();
wait_for_stage();
init_same_group(0);
fetch_to_registers(0, 0);
fetch_scales_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
slice_k_start_shared_fetch += tb_k * (stages - 1);
};
if (slice_iters) {
start_pipes();
}
// Main loop.
while (slice_iters) {
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// index 0.
#pragma unroll
for (int pipe = 0; pipe < stages;) {
#pragma unroll
for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages);
fetch_scales_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages);
pipe++;
wait_for_stage();
init_same_group(pipe % stages);
}
matmul(k);
}
slice_iters--;
if (slice_iters == 0) {
break;
}
}
a_gl_rd += a_gl_rd_delta_o * stages;
slice_k_start += tb_k * stages;
slice_k_start_shared_fetch += tb_k * stages;
if constexpr (has_act_order) {
int first_group_id = g_idx[slice_k_start];
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
}
int last_group_id = g_idx[last_g_idx];
if (last_group_id >= sh_first_group_id + sh_num_groups) {
fetch_scales_to_shared(false, first_group_id, last_group_id);
__syncthreads();
}
}
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) {
cp_async_wait<0>();
bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before
// write-out
if constexpr (!has_act_order && group_blocks == -1) {
if constexpr (num_bits == 8) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
cp_async_fence();
}
else {
if (last) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
cp_async_fence();
}
}
}
thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1) {
if constexpr (num_bits == 8) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4];
}
}
else {
if (last) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4];
}
}
}
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
scale_float(
reinterpret_cast(&frag_c[i][j][0][0]),
frag_s[j / 2][2 * (j % 2) + 0]);
scale_float(
reinterpret_cast(&frag_c[i][j][0][2]),
frag_s[j / 2][2 * (j % 2) + 0]);
scale_float(
reinterpret_cast(&frag_c[i][j][1][0]),
frag_s[j / 2][2 * (j % 2) + 1]);
scale_float(
reinterpret_cast(&frag_c[i][j][1][2]),
frag_s[j / 2][2 * (j % 2) + 1]);
}
}
}
}
if (slice_count > 1) { // only globally reduce if there is more than one
// block in a slice
barrier_acquire(&locks[slice_col], slice_idx);
global_reduce(slice_idx == 0, last);
barrier_release(&locks[slice_col], last);
}
if (last) // only the last block in a slice actually writes the result
write_result();
slice_row = 0;
slice_col_par++;
slice_col++;
init_slice();
if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
if (slice_col == 0) {
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
}
// Update slice k/n for scales loading
if constexpr (has_act_order) {
slice_k_start = tb_k * slice_row;
slice_k_finish = slice_k_start + tb_k * slice_iters;
slice_k_start_shared_fetch = slice_k_start;
slice_n_offset = act_s_col_tb_stride * slice_col;
}
else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
}
start_pipes();
}
}
}
}
template shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__ void
Marlin_wrapper(const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int* __restrict__ g_idx, // int32 group indices of shape k
int num_groups, // number of scale groups per output channel
const int* __restrict__ prob_m_ptr, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
) {
int prob_m = *prob_m_ptr;
prob_m = min(prob_m, 1024);
const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks);
if(prob_m > 16 * thread_m_blocks)
prob_m = (16 * thread_m_blocks) * div_ceil(prob_m, (16 * thread_m_blocks));
/*if (blockIdx.x == 0 && threadIdx.x == 0)
printf("marlin prob_m %d\n", prob_m);*/
if (thread_m_blocks == 1) {
Marlin(
A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n,
prob_k, locks);
}
else if (thread_m_blocks == 2) {
Marlin(
A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n,
prob_k, locks);
}
else if (thread_m_blocks == 3) {
Marlin(
A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n,
prob_k, locks);
}
else if (thread_m_blocks == 4) {
Marlin(
A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n,
prob_k, locks);
}
}
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
Marlin_wrapper, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin_wrapper<<>>( \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m_ptr, prob_n, \
prob_k, locks); \
}
typedef struct {
int thread_k;
int thread_n;
int num_threads;
} thread_config_t;
typedef struct {
int max_m_blocks;
thread_config_t tb_cfg;
} exec_config_t;
thread_config_t small_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{128, 128, 256},
{64, 128, 128},
{128, 64, 128},
};
thread_config_t large_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{64, 256, 256},
// {128, 128, 256},
{64, 128, 128},
{128, 64, 128},
};
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) {
bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n;
int tb_k = th_config.thread_k;
// Get max scale groups per thread-block
int tb_groups;
if (group_size == -1) {
tb_groups = 1;
}
else if (group_size == 0) {
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
}
else {
tb_groups = div_ceil(tb_k, group_size);
}
if (cache_scales_chunk) {
int load_groups =
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
}
else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * pipe_stages;
}
}
bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int scales_cache_size, int max_shared_mem) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int b_size = (tb_k * tb_n / pack_factor) * 4;
// Get A size
int m_blocks = div_ceil(prob_m, 16);
int tb_max_m = 16;
// zbx: too ugly
// origin
/*while (true) {
if (m_blocks >= max_m_blocks) {
tb_max_m *= max_m_blocks;
break;
}
max_m_blocks--;
if (max_m_blocks == 0) {
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
}
}*/
// refactor
tb_max_m *= std::min(m_blocks, max_m_blocks);
int a_size = (tb_max_m * tb_k) * 2;
float pipe_size = (a_size + b_size) * pipe_stages;
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
}
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
return false;
}
// Verify K/N are divisible by thread K/N
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
return false;
}
// Verify min for thread K/N
if (th_config.thread_n < min_thread_n ||
th_config.thread_k < min_thread_k) {
return false;
}
// num_threads must be at least 128 (= 4 warps)
if (th_config.num_threads < 128) {
return false;
}
// Determine cache for scales
int scales_cache_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
// Check that pipeline fits into cache
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, scales_cache_size, max_shared_mem)) {
return false;
}
return true;
}
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int num_bits, int group_size,
bool has_act_order, bool is_k_full,
int max_shared_mem) {
int max_m_blocks = 4;
while (max_m_blocks > 0) {
if (prob_m <= 16) {
for (auto th_config : small_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n,
prob_k, num_bits, group_size, has_act_order,
is_k_full, max_shared_mem)) {
return exec_config_t{ max_m_blocks, th_config };
}
}
}
else {
for (auto th_config : large_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n,
prob_k, num_bits, group_size, has_act_order,
is_k_full, max_shared_mem)) {
return exec_config_t{ max_m_blocks, th_config };
}
}
}
max_m_blocks--; // Process less M blocks per invocation to reduce cache
// usage
}
return exec_config_t{ 0, {-1, -1, -1} };
}
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
template
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
void* g_idx, void* perm, void* a_tmp, int* prob_m_ptr, int prob_m,
int prob_n, int prob_k, void* workspace, int num_bits,
bool has_act_order, bool is_k_full, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k,
int thread_n, int sms, int max_par) {
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [",
prob_m, ", ", prob_n, ", ", prob_k, "]");
int tot_m = prob_m;
int tot_m_blocks = div_ceil(tot_m, 16);
int pad = 16 * tot_m_blocks - tot_m;
if (sms == -1) {
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
}
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
// Set thread config
exec_config_t exec_cfg;
if (thread_k != -1 && thread_n != -1) {
// User-defined config
exec_cfg = exec_config_t{
4, thread_config_t{thread_k, thread_n, default_threads} };
}
else {
// Auto config
exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full,
max_shared_mem);
}
TORCH_CHECK(
exec_cfg.max_m_blocks > 0 &&
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, max_shared_mem),
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
", thread_k = ", exec_cfg.tb_cfg.thread_k,
", thread_n = ", exec_cfg.tb_cfg.thread_n,
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m,
", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size, ", has_act_order = ", has_act_order,
", is_k_full = ", is_k_full, ", max_shared_mem = ", max_shared_mem);
int num_threads = exec_cfg.tb_cfg.num_threads;
thread_k = exec_cfg.tb_cfg.thread_k;
thread_n = exec_cfg.tb_cfg.thread_n;
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
int blocks = sms;
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
" is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k);
int group_blocks = 0;
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(group_size != -1);
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
}
else {
TORCH_CHECK(group_size == 0);
group_blocks = 0;
}
}
else {
if (group_size == -1) {
group_blocks = -1;
}
else {
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
}
}
const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
const int4* s_ptr = (const int4*)s;
const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm;
int4* a_tmp_ptr = (int4*)a_tmp;
int* locks = (int*)workspace;
if (has_act_order) {
// Permute A columns
int block_rows = div_ceil(prob_m, blocks);
permute_cols_kernel << > > (
A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);
A_ptr = a_tmp_ptr;
}
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by
// having a full K, we have full original groups)
if (is_k_full) {
has_act_order = false;
}
// Main loop
for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {
int thread_m_blocks = tot_m_blocks - i;
prob_m = tot_m - 16 * i;
int par = 1;
if (thread_m_blocks > exec_cfg.max_m_blocks) {
// Note that parallel > 1 currently only works for inputs without
// any padding
par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
if (par > max_par)
par = max_par;
prob_m = (16 * exec_cfg.max_m_blocks) * par;
i += exec_cfg.max_m_blocks * (par - 1);
thread_m_blocks = exec_cfg.max_m_blocks;
}
// Define kernel configurations
#define undefined_error \
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + \
str(prob_n) + ", " + str(prob_k) + "]" + \
", has_act_order = " + str(has_act_order) + \
", num_groups = " + str(num_groups) + \
", group_size = " + str(group_size) + \
", thread_m_blocks = " + str(thread_m_blocks) + \
", thread_n_blocks = " + str(thread_n_blocks) + \
", thread_k_blocks = " + str(thread_k_blocks));
/* std::cout << "MNK = [" + str(prob_m) + ", " + \
str(prob_n) + ", " + str(prob_k) + "]" + \
", has_act_order = " + str(has_act_order) + \
", num_groups = " + str(num_groups) + \
", group_size = " + str(group_size) + \
", thread_m_blocks = " + str(thread_m_blocks) + \
", thread_n_blocks = " + str(thread_n_blocks) + \
", thread_k_blocks = " + str(thread_k_blocks) << std::endl;*/
/*if (false) {
}
// CALL_IF(4, 32, 2, 256)
// CALL_IF(4, 16, 4, 256)
__CALL_IF(4, 1, 16, 4, false, 4, 256)
__CALL_IF(4, 2, 16, 4, false, 4, 256)
// CALL_IF(4, 8, 8, 256)
__CALL_IF(4, 1, 8, 8, false, 4, 256)
__CALL_IF(4, 2, 8, 8, false, 4, 256)
// CALL_IF(4, 16, 4, 128)
__CALL_IF(4, 1, 16, 4, false, 4, 128)
__CALL_IF(4, 2, 16, 4, false, 4, 128)
// CALL_IF(4, 8, 8, 128)
__CALL_IF(4, 1, 8, 8, false, 4, 128)
__CALL_IF(4, 2, 8, 8, false, 4, 128)
else {undefined_error}*/
if (num_bits == 4 && num_threads == 256)
{
if (false) {
}
CALL_IF(4, 32, 2, 256)
CALL_IF(4, 16, 4, 256)
CALL_IF(4, 8, 8, 256)
else {
undefined_error
}
}
else if (num_bits == 4 && num_threads == 128)
{
if (false) {
}
CALL_IF(4, 8, 4, 128)
CALL_IF(4, 16, 4, 128)
CALL_IF(4, 4, 8, 128)
else {
undefined_error
}
}
// else if (num_bits == 8 && num_threads == 256)
// {
// if (false) {
// }
// CALL_IF(8, 32, 2, 256)
// CALL_IF(8, 16, 4, 256)
// CALL_IF(8, 8, 8, 256)
// else {
// undefined_error
// }
// }
// else if (num_bits == 8 && num_threads == 128)
// {
// if (false) {
// }
// CALL_IF(8, 8, 4, 128)
// CALL_IF(8, 16, 4, 128)
// CALL_IF(8, 4, 8, 128)
// else {
// undefined_error
// }
// }
else {
undefined_error
}
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
}
}
} // namespace gptq_marlin
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& g_idx,
torch::Tensor& perm, torch::Tensor& workspace,
int64_t num_bits, torch::Tensor size_m_tensor, int64_t size_m, int64_t size_n,
int64_t size_k, int sms, bool is_k_full) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
// Verify num_bits
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int pack_factor = 32 / num_bits;
// Verify A
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
", size_m = ", size_m);
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
", size_k = ", size_k);
// Verify B
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", gptq_marlin::tile_size);
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k,
", tile_size = ", gptq_marlin::tile_size);
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
"b_q_weight.size(1) = ", b_q_weight.size(1),
" is not divisible by tile_size = ", gptq_marlin::tile_size);
int actual_size_n =
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
", actual_size_n = ", actual_size_n);
// Verify device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
// Alloc buffers
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c = torch::empty({ size_m, size_n }, options);
torch::Tensor a_tmp = torch::empty({ size_m, size_k }, options);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_k = -1;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_n = -1;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
// int sms = -1; //zbx
// Verify g_idx and perm
TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||
(g_idx.size(0) == size_k && perm.size(0) == size_k),
"Unexpected g_idx.size(0) = ", g_idx.size(0),
" and perm.size(0) = ", perm.size(0),
", where size_k = ", size_k);
// Detect groupsize and act_order
int num_groups = -1;
int group_size = -1;
bool has_act_order = g_idx.size(0) != 0;
int b_rank = b_scales.sizes().size();
TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2");
TORCH_CHECK(b_scales.size(1) == size_n,
"b_scales dim 1 = ", b_scales.size(1),
" is not size_n = ", size_n);
num_groups = b_scales.size(0);
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(num_groups > 1,
"For act_order, num_groups must be > 1");
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by num_groups = ", num_groups);
group_size = size_k / num_groups;
}
else {
group_size = 0;
}
}
else {
if (num_groups > 1) {
TORCH_CHECK(
size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by b_scales.size(0) = ", b_scales.size(0));
group_size = size_k / num_groups;
}
else {
group_size = -1;
}
}
// Verify workspace size
TORCH_CHECK(
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
int min_workspace_size =
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = ", workspace.numel(),
" is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
gptq_marlin::marlin_mm_f16i4(
a.data_ptr(), b_q_weight.data_ptr(),
c.data_ptr(), b_scales.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
size_m_tensor.data_ptr(),
size_m, size_n, size_k, workspace.data_ptr(), num_bits,
has_act_order, is_k_full, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
gptq_marlin::max_par);
}
else if (a.scalar_type() == at::ScalarType::BFloat16) {
gptq_marlin::marlin_mm_f16i4(
a.data_ptr(), b_q_weight.data_ptr(),
c.data_ptr(), b_scales.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
size_m_tensor.data_ptr(),
size_m, size_n, size_k, workspace.data_ptr(), num_bits,
has_act_order, is_k_full, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
gptq_marlin::max_par);
}
else {
TORCH_CHECK(false,
"gpt_marlin_gemm only supports bfloat16 and float16");
}
return c;
}
#endif
================================================
FILE: archive/csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh
================================================
// Adapted from
// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin
// Copyrigth 2024 The vLLM team.
// Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
#pragma once
#include
#include
#include
#include
#include
#include
#include
namespace gptq_marlin {
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
static constexpr int default_threads = 256;
static constexpr int pipe_stages =
4; // 4 pipeline stages fit into shared memory
static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;
static constexpr int tile_size = 16;
static constexpr int max_par = 16;
template struct Vec {
T elems[n];
__device__ T &operator[](int i) { return elems[i]; }
};
using I4 = Vec;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async
#else
__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async_fence() {
asm volatile("cp.async.commit_group;\n" ::);
}
template __device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
#endif
} // namespace gptq_marlin
================================================
FILE: archive/csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh
================================================
// Adapted from
// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin
// Copyrigth 2024 The vLLM team.
// Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
#ifndef _data_types_cuh
#define _data_types_cuh
#include "gptq_marlin.cuh"
#include
#include
namespace gptq_marlin {
template class ScalarType {};
template <> class ScalarType {
public:
using scalar_t = half;
using scalar_t2 = half2;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec;
using FragB = Vec;
using FragC = Vec;
using FragS = Vec;
static __device__ float inline num2float(const half x) {
return __half2float(x);
}
static __device__ half2 inline num2num2(const half x) {
return __half2half2(x);
}
static __device__ half2 inline nums2num2(const half x1, const half x2) {
return __halves2half2(x1, x2);
}
static __host__ __device__ half inline float2num(const float x) {
return __float2half(x);
}
};
template <> class ScalarType {
public:
using scalar_t = nv_bfloat16;
using scalar_t2 = nv_bfloat162;
using FragA = Vec;
using FragB = Vec;
using FragC = Vec;
using FragS = Vec;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) {
return __bfloat162float(x);
}
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
return __bfloat162bfloat162(x);
}
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
const nv_bfloat16 x2) {
return __halves2bfloat162(x1, x2);
}
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
return __float2bfloat16(x);
}
#endif
};
} // namespace gptq_marlin
#endif
================================================
FILE: archive/csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu
================================================
#include "gptq_marlin.cuh"
namespace gptq_marlin {
static constexpr int repack_stages = 8;
static constexpr int repack_threads = 256;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template
__global__ void marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace gptq_marlin
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
template
__global__ void marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
int start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait();
__syncthreads();
};
extern __shared__ int4 sh[];
constexpr int perm_size = tile_k_size / 4;
int4* sh_perm_ptr = sh;
int4* sh_pipe_ptr = sh_perm_ptr;
if constexpr (has_perm) {
sh_pipe_ptr += perm_size;
}
constexpr int tile_ints = tile_k_size / pack_factor;
constexpr int stage_n_threads = tile_n_size / 4;
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto load_perm_to_shared = [&](int k_tile_id) {
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr);
if (threadIdx.x < perm_size) {
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
}
__syncthreads();
};
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
cp_async_fence();
return;
}
int first_n = n_tile_id * tile_n_size;
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
if constexpr (has_perm) {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
uint32_t const* sh_perm_int_ptr =
reinterpret_cast(sh_perm_ptr);
int src_k = sh_perm_int_ptr[k_id];
int src_k_packed = src_k / pack_factor;
cp_async4(
&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast(&(
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
}
} else {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
int first_k_packed = first_k / pack_factor;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast(
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
first_n + (n_id * 4)])));
}
}
cp_async_fence();
};
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
return;
}
int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;
if (warp_id >= 4) {
return;
}
int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2;
constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col;
constexpr int sh_stride = 64;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr);
uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr);
uint32_t vals[8];
if constexpr (has_perm) {
for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i];
uint32_t src_k = sh_perm_int_ptr[k_idx];
uint32_t src_k_pos = src_k % pack_factor;
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
vals[i] = b1_cur_val;
vals[4 + i] = b2_cur_val;
}
} else {
uint32_t b1_vals[tile_ints];
uint32_t b2_vals[tile_ints];
#pragma unroll
for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
}
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int cur_int = cur_elem / pack_factor;
int cur_pos = cur_elem % pack_factor;
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
}
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
if constexpr (has_perm) {
load_perm_to_shared(k_tile_id);
}
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
wait_for_stage();
}
n_tile_id += repack_stages;
}
}
}
} // namespace gptq_marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel \
<<>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
// Verify B
TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k, ", pack_factor = ", pack_factor);
TORCH_CHECK(b_q_weight.size(1) == size_n,
"b_q_weight.size(1) = ", b_q_weight.size(1),
" is not size_n = ", size_n);
// Verify device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
torch::Tensor out =
torch::empty({size_k / gptq_marlin::tile_size,
size_n * gptq_marlin::tile_size / pack_factor},
options);
// Detect if there is act_order
bool has_perm = perm.size(0) != 0;
// Get ptrs
uint32_t const* b_q_weight_ptr =
reinterpret_cast(b_q_weight.data_ptr());
uint32_t const* perm_ptr = reinterpret_cast(perm.data_ptr());
uint32_t* out_ptr = reinterpret_cast(out.data_ptr());
// Get dev info
int dev = b_q_weight.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
if (false) {
}
CALL_IF(4, false)
CALL_IF(4, true)
CALL_IF(8, false)
CALL_IF(8, true)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", has_perm = ", has_perm);
}
return out;
}
#endif
================================================
FILE: archive/csrc/custom_marlin/gptq_marlin/ops.h
================================================
/**
* @Description :
* @Author : Azure
* @Date : 2024-07-22 09:27:55
* @Version : 1.0.0
* @LastEditors : Azure
* @LastEditTime : 2024-07-26 08:35:00
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#pragma once
#include
#include
#include
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &g_idx,
torch::Tensor &perm, torch::Tensor &workspace,
int64_t num_bits, torch::Tensor size_m_tensor, int64_t size_m, int64_t size_n,
int64_t size_k, int sms, bool is_k_full);
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor&perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);
================================================
FILE: archive/csrc/custom_marlin/setup.py
================================================
from setuptools import setup, Extension
from torch.utils import cpp_extension
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='vLLMMarlin',
ext_modules=[
CUDAExtension(
'vLLMMarlin', [
#'custom_gguf/dequant.cu',
'binding.cpp',
'gptq_marlin/gptq_marlin.cu',
'gptq_marlin/gptq_marlin_repack.cu',
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': [
'-O3',
'--use_fast_math',
'-Xcompiler', '-fPIC',
]
},
)
],
cmdclass={'build_ext': BuildExtension}
)
================================================
FILE: archive/csrc/custom_marlin/test_cuda_graph.py
================================================
import csv
import torch
import torch.nn as nn
import vLLMMarlin
torch.set_grad_enabled(False)
from utils.marlin_utils import (
MarlinWorkspace,
marlin_quantize,
GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MIN_THREAD_K,
GPTQ_MARLIN_MAX_PARALLEL,
)
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
setup_seed(20241223)
torch.set_grad_enabled(False)
torch.set_default_dtype(torch.bfloat16)
global_dtype=torch.bfloat16
global_device=torch.device("cuda",0)
global_num_cases:int=int(50)
torch.cuda.set_device(0)
torch.backends.cudnn.enabled =True
torch.backends.cudnn.benchmark = True
max_batch_size = 512
max_tp = 8
L2_size = 73728 * 1024
def get_usable_mem():
properties = torch.cuda.get_device_properties(global_device)
#print(f"Total memory: {properties.total_memory / (1024 ** 3):.2f} GB")
allocated_memory = torch.cuda.memory_allocated(global_device)
#print(f"Currently allocated memory: {allocated_memory / (1024 ** 2):.2f} MB")
reserved_memory = torch.cuda.memory_reserved(global_device)
#print(f"Currently reserved memory: {reserved_memory / (1024 ** 2):.2f} MB")
return properties.total_memory - 512 * 1024 ** 2 - allocated_memory# - reserved_memory
def exp_range(start, stop, step = 2):
now = start
while now <= stop:
yield now
now *= step
def timing(func, iters, epochs=100):
#warmup
for idx in range(iters):
func(idx)
torch.cuda.synchronize()
cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cuda_graph):
for idx in range(iters):
func(idx)
for _ in range(2000):
cuda_graph.replay()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
stream = torch.cuda.Stream()
torch.cuda.synchronize()
#with torch.cuda.stream(stream):
start_event.record()
for _ in range(10):
cuda_graph.replay()
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms0 = start_event.elapsed_time(end_event)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
#with torch.cuda.stream(stream):
start_event.record()
for _ in range(epochs+10):
cuda_graph.replay()
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event) - elapsed_time_ms0
#print(elapsed_time_ms0, elapsed_time_ms)
return elapsed_time_ms/iters/epochs
class LinearMarlin(nn.Linear):
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
g_idx: torch.Tensor
sort_indices: torch.Tensor
has_bias: bool
def __init__(
self,
in_features,
out_features,
bias = False,
device: str = "cuda",
num_bits: int = 4, # 4-bit/8-bit is supported
group_size: int = 64, # -1, 32, 64, 128
act_order: bool = False,
is_k_full=True,
sms = -1, # sms in GPU
**kwargs,
):
self.padding = False
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
if in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:
#print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding")
self.padding = True
self.orin_in_features = in_features
self.orin_out_features = out_features
in_features = (in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K
out_features = (out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N
#print(f"After padding: in_features={in_features}, out_features={out_features}")
super().__init__(in_features, out_features, bias, device)
self.has_bias = bias
self.device = device
self.num_bits = num_bits
self.group_size = group_size
self.act_order = act_order
# TODO: optimize every shape GEMM
blocks_k, blocks_n = in_features//128, out_features//128
self.sms = sms
self.is_k_full = is_k_full
self.weight.requires_grad = False
self.weight.t_()
# Pack Marlin linear
#w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
# self.weight, self.num_bits, self.group_size, self.act_order
#)
marlin_q_w = torch.randint(int(-1e9), int(1e9), (in_features//16, out_features*2), device=device, dtype=torch.int)
marlin_s = torch.randn((in_features//64, out_features), device=device)
self.workspace = MarlinWorkspace(
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL, self.device
)
self.marlin_q_w = marlin_q_w
self.marlin_s = marlin_s
self.g_idx = torch.empty((0), dtype=torch.int32, device=self.device)
self.sort_indices = torch.empty((0), dtype=torch.int32, device=self.device)
self.k = self.weight.shape[0]
self.n = self.weight.shape[1]
self.weight = None
"""
print(in_features, out_features)
print(marlin_q_w.shape)
print(marlin_q_w.dtype)
print(marlin_s.shape)
print(marlin_s.dtype)
print(self.workspace.scratch.shape)
print(self.workspace.scratch.dtype)
print(self.g_idx.shape)
print(self.g_idx.dtype)
print(self.sort_indices.shape)
print(self.sort_indices.dtype)
#print(w_ref.shape)
#print(w_ref.dtype)
"""
#w_ref = None
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor:
# Only support input x as BF16 and FP16
x = x.to(self.device)
orig_shape = list(x.shape)
orig_dtype = x.dtype
x = x.reshape(-1, x.shape[-1])
if self.padding:
padding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype)
padding_input[:,:self.orin_in_features] = x
x = padding_input
marlin_s = self.marlin_s.to(x.dtype)
#print(self.sms * ((orig_shape[0]+63)//64))
sms = self.sms
x = vLLMMarlin.gptq_marlin_gemm(
x,
self.marlin_q_w,
marlin_s,
self.g_idx,
self.sort_indices,
self.workspace.scratch,
self.num_bits,
bsz_tensor,
x.shape[0],
self.n,
x.shape[-1],
sms,
self.is_k_full,
)
# TODO: don't padding bias
if self.has_bias:
x = x + self.bias
if self.padding:
x = x[:,:self.orin_out_features]
orig_shape[-1] = self.orin_out_features
else:
orig_shape[-1] = self.out_features
return x.reshape(orig_shape).to(orig_dtype)
def benchLinearMarlin(input_dim, output_dim):#, out_file
print("benchmarking MLP Marlin")
print("-----------------------------------------------------------")
headers = ["batch_size", "tp", "used_time", "bandwidth GB/s", "TFLOPS", "cases", "padding", "sms"]
print(" | ".join(headers) + "\n")
rows = []
for batch_size in exp_range(1, 64):
for tp in exp_range(1, max_tp):
torch.cuda.empty_cache()
if output_dim % tp != 0:
continue
cur_output_dim = output_dim // tp
modules = []
inputs = []
data_size = int(0.53125*input_dim*cur_output_dim)
input_size = int(2*batch_size*input_dim)
output_size = int(2*batch_size*cur_output_dim)
usable_mem = get_usable_mem() - 2 * input_dim * cur_output_dim
min_cases = max(global_num_cases, (2*L2_size) // (data_size+input_size))
cases = int(min(min_cases, (usable_mem * 0.8) // (data_size+input_size)))
#print(usable_mem, data_size, input_size, cases)
bsz_tensor = torch.tensor([batch_size], device=global_device, dtype=torch.int32)
if cases == 0:
row = [f"{batch_size}", "OOM", "OOM", "OOM", "0", "False"]
rows.append(row)
break
for _ in range(cases):
modules.append(LinearMarlin(input_dim, cur_output_dim, sms=56, non_equal_division=False).to(device=global_device).eval())
inputs.append(torch.randn(batch_size, 1, input_dim, device=global_device))
def forward(case_id):
modules[case_id](inputs[case_id], bsz_tensor)
used_time = timing(forward, iters=cases)
bandwidth = (data_size+input_size+output_size)/used_time/1e6
flops = 2*batch_size*input_dim*cur_output_dim
tflops = flops/used_time/1e9
cur_sms = modules[0].sms
row = [f"{batch_size}", f"{tp}", f"{used_time}", f"{bandwidth}", f"{tflops}", f"{cases}", modules[0].padding, cur_sms]
rows.append(row)
print(f"{batch_size}", f"{tp}", f"{used_time}", f"{bandwidth}", f"{tflops}", f"{cases}", modules[0].padding, cur_sms)
"""
with open(out_file, 'w', newline='') as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow(headers)
for row in rows:
csvwriter.writerow(row)
"""
"""
markdown_table = " | ".join(headers) + "\n"
markdown_table += " | ".join(["---"] * len(headers)) + "\n"
for row in rows:
markdown_table += " | ".join(row) + "\n"
print(markdown_table)
"""
#print("finish write file", out_file)
#print("-------------------------------------------------------------")
if __name__ == "__main__":
benchLinearMarlin(5120, 3584)
exit(0)
max_batch = 1
cur_batch = 1
marlin_linear = LinearMarlin(5120, 3584)
input_tensor = torch.randn(max_batch, 1, 5120, device="cuda", dtype=torch.bfloat16)
bsz_tensor = torch.tensor([max_batch], device="cuda", dtype=torch.int32)
out_truth = marlin_linear(input_tensor, bsz_tensor)
print(out_truth)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out_buf = marlin_linear(input_tensor, bsz_tensor)
for i in range(10000):
g.replay()
#torch.testing.assert_close(out_buf, out_truth, rtol=1e-3, atol=1e-3)
marlin_linear = LinearMarlin(5120, 3584)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out_buf = marlin_linear(input_tensor, bsz_tensor)
new_input = torch.randn(cur_batch, 1, 5120, device="cuda", dtype=torch.bfloat16)
bsz_tensor.copy_(torch.tensor([cur_batch], device="cuda", dtype=torch.int32))
new_out_truth = marlin_linear(new_input, bsz_tensor)
input_tensor[:cur_batch].copy_(new_input)
input_tensor[cur_batch:] = 0
g.replay()
torch.cuda.synchronize()
def printMinMax(tensor):
abs_tensor = torch.abs(tensor)
min_val = torch.min(abs_tensor)
max_val = torch.max(abs_tensor)
min_indices = (abs_tensor == min_val).nonzero(as_tuple=True)
max_indices = (abs_tensor == max_val).nonzero(as_tuple=True)
print(f"min: {min_val.item()}")
print(f"min idx: {min_indices}")
print(f"max: {max_val.item()}")
print(f"max idx: {max_indices}")
print(out_buf[:cur_batch].shape)
print(new_out_truth.shape)
printMinMax(out_buf[:cur_batch])
printMinMax(new_out_truth)
#torch.testing.assert_close(out_buf[:cur_batch, 0, :], new_out_truth[:cur_batch, 0, :], rtol=1e-3, atol=1e-3)
================================================
FILE: archive/csrc/custom_marlin/utils/__init__.py
================================================
================================================
FILE: archive/csrc/custom_marlin/utils/format24.py
================================================
#
# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
#
import torch
# This is PyTorch implementation of main part of reorder_meta()
# function, from tools/util/include/cutlass/util/host_reorder.h file
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
# GEMM decides upon layout of this matrix, and at the moment for the
# sparse GEMM executed on tensor cores, this is layout described by
# ColumnMajorInterleaved<2> data structure, in
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
# reordering of meta matrix into meta_reordered matrix calculated
# according to these segments of CUTLASS code is re-implemented here.
# Note that this calculation produces offsets for scattering metadata
# matrix elements into reordered metadata matrix elements (or,
# equivalently, for gathering reordered metadata matrix element back
# into metadata matrix elements).
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype,
device):
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
# Reorder the rows, then swizzle the 2x2 blocks.
group_x = 64
group_y = 32 if meta_dtype.itemsize == 2 else 16
dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 +
(dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 +
((dst_rows % group_x) // 8) * 4)
topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
dst_rows += topright - bottomleft
dst_cols -= topright - bottomleft
# Assumed that meta tensor is to be stored in CUTLASS
# InterleavedColumnMajor layout, and reverse engineered
# corresponding code to store values into this tensor.
interleave = 2
cols_maj = dst_cols // interleave
cols_min = dst_cols % interleave
return (cols_maj * m * interleave + dst_rows * interleave +
cols_min).view(-1)
# This function converts dense matrix into sparse semi-structured
# representation, producing "compressed" matrix, in the layout used by
# CUTLASS backend, and corresponding metadata matrix.
def sparse_semi_structured_from_dense_cutlass(dense):
if dense.dim() != 2:
raise RuntimeError(
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
)
m, k = dense.shape
device = dense.device
meta_dtype = torch.int8
if dense.dtype == torch.int8:
meta_dtype = torch.int32
elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
meta_dtype = torch.int16
else:
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
if quadbits_per_meta_elem not in (4, 8):
raise RuntimeError(
"Invalid number of elements per meta element calculated")
if meta_dtype == torch.int32:
if m % 16 != 0:
raise RuntimeError(
f"Number of rows of dense matrix {m} must be divisible by 16")
else:
if m % 32 != 0:
raise RuntimeError(
f"Number of rows of dense matrix {m} must be divisible by 32")
if k % (4 * quadbits_per_meta_elem) != 0:
raise RuntimeError(
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
)
if dense.dtype != torch.float:
ksparse = 4
dense_4 = dense.view(-1, k // ksparse, ksparse)
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
else:
ksparse = 2
dense_2 = dense.view(-1, k // ksparse, ksparse)
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
# Encoding quadruples of True/False values as follows:
# [True, True, False, False] -> 0b0100
# [True, False, True, False] -> 0b1000
# [False, True, True, False] -> 0b1001
# [True, False, False, True ] -> 0b1100
# [False, True, False, True ] -> 0b1101
# [False, False, True, True ] -> 0b1110
# Thus, lower two bits in the encoding are index of the True value
# at the lowest index in the quadruple, and the higher two bits in
# the encoding are index of the other True value in the quadruple.
# In case there are less than two True values, than False value or
# values at some index or indices are considered True for the
# encoding. In case there are more than two True values, then the
# excess True value(s) at some indices are considered False for
# the encoding. The exact encodings used for these cases are as
# follows:
# [False, False, False, False] -> 0b1110
# [False, False, False, True ] -> 0b1110
# [False, False, True, False] -> 0b1110
# [False, True, False, False] -> 0b1001
# [False, True, True, True ] -> 0b1101
# [True, False, False, False] -> 0b1000
# [True, False, True, True ] -> 0b1100
# [True, True, False, True ] -> 0b0100
# [True, True, True, False] -> 0b0100
# [True, True, True, True ] -> 0b0100
# These particular encodings are chosen, with the help of Espresso
# logic minimizer software, for the purpose of minimization of
# corresponding Boolean functions, that translate non-zero flags
# into encoding bits. Note also possible choices for the first
# and last of these encodings were limited only to (0b0100,
# 0b1110), in order to produce valid encodings for 1:2 sparsity
# case.
expr0 = m0 & m1
expr1 = ~m0 & m1
expr2 = ~m0 & ~m1
bit0 = expr1
bit1 = expr2
bit2 = expr0 | expr2 | m3
bit3 = expr1 | ~m1
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
if dense.dtype != torch.float:
sparse0 = dense_4.gather(
-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
else:
sparse = dense_2.gather(-1,
idxs0.unsqueeze(-1) // 2).view(
m,
k // 2) # type: ignore[possibly-undefined]
meta_4 = idxs0 | (idxs1 << 2)
meta_n = meta_4.view(
(-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
if quadbits_per_meta_elem == 4:
meta = (meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12))
elif quadbits_per_meta_elem == 8:
meta = (meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12)
| (meta_n[:, :, 4] << 16)
| (meta_n[:, :, 5] << 20)
| (meta_n[:, :, 6] << 24)
| (meta_n[:, :, 7] << 28))
# Reorder meta tensor elements.
meta_reordered = meta.new_empty(
(m * meta_ncols, )) # type: ignore[possibly-undefined]
meta_offsets = _calculate_meta_reordering_scatter_offsets(
m, meta_ncols, meta_dtype, device)
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
return (sparse, meta_reordered.view(m, meta_ncols))
# This function performs reverse of the function above - it
# reconstructs dense matrix from a pair of "compressed" matrix, given
# in the layout used by CUTLASS backend, and accompanying metadata
# matrix.
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
if sparse.dim() != 2:
raise RuntimeError(
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
)
m, k = sparse.shape
device = sparse.device
if meta_reordered.dim() != 2:
raise RuntimeError(
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
)
if meta_reordered.device != device:
raise RuntimeError(
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
)
meta_dtype = meta_reordered.dtype
if meta_dtype not in (torch.int16, torch.int32):
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
ksparse = 4 if sparse.dtype != torch.float else 2
meta_nrows, meta_ncols = meta_reordered.shape
if meta_nrows != m:
raise RuntimeError(
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
)
if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
raise RuntimeError(
f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
"expected according to the number of columns of meta matrix")
# Undo meta tensor elements reordering.
meta_offsets = _calculate_meta_reordering_scatter_offsets(
m, meta_ncols, meta_dtype, device)
meta = torch.gather(meta_reordered.view(-1), 0,
meta_offsets).view(m, meta_ncols)
# Unpack sparse tensor back to original dense tensor, using
# information provided by meta tensor. Note that torch.float
# datatype is handled pretty much the same as
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
# value is encoded as if underlying 8 bytes contain four
# torch.half/torch.bfloat16 values, where either first two or last
# two are zeros.
meta_2 = torch.empty(
(m, meta_ncols, 2 * quadbits_per_meta_elem),
dtype=meta_dtype,
device=device,
)
if quadbits_per_meta_elem == 4:
meta_2[:, :, 0] = meta & 0b11
meta_2[:, :, 1] = (meta >> 2) & 0b11
meta_2[:, :, 2] = (meta >> 4) & 0b11
meta_2[:, :, 3] = (meta >> 6) & 0b11
meta_2[:, :, 4] = (meta >> 8) & 0b11
meta_2[:, :, 5] = (meta >> 10) & 0b11
meta_2[:, :, 6] = (meta >> 12) & 0b11
meta_2[:, :, 7] = (meta >> 14) & 0b11
elif quadbits_per_meta_elem == 8:
meta_2[:, :, 0] = meta & 0b11
meta_2[:, :, 1] = (meta >> 2) & 0b11
meta_2[:, :, 2] = (meta >> 4) & 0b11
meta_2[:, :, 3] = (meta >> 6) & 0b11
meta_2[:, :, 4] = (meta >> 8) & 0b11
meta_2[:, :, 5] = (meta >> 10) & 0b11
meta_2[:, :, 6] = (meta >> 12) & 0b11
meta_2[:, :, 7] = (meta >> 14) & 0b11
meta_2[:, :, 8] = (meta >> 16) & 0b11
meta_2[:, :, 9] = (meta >> 18) & 0b11
meta_2[:, :, 10] = (meta >> 20) & 0b11
meta_2[:, :, 11] = (meta >> 22) & 0b11
meta_2[:, :, 12] = (meta >> 24) & 0b11
meta_2[:, :, 13] = (meta >> 26) & 0b11
meta_2[:, :, 14] = (meta >> 28) & 0b11
meta_2[:, :, 15] = (meta >> 30) & 0b11
dense_offsets = meta_2.view(-1) + (
torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view(
-1, 1).repeat(1, 2).view(-1)
dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device)
if sparse.dtype != torch.float:
# dense.scatter_(0, dense_offsets, sparse.view(-1))
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
else:
dense.view(torch.half).scatter_(0, dense_offsets,
sparse.view(torch.half).view(-1))
return dense.view(m, 2 * k)
def mask_creator(tensor):
"""
Class for creating N:M sparsity masks.
Masks will be created using the N:M ratio, where for every block of
M weights, N will be pruned based on ranked weight value. Each mask
will correspond to the given tensor.
:param N: The number of weights in a group to keep
:param M: The size of a weight group
"""
N = 2
M = 4
mask = None
# for i, tensor in enumerate(tensors):
if tensor.numel() % M != 0:
raise ValueError(
f"Tensor of size {tensor.shape} can't be evenly divided into "
f"{M} groups")
num_groups = tensor.numel() // M
# N:M sparsity for linear layers
tensor_temp = tensor.detach().abs().reshape(num_groups, M)
index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)]
w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
return mask
================================================
FILE: archive/csrc/custom_marlin/utils/marlin_24_perms.py
================================================
'''
Date: 2024-11-08 02:46:07
LastEditors: djw
LastEditTime: 2024-11-08 02:46:41
'''
"""This file is used for /tests and /benchmarks"""
from typing import Dict, List
import numpy
import torch
# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501
#
# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def get_perms_24(num_bits: int):
perm_list: List[int] = []
for i in range(32):
perm1: List[int] = []
col = i // 4
col_o = col // 2
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) +
4 * block)
for j in range(4):
perm_list.extend([p + 1 * j for p in perm1])
perm = numpy.array(perm_list)
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
scale_perm_single: List[int] = []
for i in range(8):
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
return perm, scale_perm, scale_perm_single
marlin_24_perm: Dict[int, torch.Tensor] = {}
marlin_24_scale_perm: Dict[int, List[int]] = {}
marlin_24_scale_perm_single: Dict[int, List[int]] = {}
for num_bits in [4, 8]:
perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits)
marlin_24_perm[num_bits] = perm_24
marlin_24_scale_perm[num_bits] = scale_perm_24
marlin_24_scale_perm_single[num_bits] = scale_perm_single_24
================================================
FILE: archive/csrc/custom_marlin/utils/marlin_perms.py
================================================
'''
Date: 2024-11-08 02:46:47
LastEditors: djw
LastEditTime: 2024-11-08 02:46:55
'''
"""This file is used for /tests and /benchmarks"""
from typing import Dict, List
import numpy
import torch
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
#
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def get_perms(num_bits: int):
perm_list: List[int] = []
for i in range(32):
perm1: List[int] = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = numpy.array(perm_list)
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single
marlin_perm: Dict[int, torch.Tensor] = {}
marlin_scale_perm: Dict[int, List[int]] = {}
marlin_scale_perm_single: Dict[int, List[int]] = {}
for num_bits in [4, 8]:
perm, scale_perm, scale_perm_single = get_perms(num_bits)
marlin_perm[num_bits] = perm
marlin_scale_perm[num_bits] = scale_perm
marlin_scale_perm_single[num_bits] = scale_perm_single
================================================
FILE: archive/csrc/custom_marlin/utils/marlin_utils.py
================================================
"""This file is used for /tests and /benchmarks"""
import random
import numpy
import torch
from .format24 import (
mask_creator, sparse_semi_structured_from_dense_cutlass)
from .marlin_24_perms import (
marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single)
from .marlin_perms import (
marlin_perm, marlin_scale_perm, marlin_scale_perm_single)
from .quant_utils import (
get_pack_factor, quantize_weights, sort_weights, dequantize_weights)
__cuda_arch = torch.cuda.get_device_capability()
MARLIN_TILE = 16
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
def is_marlin_supported():
return __cuda_arch[0] >= 8
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):
assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // tile, size_n * tile))
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
return q_w
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
# Permute
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
# Pack
pack_factor = get_pack_factor(num_bits)
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
dtype=numpy.uint32)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << num_bits * i
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
return q_packed
def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,
scale_perm_single):
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
def marlin_quantize(
w: torch.Tensor,
num_bits: int,
group_size: int,
act_order: bool,
):
size_k, size_n = w.shape
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
act_order)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
if act_order:
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Reformat to marlin
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits,
marlin_perm[num_bits])
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size,
marlin_scale_perm[num_bits],
marlin_scale_perm_single[num_bits])
# Create result
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device)
return res_list
def inject_24(w, size_k, size_n):
assert w.shape == (size_k, size_n)
mask = mask_creator(w.t()).t().cuda().bool()
return (mask * w).contiguous(), mask.contiguous()
def check_24(w, num_rows_to_sample=50, _verbose=False):
BLOCK_SIZE = 4
MAX_NON_ZEROS = 2
w = w.t().contiguous()
print("check_24: w.shape = {}".format(w.shape))
num_rows, num_cols = w.shape
sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
if _verbose:
print(f"Sampled row idxs = {sampled_row_idxs}")
total_segments = 0
non_24_segments = 0
for i in sampled_row_idxs:
for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
total_segments += 1
block = w[i, j:j + BLOCK_SIZE]
num_nonzero = torch.count_nonzero(block)
if num_nonzero > MAX_NON_ZEROS:
print("i = {} j = {} block = {}".format(i, j, block))
non_24_segments += 1
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
assert q_24.shape == (size_k, size_n)
# Remove zp to normalize over 0
max_q_val = (1 << num_bits) - 1
zp = (max_q_val + 1) // 2
q_24_no_zp = q_24 - zp
# Compress
q_24_no_zp = q_24_no_zp.t().contiguous()
q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(
q_24_no_zp)
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
# Restore zp
q_24_comp = q_24_no_zp_comp + zp
# Resize meta to its actual shape (without moving any data)
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
return q_24_comp, meta
def marlin_24_quantize(
w: torch.Tensor,
num_bits: int,
group_size: int,
):
size_k, size_n = w.shape
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Inject 2:4 sparsity
w_24, mask_24 = inject_24(w, size_k, size_n)
# Quantize
w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,
num_bits,
group_size,
act_order=False)
# Compress quantized weight
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
num_bits)
size_k_comp = size_k // 2
# Reformat to marlin
marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
num_bits, marlin_24_perm[num_bits])
marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size,
marlin_24_scale_perm[num_bits],
marlin_24_scale_perm_single[num_bits])
# Create result
res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device)
return res_list
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
class MarlinWorkspace:
def __init__(self, out_features, min_thread_n, max_parallel, device):
assert (out_features % min_thread_n == 0), (
"out_features = {} is undivisible by min_thread_n = {}".format(
out_features, min_thread_n))
max_workspace_size = ((out_features // min_thread_n) * max_parallel)
self.scratch = torch.zeros(max_workspace_size,
dtype=torch.int,
device=device)
================================================
FILE: archive/csrc/custom_marlin/utils/quant_utils.py
================================================
"""This file is used for /tests and /benchmarks"""
import numpy
import torch
SUPPORTED_NUM_BITS = [4, 8]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
def get_pack_factor(num_bits):
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
assert q_w.shape == w_ref.shape
orig_device = q_w.device
k_size, _ = q_w.shape
g_idx = torch.zeros((k_size, ), dtype=torch.int32)
for i in range(k_size):
g_idx[i] = i // group_size
# Simulate act_order by doing a random permutation on K
rand_perm = torch.randperm(k_size)
g_idx = g_idx[rand_perm].contiguous()
q_w = q_w[rand_perm, :].contiguous()
w_ref = w_ref[rand_perm, :].contiguous()
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
g_idx.to(device=orig_device),
rand_perm.to(device=orig_device),
)
# Function: Dequantize quantized weights
def dequantize_weights(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):
# Create a tensor for bitwise right shift operation
wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=device).unsqueeze(0)
# Apply bitwise right shift and convert qzeros to the appropriate type
zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)
# Reshape the zeros tensor
zeros = zeros + 1
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
# Reshape the scales tensor
scales = scales.reshape(-1, 1, scales.shape[-1])
# Similar bitwise right shift operation for qweight and reshape
weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)
weight = weight.reshape(-1, group_size, weight.shape[2])
# Apply dequantization formula and reshape the final weight
weight = (scales * (weight - zeros))
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
# Return the transposed weight
return weight.transpose(0, 1)
def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
act_order: bool):
orig_device = w.device
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"
if group_size == -1:
group_size = size_k
assert group_size <= size_k
max_q_val = 2**num_bits - 1
half_q_val = (max_q_val + 1) // 2
# Reshape to [groupsize, -1]
if group_size < size_k:
w = w.view((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
# Compute scale for each group
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
s *= 2 / max_q_val # 2 => symmetric
# Quantize
q_w = torch.round(w / s).int()
q_w += half_q_val
q_w = torch.clamp(q_w, 0, max_q_val)
# Compute ref (dequantized)
w_ref = (q_w - half_q_val).half() * s
# Restore original shapes
if group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
q_w = reshape_w(q_w)
w_ref = reshape_w(w_ref)
s = s.reshape((-1, size_n)).contiguous()
# Apply act_order
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
if act_order:
assert (
group_size < size_k
), "For act_order, groupsize = {} must be less than size_k = {}".format(
group_size, size_k)
w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
s.to(device=orig_device),
g_idx.to(device=orig_device),
rand_perm.to(device=orig_device),
)
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
orig_device = q_w.device
sort_indices = torch.argsort(g_idx).to(
dtype=torch.int32) # Sort based on g_idx
g_idx = g_idx[sort_indices].contiguous()
q_w = q_w[sort_indices, :].contiguous()
return (
q_w.to(device=orig_device),
g_idx.to(device=orig_device),
sort_indices.to(device=orig_device),
)
def gptq_pack(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_k % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[i::pack_factor, :] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
return q_res
def gptq_unpack(
q_res: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = 32 // num_bits
assert size_k % pack_factor == 0
orig_device = q_res.device
q_res = q_res.cpu().numpy()
q_w = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
for i in range(pack_factor):
q_w[i::pack_factor, :] = (q_res >> (num_bits * i)) & ((1 << num_bits) - 1)
q_w = torch.from_numpy(q_w.astype(numpy.int32)).to(orig_device)
return q_w
================================================
FILE: archive/csrc/ktransformers_ext/CMakeLists.txt
================================================
cmake_minimum_required(VERSION 3.16)
project(cpuinfer_ext VERSION 0.1.0)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -ffast-math -fopenmp")
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})
set(CMAKE_BUILD_TYPE "Release")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -ffast-math -fopenmp")
# set(CMAKE_BUILD_TYPE "Debug")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
include(CheckCXXCompilerFlag)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
option(LLAMA_NATIVE "llama: enable -march=native flag" ON)
# instruction set specific
if (LLAMA_NATIVE)
set(INS_ENB OFF)
else()
set(INS_ENB ON)
endif()
option(LLAMA_AVX "llama: enable AVX" OFF)
option(LLAMA_AVX2 "llama: enable AVX2" OFF)
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
option(LLAMA_FMA "llama: enable FMA" OFF)
# in MSVC F16C is implied with AVX2/AVX512
if (NOT MSVC)
option(LLAMA_F16C "llama: enable F16C" OFF)
endif()
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" ON)
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF)
option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF)
option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF)
if(KTRANSFORMERS_USE_NPU)
add_definitions(-DKTRANSFORMERS_USE_NPU=1)
endif()
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
if (MSVC)
string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR)
message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}")
else ()
set(CMAKE_GENERATOR_PLATFORM_LWR "")
endif ()
if (NOT MSVC)
if (LLAMA_STATIC)
add_link_options(-static)
if (MINGW)
add_link_options(-static-libgcc -static-libstdc++)
endif()
endif()
if (LLAMA_GPROF)
add_compile_options(-pg)
endif()
endif()
set(ARCH_FLAGS "")
if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$"))
message(STATUS "ARM detected")
if (MSVC)
add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead
add_compile_definitions(__ARM_NEON)
add_compile_definitions(__ARM_FEATURE_FMA)
set(CMAKE_REQUIRED_FLAGS_PREV ${CMAKE_REQUIRED_FLAGS})
string(JOIN " " CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS} "/arch:armv8.2")
check_cxx_source_compiles("#include \nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_DOTPROD)
if (GGML_COMPILER_SUPPORT_DOTPROD)
add_compile_definitions(__ARM_FEATURE_DOTPROD)
endif ()
check_cxx_source_compiles("#include \nint main() { float16_t _a; float16x8_t _s = vdupq_n_f16(_a); return 0; }" GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
if (GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
endif ()
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV})
else()
if(KTRANSFORMERS_USE_NPU)
list(APPEND ARCH_FLAGS -march=armv8.2-a+fp16+fp16fml+dotprod -lnuma)
endif()
check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)
if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
list(APPEND ARCH_FLAGS -mfp16-format=ieee)
endif()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6")
# Raspberry Pi 1, Zero
list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access)
endif()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7")
if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Android")
# Android armeabi-v7a
list(APPEND ARCH_FLAGS -mfpu=neon-vfpv4 -mno-unaligned-access -funsafe-math-optimizations)
else()
# Raspberry Pi 2
list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations)
endif()
endif()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8")
# Android arm64-v8a
# Raspberry Pi 3, 4, Zero 2 (32-bit)
list(APPEND ARCH_FLAGS -mno-unaligned-access)
endif()
endif()
elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
message(STATUS "x86 detected")
if(NOT KTRANSFORMERS_USE_NPU)
set(HOST_IS_X86 TRUE)
set(HAS_AVX512 TRUE)
set(__HAS_AMX__ TRUE)
add_compile_definitions(__x86_64__)
# check AVX512
execute_process(
COMMAND lscpu
OUTPUT_VARIABLE LSCPU_OUTPUT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}")
string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F)
if (COMPILER_SUPPORTS_AVX512F GREATER -1)
message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)")
add_compile_definitions(__HAS_AVX512F__)
else()
message(STATUS "Compiler and/or CPU do NOT support AVX512F")
set(HAS_AVX512 False)
endif()
# check AMX
string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX)
if(COMPILER_SUPPORTS_AMX GREATER -1)
message(STATUS "Compiler supports AMX")
add_compile_definitions(__HAS_AMX__)
else()
message(STATUS "Compiler does NOT support AMX")
endif()
endif()
if (MSVC)
# instruction set detection for MSVC only
if (LLAMA_NATIVE)
include(cmake/FindSIMD.cmake)
endif ()
if (LLAMA_AVX512)
list(APPEND ARCH_FLAGS /arch:AVX512)
# MSVC has no compile-time flags enabling specific
# AVX512 extensions, neither it defines the
# macros corresponding to the extensions.
# Do it manually.
if (LLAMA_AVX512_VBMI)
add_compile_definitions($<$:__AVX512VBMI__>)
add_compile_definitions($<$:__AVX512VBMI__>)
endif()
if (LLAMA_AVX512_VNNI)
add_compile_definitions($<$:__AVX512VNNI__>)
add_compile_definitions($<$:__AVX512VNNI__>)
endif()
if (LLAMA_AVX512_FANCY_SIMD)
add_compile_definitions($<$:__AVX512VL__>)
add_compile_definitions($<$:__AVX512VL__>)
add_compile_definitions($<$:__AVX512BW__>)
add_compile_definitions($<$:__AVX512BW__>)
add_compile_definitions($<$:__AVX512DQ__>)
add_compile_definitions($<$:__AVX512DQ__>)
add_compile_definitions($<$:__AVX512VNNI__>)
add_compile_definitions($<$:__AVX512VNNI__>)
endif()
if (LLAMA_AVX512_BF16)
add_compile_definitions($<$:__AVX512BF16__>)
add_compile_definitions($<$:__AVX512BF16__>)
endif()
elseif (LLAMA_AVX2)
list(APPEND ARCH_FLAGS /arch:AVX2)
elseif (LLAMA_AVX)
list(APPEND ARCH_FLAGS /arch:AVX)
endif()
else()
if (LLAMA_NATIVE)
list(APPEND ARCH_FLAGS -mfma -mavx -mavx2)
list(APPEND ARCH_FLAGS -march=native)
endif()
if (LLAMA_F16C)
list(APPEND ARCH_FLAGS -mf16c)
endif()
if (LLAMA_FMA)
list(APPEND ARCH_FLAGS -mfma)
endif()
if (LLAMA_AVX)
list(APPEND ARCH_FLAGS -mavx)
endif()
if (LLAMA_AVX2)
list(APPEND ARCH_FLAGS -mavx2)
endif()
if (LLAMA_AVX512)
list(APPEND ARCH_FLAGS -mavx512f)
list(APPEND ARCH_FLAGS -mavx512bw)
endif()
if (LLAMA_AVX512_VBMI)
list(APPEND ARCH_FLAGS -mavx512vbmi)
endif()
if (LLAMA_AVX512_VNNI)
list(APPEND ARCH_FLAGS -mavx512vnni)
endif()
if (LLAMA_AVX512_FANCY_SIMD)
message(STATUS "AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI enabled")
list(APPEND ARCH_FLAGS -mavx512vl)
list(APPEND ARCH_FLAGS -mavx512bw)
list(APPEND ARCH_FLAGS -mavx512dq)
list(APPEND ARCH_FLAGS -mavx512vnni)
list(APPEND ARCH_FLAGS -mavx512vpopcntdq)
endif()
if (LLAMA_AVX512_BF16)
list(APPEND ARCH_FLAGS -mavx512bf16)
endif()
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
message(STATUS "PowerPC detected")
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
list(APPEND ARCH_FLAGS -mcpu=powerpc64le)
else()
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)
#TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
endif()
else()
message(STATUS "Unknown architecture")
endif()
# message(STATUS "CUDAToolkit_ROOT:${CUDAToolkit_ROOT}")
# find_package(FindCUDAToolkit REQUIRED)
# if(CUDAToolkit_FOUND)
# message(STATUS "Found CUDA cudart lib at:${CUDAToolkit_LIBRARY_DIR}")
# else()
# message(STATUS "Can't found CUDA lib")
# endif()
if (NOT EXISTS $ENV{ROCM_PATH})
if (NOT EXISTS /opt/rocm)
set(ROCM_PATH /usr)
else()
set(ROCM_PATH /opt/rocm)
endif()
else()
set(ROCM_PATH $ENV{ROCM_PATH})
endif()
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake")
if (NOT EXISTS $ENV{MUSA_PATH})
if (NOT EXISTS /opt/musa)
set(MUSA_PATH /usr/local/musa)
else()
set(MUSA_PATH /opt/musa)
endif()
else()
set(MUSA_PATH $ENV{MUSA_PATH})
endif()
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
add_compile_options("$<$:${ARCH_FLAGS}>")
add_compile_options("$<$:${ARCH_FLAGS}>")
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/pybind11 ${CMAKE_CURRENT_BINARY_DIR}/third_party/pybind11)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/llama.cpp)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
if (WIN32)
include_directories("$ENV{CUDA_PATH}/include")
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
elseif (UNIX)
if (KTRANSFORMERS_USE_ROCM)
find_package(HIP REQUIRED)
if(HIP_FOUND)
include_directories("${HIP_INCLUDE_DIRS}")
add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)
endif()
elseif (KTRANSFORMERS_USE_MUSA)
if (NOT EXISTS $ENV{MUSA_PATH})
if (NOT EXISTS /opt/musa)
set(MUSA_PATH /usr/local/musa)
else()
set(MUSA_PATH /opt/musa)
endif()
else()
set(MUSA_PATH $ENV{MUSA_PATH})
endif()
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
find_package(MUSAToolkit)
if (MUSAToolkit_FOUND)
message(STATUS "MUSA Toolkit found")
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
endif()
elseif (KTRANSFORMERS_USE_XPU)
add_compile_definitions(KTRANSFORMERS_USE_XPU=1)
elseif (KTRANSFORMERS_USE_CUDA)
find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA detected")
find_package(CUDAToolkit REQUIRED)
include_directories(${CUDAToolkit_INCLUDE_DIRS})
endif()
message(STATUS "enabling CUDA")
enable_language(CUDA)
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
endif()
endif()
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3)
# aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4)
file(GLOB LLAMAFILE_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile/*.cpp")
list(REMOVE_ITEM LLAMAFILE_SOURCES
"${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile/sgemm_arm.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile/sgemm_x86.cpp"
)
set(SOURCE_DIR4 ${LLAMAFILE_SOURCES})
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5)
if (HOST_IS_X86 AND HAS_AVX512 AND __HAS_AMX__)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/amx SOURCE_DIR6)
endif()
set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6})
file(GLOB_RECURSE FMT_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.h")
add_custom_target(
format
COMMAND clang-format
-i
-style=file
${FMT_SOURCES}
COMMENT "Running clang-format on all source files"
)
add_library(llamafile STATIC ${SOURCE_DIR4})
message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
message(STATUS "ARCH_FLAGS: ${ARCH_FLAGS}")
pybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES})
target_link_libraries(${PROJECT_NAME} PRIVATE llama)
if(WIN32)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
elseif(UNIX)
if (KTRANSFORMERS_USE_ROCM)
add_compile_definitions(USE_HIP=1)
target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so")
message(STATUS "Building for HIP")
elseif(KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
elseif(KTRANSFORMERS_USE_XPU)
elseif(KTRANSFORMERS_USE_CUDA AND NOT KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so")
endif()
endif()
# Define the USE_NUMA option
option(USE_NUMA "Disable NUMA support" OFF)
# Check if the USE_NUMA environment variable is set
if(DEFINED ENV{USE_NUMA})
set(USE_NUMA ON)
endif()
if(USE_NUMA)
message(STATUS "NUMA support is enabled")
else()
message(STATUS "NUMA support is disabled")
endif()
find_library(NUMA_LIBRARY NAMES numa)
if(NUMA_LIBRARY AND USE_NUMA)
message(STATUS "NUMA library found: ${NUMA_LIBRARY} - enabling NUMA support")
target_link_libraries(${PROJECT_NAME} PRIVATE ${NUMA_LIBRARY})
target_compile_definitions(${PROJECT_NAME} PRIVATE USE_NUMA)
else()
if(USE_NUMA)
message(FATAL_ERROR "NUMA library not found - maybe sudo apt install libnuma-dev")
else()
message(STATUS "NUMA library not found or user not set USE_NUMA - disabling NUMA support")
endif()
endif()
================================================
FILE: archive/csrc/ktransformers_ext/bench/bench_attention.py
================================================
#!/usr/bin/env python
# coding=utf-8
"""
Description :
Author : Jianwei Dong
Date : 2024-08-28 10:32:05
Version : 1.0.0
LastEditors : Jianwei Dong
LastEditTime : 2024-08-28 10:32:05
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import os, sys
import time
sys.path.append(os.path.dirname(__file__) + "/../build")
import cpuinfer_ext
import torch
layer_num = 10
kv_head_num = 8
q_head_num = 32
head_dim = 128
block_len = 128
anchor_num = 1
anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC
kv_type = cpuinfer_ext.kvcache.ggml_type.FP16
retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER
layer_step: int = 1
token_step: int = 1
layer_offset: int = 0
max_thread_num: int = 64
max_batch_size: int = 1
max_block_num: int = 1024
CPUInfer = cpuinfer_ext.CPUInfer(max_thread_num)
warm_up_iter = 1000
test_iter = 10000
def bench_linear(cache_seqlen: int):
with torch.inference_mode(mode=True):
cache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device="cpu")
seqlens_zero = torch.zeros((1,), dtype=torch.int32, device="cpu")
config = cpuinfer_ext.kvcache.KVCacheConfig(
layer_num,
kv_head_num,
q_head_num,
head_dim,
block_len,
anchor_num,
anchor_type,
kv_type,
retrieval_type,
layer_step,
token_step,
layer_offset,
max_block_num,
max_batch_size,
max_thread_num,
)
local_kvcache = cpuinfer_ext.kvcache.KVCache(config)
block_table = (
torch.arange(max_block_num, dtype=torch.int32, device="cpu")
.contiguous()
.view(1, -1)
)
for layer_idx in range(layer_num):
k_cache = torch.randn(
(1, cache_seqlen, kv_head_num, head_dim),
dtype=torch.float16,
device="cpu",
).contiguous()
v_cache = torch.randn(
(1, cache_seqlen, kv_head_num, head_dim),
dtype=torch.float16,
device="cpu",
).contiguous()
CPUInfer.submit(
local_kvcache.update_kvcache_fp16(
k_cache.data_ptr(),
v_cache.data_ptr(),
layer_idx,
block_table.data_ptr(),
1,
max_block_num,
seqlens_zero.data_ptr(),
cache_seqlen,
)
)
CPUInfer.sync()
input = torch.randn(
(1, 1, q_head_num, head_dim), dtype=torch.float16, device="cpu"
).contiguous()
output = torch.empty(
(1, 1, q_head_num, head_dim), dtype=torch.float16, device="cpu"
).contiguous()
# attn_lse: (bsz, q_len, q_head_num)
attn_lse = torch.empty(
(1, 1, q_head_num), dtype=torch.float32, device="cpu"
).contiguous()
input = input / 100
# warm up
for i in range(warm_up_iter):
CPUInfer.submit(
local_kvcache.attn(
input.data_ptr(),
output.data_ptr(),
attn_lse.data_ptr(),
i % layer_num,
0,
1,
1,
max_block_num,
block_table.data_ptr(),
cache_seqlens.data_ptr(),
-1,
-1,
-1,
)
)
CPUInfer.sync()
# test
start = time.perf_counter()
for i in range(test_iter):
CPUInfer.submit(
local_kvcache.attn(
input.data_ptr(),
output.data_ptr(),
attn_lse.data_ptr(),
i % layer_num,
0,
1,
1,
max_block_num,
block_table.data_ptr(),
cache_seqlens.data_ptr(),
-1,
-1,
-1,
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time = end - start
print("cache sequence length: ", cache_seqlen)
print("Time(s): ", total_time)
print("Iteration: ", test_iter)
print("Time(us) per iteration: ", total_time / test_iter * 1000000)
print(
"Bandwidth: ",
cache_seqlen
* kv_head_num
* head_dim
* 2
* 2
* test_iter
/ total_time
/ 1000
/ 1000
/ 1000,
"GB/s",
)
print("")
bench_linear(1024)
bench_linear(4096)
bench_linear(16384)
bench_linear(32768)
bench_linear(65536)
================================================
FILE: archive/csrc/ktransformers_ext/bench/bench_attention_torch.py
================================================
#!/usr/bin/env python
# coding=utf-8
"""
Description :
Author : Jianwei Dong
Date : 2024-08-28 10:32:05
Version : 1.0.0
LastEditors : Jianwei Dong
LastEditTime : 2024-08-28 10:32:05
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""
import os, sys
import time
sys.path.append(os.path.dirname(__file__) + "/../build")
import cpuinfer_ext
import torch
layer_num = 10
kv_head_num = 8
q_head_num = 32
head_dim = 128
block_len = 128
anchor_num = 1
warm_up_iter = 1000
test_iter = 10000
def bench_linear(cache_seqlen: int, device):
with torch.inference_mode(mode=True):
kvcaches = []
for layer_idx in range(layer_num):
k_cache = torch.randn(
(1, 32, cache_seqlen, head_dim),
dtype=torch.float16,
device=device,
).contiguous()
v_cache = torch.randn(
(1, 32, cache_seqlen, head_dim),
dtype=torch.float16,
device=device,
).contiguous()
kvcaches.append((k_cache, v_cache))
input = torch.randn(
(1, q_head_num, 1, head_dim), dtype=torch.float16, device=device
).contiguous()
input = input / 100
# warm up
for i in range(warm_up_iter):
k_cache = kvcaches[i % layer_num][0]
v_cache = kvcaches[i % layer_num][1]
torch.nn.functional.scaled_dot_product_attention(input, k_cache, v_cache)
# test
start = time.perf_counter()
for i in range(test_iter):
k_cache = kvcaches[i % layer_num][0]
v_cache = kvcaches[i % layer_num][1]
torch.nn.functional.scaled_dot_product_attention(input, k_cache, v_cache)
end = time.perf_counter()
total_time = end - start
print("cache sequence length: ", cache_seqlen)
print("Time(s): ", total_time)
print("Iteration: ", test_iter)
print("Time(us) per iteration: ", total_time / test_iter * 1000000)
print(
"Bandwidth: ",
cache_seqlen
* q_head_num
* head_dim
* 2
* 2
* test_iter
/ total_time
/ 1000
/ 1000
/ 1000,
"GB/s",
)
print("")
bench_linear(1024, "cpu")
bench_linear(4096, "cpu")
bench_linear(1024, "cuda")
bench_linear(4096, "cuda")
bench_linear(16384, "cuda")
bench_linear(32768, "cuda")
bench_linear(65536, "cuda")
================================================
FILE: archive/csrc/ktransformers_ext/bench/bench_linear.py
================================================
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : chenht2022
Date : 2024-07-25 10:31:59
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2024-08-06 10:35:35
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import os, sys
import time
sys.path.append(os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
import torch
input_size = 16384
output_size = 5120
stride = 16
group_max_len = 1024
layer_num = 10
qlen = 1
CPUInfer = cpuinfer_ext.CPUInfer(64)
warm_up_iter = 1000
test_iter = 10000
def bench_linear(quant_mode: str):
with torch.inference_mode(mode=True):
hidden_type = 30 # ggml_type::GGML_TYPE_BF16
if quant_mode == "fp32":
proj_type = 0 # ggml_type::GGML_TYPE_F32
bytes_per_elem = 4.000000
elif quant_mode == "fp16":
proj_type = 1 # ggml_type::GGML_TYPE_F16
bytes_per_elem = 2.000000
elif quant_mode == "bf16":
proj_type = 30 # ggml_type::GGML_TYPE_BF16
bytes_per_elem = 2.000000
elif quant_mode == "q8_0":
proj_type = 8 # ggml_type::GGML_TYPE_Q8_0
bytes_per_elem = 1.062500
elif quant_mode == "q6_k":
proj_type = 14 # ggml_type::GGML_TYPE_Q6_K
bytes_per_elem = 0.820312
elif quant_mode == "q5_k_m":
proj_type = 13 # ggml_type::GGML_TYPE_Q5_K
bytes_per_elem = 0.687500
elif quant_mode == "q4_k_m":
proj_type = 12 # ggml_type::GGML_TYPE_Q4_K
bytes_per_elem = 0.562500
elif quant_mode == "q3_k_m":
proj_type = 11 # ggml_type::GGML_TYPE_Q3_K
bytes_per_elem = 0.429688
elif quant_mode == "q2_k":
proj_type = 10 # ggml_type::GGML_TYPE_Q2_K
bytes_per_elem = 0.328125
elif quant_mode == "iq3_xs":
proj_type = 21 # ggml_type::GGML_TYPE_IQ3_S
bytes_per_elem = 0.429688
elif quant_mode == "iq2_xxs":
proj_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS
bytes_per_elem = 0.257812
else:
assert(False)
linears = []
projs = []
for _ in range(layer_num):
proj = torch.randn((output_size, input_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
config = cpuinfer_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type)
linear = cpuinfer_ext.linear.Linear(config)
projs.append(proj)
linears.append(linear)
input = torch.randn((layer_num, qlen, input_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
output = torch.empty((layer_num, qlen, output_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
# warm up
for i in range(warm_up_iter):
CPUInfer.submit(
linears[i % layer_num].forward(
qlen,
input[i % layer_num].data_ptr(),
output[i % layer_num].data_ptr()
)
)
CPUInfer.sync()
# test
start = time.perf_counter()
for i in range(test_iter):
CPUInfer.submit(
linears[i % layer_num].forward(
qlen,
input[i % layer_num].data_ptr(),
output[i % layer_num].data_ptr()
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time = end - start
print('Quant mode: ', quant_mode)
print('Time(s): ', total_time)
print('Iteration: ', test_iter)
print('Time(us) per iteration: ', total_time / test_iter * 1000000)
print('Bandwidth: ', input_size * output_size * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')
print('')
bench_linear("fp32")
bench_linear("fp16")
bench_linear("bf16")
bench_linear("q8_0")
bench_linear("q6_k")
bench_linear("q5_k_m")
bench_linear("q4_k_m")
bench_linear("q3_k_m")
bench_linear("q2_k")
# Not supported on __x86_64__
# bench_linear("iq3_xs")
# bench_linear("iq2_xxs")
================================================
FILE: archive/csrc/ktransformers_ext/bench/bench_linear_torch.py
================================================
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : chenht2022
Date : 2024-07-25 10:31:59
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2024-07-25 10:32:48
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import os, sys
import time
import torch
import torch.nn.quantized as nnq
scale, zero_point = 0.1, 0 # Adjust scale and zero_point based on your dataset
input_size = 16384
output_size = 5120
layer_num = 10
qlen = 1
warm_up_iter = 1000
test_iter = 10000
def bench_linear(quant_mode: str):
with torch.inference_mode(mode=True):
if quant_mode == "fp32":
proj_type = torch.float32
bytes_per_elem = 4.000000
elif quant_mode == "fp16":
proj_type = torch.float16
bytes_per_elem = 2.000000
elif quant_mode == "bf16":
proj_type = torch.bfloat16
bytes_per_elem = 2.000000
elif quant_mode == "qint8":
proj_type = torch.qint8
bytes_per_elem = 1.000000
else:
assert(False)
projs = []
for _ in range(layer_num):
proj = torch.randn((output_size, input_size), dtype = torch.float32, device = "cuda").to("cpu").contiguous()
if quant_mode == "qint8":
proj_q = torch.quantize_per_tensor(proj, scale, zero_point, torch.qint8)
quantized_layer = nnq.Linear(input_size, output_size)
quantized_layer.set_weight_bias(proj_q, None)
projs.append(quantized_layer)
else:
projs.append(proj.to(proj_type))
input = torch.randn((layer_num, qlen, input_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
# warm up
for i in range(warm_up_iter):
if isinstance(projs[i % layer_num], nnq.Linear):
input_q = torch.quantize_per_tensor(input[i % layer_num].to(torch.float32), scale, zero_point, torch.quint8)
t_output = projs[i % layer_num](input_q)
else:
t_output = torch.mm(input[i % layer_num].to(proj_type), projs[i % layer_num].t())
# test
start = time.perf_counter()
for i in range(test_iter):
if isinstance(projs[i % layer_num], nnq.Linear):
input_q = torch.quantize_per_tensor(input[i % layer_num].to(torch.float32), scale, zero_point, torch.quint8)
t_output = projs[i % layer_num](input_q)
else:
t_output = torch.mm(input[i % layer_num].to(proj_type), projs[i % layer_num].t())
end = time.perf_counter()
total_time = end - start
print('Quant mode: ', quant_mode)
print('Time(s): ', total_time)
print('Iteration: ', test_iter)
print('Time(us) per iteration: ', total_time / test_iter * 1000000)
print('Bandwidth: ', input_size * output_size * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')
print('')
bench_linear("fp32")
bench_linear("fp16")
bench_linear("bf16")
bench_linear("qint8")
================================================
FILE: archive/csrc/ktransformers_ext/bench/bench_mlp.py
================================================
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : chenht2022
Date : 2024-07-16 10:43:18
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2024-08-06 10:36:04
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import os, sys
import time
sys.path.append(os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
import torch
hidden_size = 5120
intermediate_size = 3072
stride = 16
group_max_len = 1024
layer_num = 10
qlen = 1
CPUInfer = cpuinfer_ext.CPUInfer(64)
warm_up_iter = 1000
test_iter = 10000
def bench_mlp(quant_mode: str):
with torch.inference_mode(mode=True):
hidden_type = 30 # ggml_type::GGML_TYPE_BF16
if quant_mode == "fp32":
gate_type = 0 # ggml_type::GGML_TYPE_F32
up_type = 0 # ggml_type::GGML_TYPE_F32
down_type = 0 # ggml_type::GGML_TYPE_F32
bytes_per_elem = 4.000000
elif quant_mode == "fp16":
gate_type = 1 # ggml_type::GGML_TYPE_F16
up_type = 1 # ggml_type::GGML_TYPE_F16
down_type = 1 # ggml_type::GGML_TYPE_F16
bytes_per_elem = 2.000000
elif quant_mode == "bf16":
gate_type = 30 # ggml_type::GGML_TYPE_BF16
up_type = 30 # ggml_type::GGML_TYPE_BF16
down_type = 30 # ggml_type::GGML_TYPE_BF16
bytes_per_elem = 2.000000
elif quant_mode == "q8_0":
gate_type = 8 # ggml_type::GGML_TYPE_Q8_0
up_type = 8 # ggml_type::GGML_TYPE_Q8_0
down_type = 8 # ggml_type::GGML_TYPE_Q8_0
bytes_per_elem = 1.062500
elif quant_mode == "q6_k":
gate_type = 14 # ggml_type::GGML_TYPE_Q6_K
up_type = 14 # ggml_type::GGML_TYPE_Q6_K
down_type = 14 # ggml_type::GGML_TYPE_Q6_K
bytes_per_elem = 0.820312
elif quant_mode == "q5_k_m":
gate_type = 13 # ggml_type::GGML_TYPE_Q5_K
up_type = 13 # ggml_type::GGML_TYPE_Q5_K
down_type = 14 # ggml_type::GGML_TYPE_Q6_K
bytes_per_elem = 0.731771
elif quant_mode == "q4_k_m":
gate_type = 12 # ggml_type::GGML_TYPE_Q4_K
up_type = 12 # ggml_type::GGML_TYPE_Q4_K
down_type = 14 # ggml_type::GGML_TYPE_Q6_K
bytes_per_elem = 0.648437
elif quant_mode == "q3_k_m":
gate_type = 11 # ggml_type::GGML_TYPE_Q3_K
up_type = 11 # ggml_type::GGML_TYPE_Q3_K
down_type = 13 # ggml_type::GGML_TYPE_Q5_K
bytes_per_elem = 0.515625
elif quant_mode == "q2_k":
gate_type = 10 # ggml_type::GGML_TYPE_Q2_K
up_type = 10 # ggml_type::GGML_TYPE_Q2_K
down_type = 11 # ggml_type::GGML_TYPE_Q3_K
bytes_per_elem = 0.328125
elif quant_mode == "iq3_xs":
gate_type = 21 # ggml_type::GGML_TYPE_IQ3_S
up_type = 21 # ggml_type::GGML_TYPE_IQ3_S
down_type = 21 # ggml_type::GGML_TYPE_IQ3_S
bytes_per_elem = 0.429688
elif quant_mode == "iq2_xxs":
gate_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS
up_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS
down_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS
bytes_per_elem = 0.257812
else:
assert(False)
mlps = []
gate_projs = []
up_projs = []
down_projs = []
for _ in range(layer_num):
gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
config = cpuinfer_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)
mlp = cpuinfer_ext.mlp.MLP(config)
gate_projs.append(gate_proj)
up_projs.append(up_proj)
down_projs.append(down_proj)
mlps.append(mlp)
input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
# warm up
for i in range(warm_up_iter):
CPUInfer.submit(
mlps[i % layer_num].forward(
qlen,
input[i % layer_num].data_ptr(),
output[i % layer_num].data_ptr()
)
)
CPUInfer.sync()
# test
start = time.perf_counter()
for i in range(test_iter):
CPUInfer.submit(
mlps[i % layer_num].forward(
qlen,
input[i % layer_num].data_ptr(),
output[i % layer_num].data_ptr()
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time = end - start
print('Quant mode: ', quant_mode)
print('Time(s): ', total_time)
print('Iteration: ', test_iter)
print('Time(us) per iteration: ', total_time / test_iter * 1000000)
print('Bandwidth: ', hidden_size * intermediate_size * 3 * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')
print('')
bench_mlp("fp32")
bench_mlp("fp16")
bench_mlp("bf16")
bench_mlp("q8_0")
bench_mlp("q6_k")
bench_mlp("q5_k_m")
bench_mlp("q4_k_m")
bench_mlp("q3_k_m")
bench_mlp("q2_k")
# Not supported on __x86_64__
# bench_linear("iq3_xs")
# bench_linear("iq2_xxs")
================================================
FILE: archive/csrc/ktransformers_ext/bench/bench_mlp_torch.py
================================================
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : chenht2022
Date : 2024-07-16 10:43:18
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2024-07-25 10:32:53
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import os, sys
import time
import torch
import torch.nn.quantized as nnq
scale, zero_point = 0.1, 0 # Adjust scale and zero_point based on your dataset
hidden_size = 5120
intermediate_size = 3072
layer_num = 10
qlen = 1
warm_up_iter = 1000
test_iter = 10000
def act_fn(x):
return x / (1.0 + torch.exp(-x))
def mlp_torch(input, gate_proj, up_proj, down_proj):
if isinstance(gate_proj, nnq.Linear):
input_q = torch.quantize_per_tensor(input.to(torch.float32), scale, zero_point, torch.quint8)
gate_buf = gate_proj(input_q)
up_buf = up_proj(input_q)
gate_buf = gate_buf.dequantize()
up_buf = up_buf.dequantize()
intermediate = act_fn(gate_buf) * up_buf
intermediate_q = torch.quantize_per_tensor(intermediate, scale, zero_point, torch.quint8)
expert_output = down_proj(intermediate_q)
ret = expert_output.dequantize()
else:
gate_buf = torch.mm(input.to(gate_proj.dtype), gate_proj.t())
up_buf = torch.mm(input.to(up_proj.dtype), up_proj.t())
intermediate = act_fn(gate_buf) * up_buf
ret = torch.mm(intermediate.to(down_proj.dtype), down_proj.t())
return ret
def bench_mlp(quant_mode: str):
with torch.inference_mode(mode=True):
if quant_mode == "fp32":
proj_type = torch.float32
bytes_per_elem = 4.000000
elif quant_mode == "fp16":
proj_type = torch.float16
bytes_per_elem = 2.000000
elif quant_mode == "bf16":
proj_type = torch.bfloat16
bytes_per_elem = 2.000000
elif quant_mode == "qint8":
proj_type = torch.qint8
bytes_per_elem = 1.000000
else:
assert(False)
gate_projs = []
up_projs = []
down_projs = []
for _ in range(layer_num):
gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
if quant_mode == "qint8":
gate_proj_q = torch.quantize_per_tensor(gate_proj, scale, zero_point, torch.qint8)
quantized_gate = nnq.Linear(hidden_size, intermediate_size)
quantized_gate.set_weight_bias(gate_proj_q, None)
up_proj_q = torch.quantize_per_tensor(up_proj, scale, zero_point, torch.qint8)
quantized_up = nnq.Linear(hidden_size, intermediate_size)
quantized_up.set_weight_bias(up_proj_q, None)
down_proj_q = torch.quantize_per_tensor(down_proj, scale, zero_point, torch.qint8)
quantized_down = nnq.Linear(intermediate_size, hidden_size)
quantized_down.set_weight_bias(down_proj_q, None)
gate_projs.append(quantized_gate)
up_projs.append(quantized_up)
down_projs.append(quantized_down)
else:
gate_projs.append(gate_proj.to(proj_type))
up_projs.append(up_proj.to(proj_type))
down_projs.append(down_proj.to(proj_type))
input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
# warm up
for i in range(warm_up_iter):
mlp_torch(input[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])
# test
start = time.perf_counter()
for i in range(test_iter):
mlp_torch(input[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])
end = time.perf_counter()
total_time = end - start
print('Quant mode: ', quant_mode)
print('Time(s): ', total_time)
print('Iteration: ', test_iter)
print('Time(us) per iteration: ', total_time / test_iter * 1000000)
print('Bandwidth: ', hidden_size * intermediate_size * 3 * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')
print('')
bench_mlp("fp32")
bench_mlp("fp16")
bench_mlp("bf16")
bench_mlp("qint8")
================================================
FILE: archive/csrc/ktransformers_ext/bench/bench_moe.py
================================================
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : chenht2022
Date : 2024-07-25 10:32:05
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2024-08-06 10:41:28
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import os, sys
import time
sys.path.append(os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
import torch
expert_num = 160
hidden_size = 5120
intermediate_size = 1536
stride = 16
group_min_len = 10
group_max_len = 1024
n_routed_experts = 6
layer_num = 10
qlen = 1
CPUInfer = cpuinfer_ext.CPUInfer(64)
warm_up_iter = 1000
test_iter = 10000
def bench_moe(quant_mode: str):
with torch.inference_mode(mode=True):
hidden_type = 30 # ggml_type::GGML_TYPE_BF16
if quant_mode == "fp32":
gate_type = 0 # ggml_type::GGML_TYPE_F32
up_type = 0 # ggml_type::GGML_TYPE_F32
down_type = 0 # ggml_type::GGML_TYPE_F32
bytes_per_elem = 4.000000
elif quant_mode == "fp16":
gate_type = 1 # ggml_type::GGML_TYPE_F16
up_type = 1 # ggml_type::GGML_TYPE_F16
down_type = 1 # ggml_type::GGML_TYPE_F16
bytes_per_elem = 2.000000
elif quant_mode == "bf16":
gate_type = 30 # ggml_type::GGML_TYPE_BF16
up_type = 30 # ggml_type::GGML_TYPE_BF16
down_type = 30 # ggml_type::GGML_TYPE_BF16
bytes_per_elem = 2.000000
elif quant_mode == "q8_0":
gate_type = 8 # ggml_type::GGML_TYPE_Q8_0
up_type = 8 # ggml_type::GGML_TYPE_Q8_0
down_type = 8 # ggml_type::GGML_TYPE_Q8_0
bytes_per_elem = 1.062500
elif quant_mode == "q6_k":
gate_type = 14 # ggml_type::GGML_TYPE_Q6_K
up_type = 14 # ggml_type::GGML_TYPE_Q6_K
down_type = 14 # ggml_type::GGML_TYPE_Q6_K
bytes_per_elem = 0.820312
elif quant_mode == "q5_k_m":
gate_type = 13 # ggml_type::GGML_TYPE_Q5_K
up_type = 13 # ggml_type::GGML_TYPE_Q5_K
down_type = 14 # ggml_type::GGML_TYPE_Q6_K
bytes_per_elem = 0.731771
elif quant_mode == "q4_k_m":
gate_type = 12 # ggml_type::GGML_TYPE_Q4_K
up_type = 12 # ggml_type::GGML_TYPE_Q4_K
down_type = 14 # ggml_type::GGML_TYPE_Q6_K
bytes_per_elem = 0.648437
elif quant_mode == "q3_k_m":
gate_type = 11 # ggml_type::GGML_TYPE_Q3_K
up_type = 11 # ggml_type::GGML_TYPE_Q3_K
down_type = 13 # ggml_type::GGML_TYPE_Q5_K
bytes_per_elem = 0.515625
elif quant_mode == "q2_k":
gate_type = 10 # ggml_type::GGML_TYPE_Q2_K
up_type = 10 # ggml_type::GGML_TYPE_Q2_K
down_type = 11 # ggml_type::GGML_TYPE_Q3_K
bytes_per_elem = 0.328125
elif quant_mode == "iq3_xs":
gate_type = 21 # ggml_type::GGML_TYPE_IQ3_S
up_type = 21 # ggml_type::GGML_TYPE_IQ3_S
down_type = 21 # ggml_type::GGML_TYPE_IQ3_S
bytes_per_elem = 0.429688
elif quant_mode == "iq2_xxs":
gate_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS
up_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS
down_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS
bytes_per_elem = 0.257812
else:
assert(False)
moes = []
gate_projs = []
up_projs = []
down_projs = []
for _ in range(layer_num):
gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
config = cpuinfer_ext.moe.MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, stride, group_min_len, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)
moe = cpuinfer_ext.moe.MOE(config)
gate_projs.append(gate_proj)
up_projs.append(up_proj)
down_projs.append(down_proj)
moes.append(moe)
expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = "cuda")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to("cpu").contiguous()
weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
# warm up
for i in range(warm_up_iter):
CPUInfer.submit(
moes[i % layer_num].forward(
qlen,
n_routed_experts,
expert_ids[i % layer_num].data_ptr(),
weights[i % layer_num].data_ptr(),
input[i % layer_num].data_ptr(),
output[i % layer_num].data_ptr()
)
)
CPUInfer.sync()
# test
start = time.perf_counter()
for i in range(test_iter):
CPUInfer.submit(
moes[i % layer_num].forward(
qlen,
n_routed_experts,
expert_ids[i % layer_num].data_ptr(),
weights[i % layer_num].data_ptr(),
input[i % layer_num].data_ptr(),
output[i % layer_num].data_ptr()
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time = end - start
print('Quant mode: ', quant_mode)
print('Time(s): ', total_time)
print('Iteration: ', test_iter)
print('Time(us) per iteration: ', total_time / test_iter * 1000000)
print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')
print('')
bench_moe("fp32")
bench_moe("fp16")
bench_moe("bf16")
bench_moe("q8_0")
bench_moe("q6_k")
bench_moe("q5_k_m")
bench_moe("q4_k_m")
bench_moe("q3_k_m")
bench_moe("q2_k")
# Not supported on __x86_64__
# bench_linear("iq3_xs")
# bench_linear("iq2_xxs")
================================================
FILE: archive/csrc/ktransformers_ext/bench/bench_moe_amx.py
================================================
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : chenht2022
Date : 2025-04-25 18:28:12
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2025-04-25 18:28:12
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import os, sys
import time
sys.path.append(os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
import torch
expert_num = 8
hidden_size = 7168
intermediate_size = 2048
max_len = 25600
n_routed_experts = 8
layer_num = 10
qlen = 1024
CPUInfer = cpuinfer_ext.CPUInfer(65)
warm_up_iter = 100
test_iter = 100
def bench_moe(quant_mode: str):
with torch.inference_mode(mode=True):
if quant_mode == "bf16":
bytes_per_elem = 2.000000
elif quant_mode == "int8":
bytes_per_elem = 1.000000
else:
assert(False)
moes = []
gate_projs = []
up_projs = []
down_projs = []
for _ in range(layer_num):
gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
config = cpuinfer_ext.moe.AMX_MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr())
if quant_mode == "bf16":
moe = cpuinfer_ext.moe.AMXBF16_MOE(config)
CPUInfer.submit(moe.load_weights())
CPUInfer.sync()
elif quant_mode == "int8":
moe = cpuinfer_ext.moe.AMXInt8_MOE(config)
CPUInfer.submit(moe.load_weights())
CPUInfer.sync()
gate_projs.append(gate_proj)
up_projs.append(up_proj)
down_projs.append(down_proj)
moes.append(moe)
expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = "cuda")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to("cpu").contiguous()
weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
qlen_tensor = torch.tensor([qlen], dtype=torch.int32)
# warm up
for i in range(warm_up_iter):
CPUInfer.submit(
moes[i % layer_num].forward(
qlen,
n_routed_experts,
expert_ids[i % layer_num].data_ptr(),
weights[i % layer_num].data_ptr(),
input[i % layer_num].data_ptr(),
output[i % layer_num].data_ptr(),
qlen_tensor.data_ptr()
)
)
CPUInfer.sync()
# test
start = time.perf_counter()
for i in range(test_iter):
CPUInfer.submit(
moes[i % layer_num].forward(
qlen,
n_routed_experts,
expert_ids[i % layer_num].data_ptr(),
weights[i % layer_num].data_ptr(),
input[i % layer_num].data_ptr(),
output[i % layer_num].data_ptr(),
qlen_tensor.data_ptr()
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time = end - start
print('Quant mode: ', quant_mode)
print('Time(s): ', total_time)
print('Iteration: ', test_iter)
print('Time(us) per iteration: ', total_time / test_iter * 1000000)
print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')
print('Flops: ', hidden_size * intermediate_size * qlen * 3 * n_routed_experts * 2 * test_iter / total_time / 1000 / 1000 / 1000, 'GFLOPS')
print('')
bench_moe("bf16")
bench_moe("int8")
================================================
FILE: archive/csrc/ktransformers_ext/bench/bench_moe_torch.py
================================================
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : chenht2022
Date : 2024-07-25 10:32:05
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2024-07-25 10:32:57
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import os, sys
import time
import torch
import torch.nn.quantized as nnq
scale, zero_point = 0.1, 0 # Adjust scale and zero_point based on your dataset
expert_num = 160
hidden_size = 5120
intermediate_size = 1536
n_routed_experts = 6
layer_num = 10
qlen = 1
warm_up_iter = 1000
test_iter = 10000
def act_fn(x):
return x / (1.0 + torch.exp(-x))
def mlp_torch(input, gate_proj, up_proj, down_proj):
if isinstance(gate_proj, nnq.Linear):
input_q = torch.quantize_per_tensor(input.to(torch.float32), scale, zero_point, torch.quint8)
gate_buf = gate_proj(input_q)
up_buf = up_proj(input_q)
gate_buf = gate_buf.dequantize()
up_buf = up_buf.dequantize()
intermediate = act_fn(gate_buf) * up_buf
intermediate_q = torch.quantize_per_tensor(intermediate, scale, zero_point, torch.quint8)
expert_output = down_proj(intermediate_q)
ret = expert_output.dequantize()
else:
gate_buf = torch.mm(input.to(gate_proj.dtype), gate_proj.t())
up_buf = torch.mm(input.to(up_proj.dtype), up_proj.t())
intermediate = act_fn(gate_buf) * up_buf
ret = torch.mm(intermediate.to(down_proj.dtype), down_proj.t())
return ret
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
cnts.scatter_(1, expert_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = expert_ids.view(-1).argsort()
sorted_tokens = input[idxs // expert_ids.shape[1]]
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
t_output = (
new_x.view(*expert_ids.shape, -1)
.type(weights.dtype)
.mul_(weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return t_output
def bench_moe(quant_mode: str):
with torch.inference_mode(mode=True):
if quant_mode == "fp32":
proj_type = torch.float32
bytes_per_elem = 4.000000
elif quant_mode == "fp16":
proj_type = torch.float16
bytes_per_elem = 2.000000
elif quant_mode == "bf16":
proj_type = torch.bfloat16
bytes_per_elem = 2.000000
elif quant_mode == "qint8":
proj_type = torch.qint8
bytes_per_elem = 1.000000
else:
assert(False)
gate_projs = []
up_projs = []
down_projs = []
for _ in range(layer_num):
gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
if quant_mode == "qint8":
quantized_gate_proj = []
quantized_up_proj = []
quantized_down_proj = []
for i in range(expert_num):
gate_proj_q = torch.quantize_per_tensor(gate_proj[i], scale, zero_point, torch.qint8)
quantized_gate = nnq.Linear(hidden_size, intermediate_size)
quantized_gate.set_weight_bias(gate_proj_q, None)
quantized_gate_proj.append(quantized_gate)
up_proj_q = torch.quantize_per_tensor(up_proj[i], scale, zero_point, torch.qint8)
quantized_up = nnq.Linear(hidden_size, intermediate_size)
quantized_up.set_weight_bias(up_proj_q, None)
quantized_up_proj.append(quantized_up)
down_proj_q = torch.quantize_per_tensor(down_proj[i], scale, zero_point, torch.qint8)
quantized_down = nnq.Linear(intermediate_size, hidden_size)
quantized_down.set_weight_bias(down_proj_q, None)
quantized_down_proj.append(quantized_down)
gate_projs.append(quantized_gate_proj)
up_projs.append(quantized_up_proj)
down_projs.append(quantized_down_proj)
else:
gate_projs.append(gate_proj.to(proj_type))
up_projs.append(up_proj.to(proj_type))
down_projs.append(down_proj.to(proj_type))
expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = "cuda")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to("cpu").contiguous()
weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = "cuda").to("cpu").contiguous()
input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous()
# warm up
for i in range(warm_up_iter):
moe_torch(input[i % layer_num], expert_ids[i % layer_num], weights[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])
# test
start = time.perf_counter()
for i in range(test_iter):
moe_torch(input[i % layer_num], expert_ids[i % layer_num], weights[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])
end = time.perf_counter()
total_time = end - start
print('Quant mode: ', quant_mode)
print('Time(s): ', total_time)
print('Iteration: ', test_iter)
print('Time(us) per iteration: ', total_time / test_iter * 1000000)
print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')
print('')
bench_moe("fp32")
bench_moe("fp16")
bench_moe("bf16")
bench_moe("qint8")
================================================
FILE: archive/csrc/ktransformers_ext/cmake/FindSIMD.cmake
================================================
include(CheckCSourceRuns)
set(AVX_CODE "
#include
int main()
{
__m256 a;
a = _mm256_set1_ps(0);
return 0;
}
")
set(AVX512_CODE "
#include
int main()
{
__m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0);
__m512i b = a;
__mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
return 0;
}
")
set(AVX2_CODE "
#include
int main()
{
__m256i a = {0};
a = _mm256_abs_epi16(a);
__m256i x;
_mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
return 0;
}
")
set(FMA_CODE "
#include
int main()
{
__m256 acc = _mm256_setzero_ps();
const __m256 d = _mm256_setzero_ps();
const __m256 p = _mm256_setzero_ps();
acc = _mm256_fmadd_ps( d, p, acc );
return 0;
}
")
macro(check_sse type flags)
set(__FLAG_I 1)
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
foreach (__FLAG ${flags})
if (NOT ${type}_FOUND)
set(CMAKE_REQUIRED_FLAGS ${__FLAG})
check_c_source_runs("${${type}_CODE}" HAS_${type}_${__FLAG_I})
if (HAS_${type}_${__FLAG_I})
set(${type}_FOUND TRUE CACHE BOOL "${type} support")
set(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags")
endif()
math(EXPR __FLAG_I "${__FLAG_I}+1")
endif()
endforeach()
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
if (NOT ${type}_FOUND)
set(${type}_FOUND FALSE CACHE BOOL "${type} support")
set(${type}_FLAGS "" CACHE STRING "${type} flags")
endif()
mark_as_advanced(${type}_FOUND ${type}_FLAGS)
endmacro()
# flags are for MSVC only!
check_sse("AVX" " ;/arch:AVX")
if (NOT ${AVX_FOUND})
set(LLAMA_AVX OFF)
else()
set(LLAMA_AVX ON)
endif()
check_sse("AVX2" " ;/arch:AVX2")
check_sse("FMA" " ;/arch:AVX2")
if ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))
set(LLAMA_AVX2 OFF)
else()
set(LLAMA_AVX2 ON)
endif()
check_sse("AVX512" " ;/arch:AVX512")
if (NOT ${AVX512_FOUND})
set(LLAMA_AVX512 OFF)
else()
set(LLAMA_AVX512 ON)
endif()
================================================
FILE: archive/csrc/ktransformers_ext/cpu_backend/backend.cpp
================================================
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-22 02:03:05
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:33:34
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "backend.h"
#ifdef USE_NUMA
#include
#include
thread_local int Backend::numa_node = -1;
#endif
thread_local int Backend::thread_local_id = -1;
Backend::Backend(int max_thread_num) {
max_thread_num_ = max_thread_num;
thread_state_.resize(max_thread_num_);
for (int i = 0; i < max_thread_num_; i++) {
thread_state_[i].curr = std::make_unique>();
thread_state_[i].status =
std::make_unique>(ThreadStatus::WAITING);
}
workers_.resize(max_thread_num_);
for (int i = 1; i < max_thread_num_; i++) {
workers_[i] = std::thread(&Backend::worker_thread, this, i);
}
}
Backend::~Backend() {
for (int i = 0; i < max_thread_num_; i++) {
thread_state_[i].status->store(ThreadStatus::EXIT,
std::memory_order_release);
}
for (int i = 1; i < max_thread_num_; i++) {
if (workers_[i].joinable()) {
workers_[i].join();
}
}
}
int Backend::get_thread_num() { return max_thread_num_; }
void Backend::do_work_stealing_job(int task_num,
std::function init_func,
std::function compute_func,
std::function finalize_func) {
init_func_ = init_func;
compute_func_ = compute_func;
finalize_func_ = finalize_func;
#ifdef USE_NUMA
// numa node location will be calculated based on the number of threads
thread_num_ = max_thread_num_;
#else
thread_num_ = std::min(max_thread_num_, task_num);
#endif
int base = task_num / thread_num_;
int remain = task_num % thread_num_;
thread_state_[0].end = base + (0 < remain);
// 为主线程设置 thread_local_id
thread_local_id = 0;
for (int i = 1; i < thread_num_; i++) {
thread_state_[i].curr->store(thread_state_[i - 1].end,
std::memory_order_relaxed);
thread_state_[i].end = thread_state_[i - 1].end + base + (i < remain);
thread_state_[i].status->store(ThreadStatus::WORKING,
std::memory_order_release);
}
thread_state_[0].curr->store(0, std::memory_order_relaxed);
thread_state_[0].status->store(ThreadStatus::WORKING,
std::memory_order_release);
process_tasks(0);
for (int i = 1; i < thread_num_; i++) {
while (thread_state_[i].status->load(std::memory_order_acquire) ==
ThreadStatus::WORKING) {
}
}
}
void Backend::process_tasks(int thread_id) {
#ifdef USE_NUMA
if(numa_node == -1){
numa_node = thread_id * numa_num_configured_nodes() / thread_num_;
struct bitmask* mask = numa_bitmask_alloc(numa_num_configured_nodes());
numa_bitmask_setbit(mask, numa_node);
numa_bind(mask);
}
#endif
if (init_func_ != nullptr) {
init_func_(thread_id);
}
while (true) {
int task_id = thread_state_[thread_id].curr->fetch_add(
1, std::memory_order_acq_rel);
if (task_id >= thread_state_[thread_id].end) {
break;
}
compute_func_(task_id);
}
for (int t_offset = 1; t_offset < thread_num_; t_offset++) {
int t_i = (thread_id + t_offset) % thread_num_;
if (thread_state_[t_i].status->load(std::memory_order_acquire) !=
ThreadStatus::WORKING) {
continue;
}
while (true) {
int task_id = thread_state_[t_i].curr->fetch_add(
1, std::memory_order_acq_rel);
if (task_id >= thread_state_[t_i].end) {
break;
}
compute_func_(task_id);
}
}
if (finalize_func_ != nullptr) {
finalize_func_(thread_id);
}
thread_state_[thread_id].status->store(ThreadStatus::WAITING,
std::memory_order_release);
}
void Backend::worker_thread(int thread_id) {
auto start = std::chrono::steady_clock::now();
thread_local_id = thread_id; // 设置线程本地变量
while (true) {
ThreadStatus status =
thread_state_[thread_id].status->load(std::memory_order_acquire);
if (status == ThreadStatus::WORKING) {
process_tasks(thread_id);
start = std::chrono::steady_clock::now();
} else if (status == ThreadStatus::WAITING) {
auto now = std::chrono::steady_clock::now();
auto duration =
std::chrono::duration_cast(now -
start)
.count();
if (duration > 50) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
} else if (status == ThreadStatus::EXIT) {
return;
}
}
}
================================================
FILE: archive/csrc/ktransformers_ext/cpu_backend/backend.h
================================================
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-22 02:03:05
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:33:38
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_BACKEND_H
#define CPUINFER_BACKEND_H
#include
#include
#include
#include
#include
#include
#include
enum ThreadStatus {
WORKING,
WAITING,
EXIT,
};
struct ThreadState {
std::unique_ptr> status;
std::unique_ptr> curr;
int end;
};
class Backend {
public:
Backend(int);
~Backend();
int get_thread_num();
void do_work_stealing_job(int, std::function,
std::function,
std::function);
#ifdef USE_NUMA
static thread_local int numa_node;
#endif
static thread_local int thread_local_id;
private:
int thread_num_;
int max_thread_num_;
std::vector thread_state_; // [thread_num]
std::function init_func_;
std::function compute_func_;
std::function finalize_func_;
std::vector workers_;
void process_tasks(int);
void worker_thread(int);
};
#endif
================================================
FILE: archive/csrc/ktransformers_ext/cpu_backend/cpuinfer.h
================================================
/**
* @Description :
* @Author : chenht2022
* @Date : 2024-07-16 10:43:18
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-08-07 09:47:43
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_CPUINFER_H
#define CPUINFER_CPUINFER_H
#include
#include
#include