Repository: kvcache-ai/ktransformers Branch: main Commit: 8561a71dd11e Files: 1146 Total size: 12.2 MB Directory structure: gitextract_0e22n38f/ ├── .github/ │ ├── CODE_OF_CONDUCT.md │ ├── CONTRIBUTING.md │ ├── ISSUE_TEMPLATE/ │ │ ├── -bug-.yaml │ │ ├── -feature-.yaml │ │ └── config.yml │ ├── PULL_REQUEST_TEMPLATE.md │ ├── SECURITY.md │ └── workflows/ │ ├── book-ci.yml │ ├── deploy.yml │ ├── docker-image.yml │ ├── kt-kernel-tests.yml │ ├── release-fake-tag.yml │ ├── release-pypi.yml │ ├── release-sglang-kt.yml │ └── sync-sglang-submodule.yml ├── .gitignore ├── .gitmodules ├── LICENSE ├── MAINTAINERS.md ├── README.md ├── README_ZH.md ├── archive/ │ ├── .devcontainer/ │ │ ├── Dockerfile │ │ └── devcontainer.json │ ├── .flake8 │ ├── .gitmodules │ ├── .pylintrc │ ├── Dockerfile │ ├── Dockerfile.xpu │ ├── LICENSE │ ├── MANIFEST.in │ ├── Makefile │ ├── README.md │ ├── README_LEGACY.md │ ├── README_ZH.md │ ├── README_ZH_LEGACY.md │ ├── SECURITY.md │ ├── book.toml │ ├── config.json │ ├── csrc/ │ │ ├── balance_serve/ │ │ │ └── CMakeLists.txt │ │ ├── custom_marlin/ │ │ │ ├── __init__.py │ │ │ ├── binding.cpp │ │ │ ├── gptq_marlin/ │ │ │ │ ├── gptq_marlin.cu │ │ │ │ ├── gptq_marlin.cuh │ │ │ │ ├── gptq_marlin_dtypes.cuh │ │ │ │ ├── gptq_marlin_repack.cu │ │ │ │ └── ops.h │ │ │ ├── setup.py │ │ │ ├── test_cuda_graph.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── format24.py │ │ │ ├── marlin_24_perms.py │ │ │ ├── marlin_perms.py │ │ │ ├── marlin_utils.py │ │ │ └── quant_utils.py │ │ └── ktransformers_ext/ │ │ ├── CMakeLists.txt │ │ ├── bench/ │ │ │ ├── bench_attention.py │ │ │ ├── bench_attention_torch.py │ │ │ ├── bench_linear.py │ │ │ ├── bench_linear_torch.py │ │ │ ├── bench_mlp.py │ │ │ ├── bench_mlp_torch.py │ │ │ ├── bench_moe.py │ │ │ ├── bench_moe_amx.py │ │ │ └── bench_moe_torch.py │ │ ├── cmake/ │ │ │ └── FindSIMD.cmake │ │ ├── cpu_backend/ │ │ │ ├── backend.cpp │ │ │ ├── backend.h │ │ │ ├── cpuinfer.h │ │ │ ├── shared_mem_buffer.cpp │ │ │ ├── shared_mem_buffer.h │ │ │ ├── task_queue.cpp │ │ │ ├── task_queue.h │ │ │ └── vendors/ │ │ │ ├── README.md │ │ │ ├── cuda.h │ │ │ ├── hip.h │ │ │ ├── musa.h │ │ │ └── vendor.h │ │ ├── cuda/ │ │ │ ├── binding.cpp │ │ │ ├── custom_gguf/ │ │ │ │ ├── dequant.cu │ │ │ │ └── ops.h │ │ │ ├── gptq_marlin/ │ │ │ │ ├── gptq_marlin.cu │ │ │ │ ├── gptq_marlin.cuh │ │ │ │ ├── gptq_marlin_dtypes.cuh │ │ │ │ └── ops.h │ │ │ ├── setup.py │ │ │ └── test_dequant.py │ │ ├── examples/ │ │ │ ├── test_attention.py │ │ │ ├── test_linear.py │ │ │ ├── test_mlp.py │ │ │ └── test_moe.py │ │ ├── ext_bindings.cpp │ │ ├── operators/ │ │ │ ├── amx/ │ │ │ │ ├── la/ │ │ │ │ │ ├── amx.hpp │ │ │ │ │ └── utils.hpp │ │ │ │ └── moe.hpp │ │ │ ├── kvcache/ │ │ │ │ ├── kvcache.h │ │ │ │ ├── kvcache_attn.cpp │ │ │ │ ├── kvcache_load_dump.cpp │ │ │ │ ├── kvcache_read_write.cpp │ │ │ │ └── kvcache_utils.cpp │ │ │ └── llamafile/ │ │ │ ├── conversion.h │ │ │ ├── linear.cpp │ │ │ ├── linear.h │ │ │ ├── mlp.cpp │ │ │ ├── mlp.h │ │ │ ├── moe.cpp │ │ │ └── moe.h │ │ └── vendors/ │ │ ├── cuda.h │ │ ├── hip.h │ │ ├── musa.h │ │ └── vendor.h │ ├── install-with-cache.sh │ ├── install.bat │ ├── install.sh │ ├── ktransformers/ │ │ ├── __init__.py │ │ ├── configs/ │ │ │ ├── config.yaml │ │ │ └── log_config.ini │ │ ├── ktransformers_ext/ │ │ │ ├── operators/ │ │ │ │ └── custom_marlin/ │ │ │ │ └── quantize/ │ │ │ │ └── utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── format_24.py │ │ │ │ ├── marlin_24_perms.py │ │ │ │ ├── marlin_perms.py │ │ │ │ ├── marlin_utils.py │ │ │ │ └── quant_utils.py │ │ │ └── triton/ │ │ │ └── fp8gemm.py │ │ ├── local_chat.py │ │ ├── local_chat_test.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── ascend/ │ │ │ │ ├── custom_ascend_modeling_deepseek_v3.py │ │ │ │ └── custom_ascend_modeling_qwen3.py │ │ │ ├── configuration_deepseek.py │ │ │ ├── configuration_deepseek_v3.py │ │ │ ├── configuration_glm4_moe.py │ │ │ ├── configuration_llama.py │ │ │ ├── configuration_qwen2_moe.py │ │ │ ├── configuration_qwen3_moe.py │ │ │ ├── configuration_qwen3_next.py │ │ │ ├── configuration_smallthinker.py │ │ │ ├── custom_cache.py │ │ │ ├── custom_modeling_deepseek_v2.py │ │ │ ├── custom_modeling_deepseek_v3.py │ │ │ ├── custom_modeling_glm4_moe.py │ │ │ ├── custom_modeling_qwen2_moe.py │ │ │ ├── custom_modeling_qwen3_moe.py │ │ │ ├── custom_modeling_qwen3_next.py │ │ │ ├── custom_modeling_smallthinker.py │ │ │ ├── modeling_deepseek.py │ │ │ ├── modeling_deepseek_v3.py │ │ │ ├── modeling_glm4_moe.py │ │ │ ├── modeling_llama.py │ │ │ ├── modeling_mixtral.py │ │ │ ├── modeling_qwen2_moe.py │ │ │ ├── modeling_qwen3_moe.py │ │ │ ├── modeling_qwen3_next.py │ │ │ └── modeling_smallthinker.py │ │ ├── operators/ │ │ │ ├── RoPE.py │ │ │ ├── __init__.py │ │ │ ├── ascend/ │ │ │ │ ├── ascend_attention.py │ │ │ │ ├── ascend_experts.py │ │ │ │ ├── ascend_gate.py │ │ │ │ ├── ascend_layernorm.py │ │ │ │ ├── ascend_linear.py │ │ │ │ └── ascend_mlp.py │ │ │ ├── attention.py │ │ │ ├── balance_serve_attention.py │ │ │ ├── base_operator.py │ │ │ ├── cpuinfer.py │ │ │ ├── dynamic_attention.py │ │ │ ├── experts.py │ │ │ ├── flashinfer_batch_prefill_wrapper.py │ │ │ ├── flashinfer_wrapper.py │ │ │ ├── gate.py │ │ │ ├── layernorm.py │ │ │ ├── linear.py │ │ │ ├── mlp.py │ │ │ ├── models.py │ │ │ ├── triton_attention.py │ │ │ └── triton_attention_prefill.py │ │ ├── optimize/ │ │ │ ├── optimize.py │ │ │ └── optimize_rules/ │ │ │ ├── DeepSeek-V2-Chat-multi-gpu-4.yaml │ │ │ ├── DeepSeek-V2-Chat-multi-gpu.yaml │ │ │ ├── DeepSeek-V2-Chat.yaml │ │ │ ├── DeepSeek-V2-Lite-Chat-gpu-cpu.yaml │ │ │ ├── DeepSeek-V2-Lite-Chat-multi-gpu.yaml │ │ │ ├── DeepSeek-V2-Lite-Chat.yaml │ │ │ ├── DeepSeek-V3-Chat-amx.yaml │ │ │ ├── DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve-amx.yaml │ │ │ ├── DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve.yaml │ │ │ ├── DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml │ │ │ ├── DeepSeek-V3-Chat-multi-gpu-4.yaml │ │ │ ├── DeepSeek-V3-Chat-multi-gpu-8.yaml │ │ │ ├── DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml │ │ │ ├── DeepSeek-V3-Chat-multi-gpu-marlin.yaml │ │ │ ├── DeepSeek-V3-Chat-multi-gpu.yaml │ │ │ ├── DeepSeek-V3-Chat-npu.yaml │ │ │ ├── DeepSeek-V3-Chat-serve.yaml │ │ │ ├── DeepSeek-V3-Chat.yaml │ │ │ ├── Glm4Moe-serve.yaml │ │ │ ├── Internlm2_5-7b-Chat-1m.yaml │ │ │ ├── Mixtral.yaml │ │ │ ├── Moonlight-16B-A3B-serve.yaml │ │ │ ├── Moonlight-16B-A3B.yaml │ │ │ ├── Qwen2-57B-A14B-Instruct-multi-gpu.yaml │ │ │ ├── Qwen2-57B-A14B-Instruct.yaml │ │ │ ├── Qwen2-serve-amx.yaml │ │ │ ├── Qwen2-serve.yaml │ │ │ ├── Qwen3Moe-serve-amx.yaml │ │ │ ├── Qwen3Moe-serve.yaml │ │ │ ├── Qwen3Next-serve.yaml │ │ │ ├── Smallthinker-serve.yaml │ │ │ ├── npu/ │ │ │ │ ├── DeepSeek-V3-Chat-300IA2-npu-serve.yaml │ │ │ │ ├── DeepSeek-V3-Chat-300IA2-npu.yaml │ │ │ │ └── Qwen3-Chat-300IA2-npu-serve.yaml │ │ │ ├── rocm/ │ │ │ │ └── DeepSeek-V3-Chat.yaml │ │ │ └── xpu/ │ │ │ ├── DeepSeek-V2-Chat.yaml │ │ │ ├── DeepSeek-V3-Chat.yaml │ │ │ └── Qwen3Moe-Chat.yaml │ │ ├── server/ │ │ │ ├── __init__.py │ │ │ ├── api/ │ │ │ │ ├── __init__.py │ │ │ │ ├── ollama/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── completions.py │ │ │ │ ├── openai/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── assistants/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── assistants.py │ │ │ │ │ │ ├── messages.py │ │ │ │ │ │ ├── runs.py │ │ │ │ │ │ └── threads.py │ │ │ │ │ ├── endpoints/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── chat.py │ │ │ │ │ └── legacy/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── completions.py │ │ │ │ └── web/ │ │ │ │ ├── __init__.py │ │ │ │ └── system.py │ │ │ ├── args.py │ │ │ ├── backend/ │ │ │ │ ├── __init__.py │ │ │ │ ├── args.py │ │ │ │ ├── base.py │ │ │ │ ├── context_manager.py │ │ │ │ └── interfaces/ │ │ │ │ ├── __init__.py │ │ │ │ ├── balance_serve.py │ │ │ │ ├── exllamav2.py │ │ │ │ ├── ktransformers.py │ │ │ │ └── transformers.py │ │ │ ├── balance_serve/ │ │ │ │ ├── inference/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── config.py │ │ │ │ │ ├── distributed/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── communication_op.py │ │ │ │ │ │ ├── cuda_wrapper.py │ │ │ │ │ │ ├── custom_all_reduce.py │ │ │ │ │ │ ├── custom_all_reduce_utils.py │ │ │ │ │ │ ├── parallel_state.py │ │ │ │ │ │ ├── pynccl.py │ │ │ │ │ │ ├── pynccl_wrapper.py │ │ │ │ │ │ └── utils.py │ │ │ │ │ ├── forward_batch.py │ │ │ │ │ ├── model_runner.py │ │ │ │ │ ├── query_manager.py │ │ │ │ │ └── sampling/ │ │ │ │ │ ├── penaltylib/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── orchestrator.py │ │ │ │ │ │ └── penalizers/ │ │ │ │ │ │ ├── frequency_penalty.py │ │ │ │ │ │ ├── min_new_tokens.py │ │ │ │ │ │ ├── presence_penalty.py │ │ │ │ │ │ └── repetition_penalty.py │ │ │ │ │ └── sampler.py │ │ │ │ ├── sched_rpc.py │ │ │ │ └── settings.py │ │ │ ├── config/ │ │ │ │ ├── config.py │ │ │ │ ├── log.py │ │ │ │ └── singleton.py │ │ │ ├── crud/ │ │ │ │ ├── __init__.py │ │ │ │ └── assistants/ │ │ │ │ ├── __init__.py │ │ │ │ ├── assistants.py │ │ │ │ ├── messages.py │ │ │ │ ├── runs.py │ │ │ │ └── threads.py │ │ │ ├── exceptions.py │ │ │ ├── main.py │ │ │ ├── models/ │ │ │ │ ├── __init__.py │ │ │ │ └── assistants/ │ │ │ │ ├── __init__.py │ │ │ │ ├── assistants.py │ │ │ │ ├── messages.py │ │ │ │ ├── run_steps.py │ │ │ │ ├── runs.py │ │ │ │ └── threads.py │ │ │ ├── requirements.txt │ │ │ ├── schemas/ │ │ │ │ ├── __init__.py │ │ │ │ ├── assistants/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── assistants.py │ │ │ │ │ ├── messages.py │ │ │ │ │ ├── runs.py │ │ │ │ │ ├── streaming.py │ │ │ │ │ ├── threads.py │ │ │ │ │ └── tool.py │ │ │ │ ├── base.py │ │ │ │ ├── conversation.py │ │ │ │ ├── endpoints/ │ │ │ │ │ └── chat.py │ │ │ │ └── legacy/ │ │ │ │ ├── __init__.py │ │ │ │ └── completions.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── create_interface.py │ │ │ ├── multi_timer.py │ │ │ ├── serve_profiling.py │ │ │ └── sql_utils.py │ │ ├── tests/ │ │ │ ├── .gitignore │ │ │ ├── AIME_2024/ │ │ │ │ ├── eval_api.py │ │ │ │ ├── evaluation.py │ │ │ │ └── prompts.py │ │ │ ├── UT/ │ │ │ │ ├── test_kdeepseek_attention_w8a8a2serve_npu.py │ │ │ │ └── test_kdeepseek_ln_npu.py │ │ │ ├── dequant_gpu.py │ │ │ ├── dequant_gpu_t.py │ │ │ ├── function_call_test.py │ │ │ ├── humaneval/ │ │ │ │ ├── eval_api.py │ │ │ │ ├── evaluation.py │ │ │ │ └── prompts.py │ │ │ ├── mmlu_pro_test.py │ │ │ ├── mmlu_test.py │ │ │ ├── mmlu_test_multi.py │ │ │ ├── parse_cover_info.py │ │ │ ├── score.py │ │ │ ├── test_client.py │ │ │ ├── test_prefix.py │ │ │ ├── test_pytorch_q8.py │ │ │ ├── test_speed.py │ │ │ └── triton_fp8gemm_test.py │ │ ├── util/ │ │ │ ├── ascend/ │ │ │ │ └── ascend_utils.py │ │ │ ├── cuda_graph_runner.py │ │ │ ├── custom_gguf.py │ │ │ ├── custom_loader.py │ │ │ ├── modeling_rope_utils.py │ │ │ ├── npu_graph_runner.py │ │ │ ├── textstream.py │ │ │ ├── utils.py │ │ │ ├── vendors.py │ │ │ └── weight_loader.py │ │ └── website/ │ │ ├── .browserslistrc │ │ ├── .eslintrc.js │ │ ├── .gitignore │ │ ├── README.md │ │ ├── config.d.ts │ │ ├── jest.config.js │ │ ├── package.json │ │ ├── public/ │ │ │ ├── config.js │ │ │ ├── css/ │ │ │ │ └── reset.css │ │ │ └── index.html │ │ ├── src/ │ │ │ ├── App.vue │ │ │ ├── api/ │ │ │ │ ├── api-client.ts │ │ │ │ ├── assistant.ts │ │ │ │ ├── message.ts │ │ │ │ ├── run.ts │ │ │ │ └── thread.ts │ │ │ ├── assets/ │ │ │ │ ├── css/ │ │ │ │ │ └── mixins.styl │ │ │ │ └── iconfont/ │ │ │ │ ├── demo.css │ │ │ │ ├── demo_index.html │ │ │ │ ├── iconfont.css │ │ │ │ ├── iconfont.js │ │ │ │ └── iconfont.json │ │ │ ├── components/ │ │ │ │ └── chat/ │ │ │ │ └── index.vue │ │ │ ├── conf/ │ │ │ │ └── config.ts │ │ │ ├── locals/ │ │ │ │ ├── en.js │ │ │ │ ├── index.js │ │ │ │ └── zh.js │ │ │ ├── main.ts │ │ │ ├── router/ │ │ │ │ └── index.ts │ │ │ ├── shims-vue.d.ts │ │ │ ├── store/ │ │ │ │ └── index.ts │ │ │ ├── utils/ │ │ │ │ ├── copy.ts │ │ │ │ └── types.ts │ │ │ └── views/ │ │ │ └── home.vue │ │ ├── tests/ │ │ │ └── unit/ │ │ │ └── example.spec.ts │ │ ├── tsconfig.json │ │ └── vue.config.js │ ├── merge_tensors/ │ │ ├── merge_safetensor_gguf.py │ │ └── merge_safetensor_gguf_for_qwen3.py │ ├── pyproject.toml │ ├── requirements-local_chat.txt │ ├── setup.py │ └── third_party/ │ ├── llamafile/ │ │ ├── README.md │ │ ├── bench.h │ │ ├── flags.cpp │ │ ├── flags.h │ │ ├── iqk_mul_mat.inc │ │ ├── iqk_mul_mat_amd_avx2.cpp │ │ ├── iqk_mul_mat_amd_zen4.cpp │ │ ├── iqk_mul_mat_arm.inc │ │ ├── iqk_mul_mat_arm82.cpp │ │ ├── iqk_mul_mat_x86.inc │ │ ├── macros.h │ │ ├── micros.h │ │ ├── numba.h │ │ ├── sgemm.cpp │ │ ├── sgemm.h │ │ ├── sgemm_arm.cpp │ │ ├── sgemm_x86.cpp │ │ ├── tinyblas_cpu.h │ │ ├── tinyblas_cpu_mixmul.inc │ │ ├── tinyblas_cpu_mixmul_amd_avx.cpp │ │ ├── tinyblas_cpu_mixmul_amd_avx2.cpp │ │ ├── tinyblas_cpu_mixmul_amd_avx512f.cpp │ │ ├── tinyblas_cpu_mixmul_amd_avxvnni.cpp │ │ ├── tinyblas_cpu_mixmul_amd_fma.cpp │ │ ├── tinyblas_cpu_mixmul_amd_zen4.cpp │ │ ├── tinyblas_cpu_mixmul_arm80.cpp │ │ ├── tinyblas_cpu_mixmul_arm82.cpp │ │ ├── tinyblas_cpu_sgemm.inc │ │ ├── tinyblas_cpu_sgemm_amd_avx.cpp │ │ ├── tinyblas_cpu_sgemm_amd_avx2.cpp │ │ ├── tinyblas_cpu_sgemm_amd_avx512f.cpp │ │ ├── tinyblas_cpu_sgemm_amd_avxvnni.cpp │ │ ├── tinyblas_cpu_sgemm_amd_fma.cpp │ │ ├── tinyblas_cpu_sgemm_amd_zen4.cpp │ │ ├── tinyblas_cpu_sgemm_arm.inc │ │ ├── tinyblas_cpu_sgemm_arm80.cpp │ │ ├── tinyblas_cpu_sgemm_arm82.cpp │ │ ├── tinyblas_cpu_sgemm_x86.inc │ │ └── tinyblas_cpu_unsupported.cpp │ └── nlohmann/ │ ├── json.hpp │ └── json_fwd.hpp ├── book.toml ├── doc/ │ ├── SUMMARY.md │ ├── basic/ │ │ ├── note1.md │ │ └── note2.md │ ├── en/ │ │ ├── AMX.md │ │ ├── DeepseekR1_V3_tutorial.md │ │ ├── Docker.md │ │ ├── Docker_xpu.md │ │ ├── FAQ.md │ │ ├── Kimi-K2-Thinking.md │ │ ├── Kimi-K2.5.md │ │ ├── Kimi-K2.md │ │ ├── Kllama_tutorial_DeepSeekV2Lite.ipynb │ │ ├── MiniMax-M2.5.md │ │ ├── Qwen3-Next.md │ │ ├── Qwen3.5.md │ │ ├── ROCm.md │ │ ├── SFT/ │ │ │ ├── DPO_tutorial.md │ │ │ ├── KTransformers-Fine-Tuning_Developer-Technical-Notes.md │ │ │ ├── KTransformers-Fine-Tuning_User-Guide.md │ │ │ ├── README.md │ │ │ └── injection_tutorial.md │ │ ├── SFT_Installation_Guide_KimiK2.5.md │ │ ├── SFT_Installation_Guide_KimiK2.md │ │ ├── SmallThinker_and_Glm4moe.md │ │ ├── V3-success.md │ │ ├── api/ │ │ │ └── server/ │ │ │ ├── api.md │ │ │ ├── server.md │ │ │ ├── tabby.md │ │ │ └── website.md │ │ ├── balance-serve.md │ │ ├── benchmark.md │ │ ├── deepseek-v2-injection.md │ │ ├── fp8_kernel.md │ │ ├── install.md │ │ ├── kt-kernel/ │ │ │ ├── GLM-5-Tutorial.md │ │ │ ├── Kimi-K2-Thinking-Native.md │ │ │ ├── MiniMax-M2.1-Tutorial.md │ │ │ ├── Native-Precision-Tutorial.md │ │ │ ├── Qwen3-Coder-Next-Tutorial.md │ │ │ ├── README.md │ │ │ ├── amd_blis.md │ │ │ ├── deepseek-v3.2-sglang-tutorial.md │ │ │ ├── experts-sched-Tutorial.md │ │ │ └── kt-cli.md │ │ ├── llama4.md │ │ ├── long_context_introduction.md │ │ ├── long_context_tutorial.md │ │ ├── makefile_usage.md │ │ ├── multi-gpu-tutorial.md │ │ ├── operators/ │ │ │ └── llamafile.md │ │ ├── prefix_cache.md │ │ └── xpu.md │ └── zh/ │ ├── DeepseekR1_V3_tutorial_zh.md │ ├── DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md │ ├── KTransformers-Fine-Tuning_Developer-Technical-Notes_zh.md │ ├── KTransformers-Fine-Tuning_User-Guide_zh.md │ ├── Qwen3-MoE_tutorial_zh_for_Ascend_NPU.md │ ├── api/ │ │ └── server/ │ │ ├── api.md │ │ ├── server.md │ │ ├── tabby.md │ │ └── website.md │ └── clawdbot_integration_guide.md ├── docker/ │ ├── Dockerfile │ ├── README-packaging.md │ ├── docker-utils.sh │ └── push-to-dockerhub.sh ├── install.sh ├── kt-kernel/ │ ├── .clang-format │ ├── .githooks/ │ │ ├── commit-msg │ │ └── pre-commit │ ├── .gitignore │ ├── .gitmodules │ ├── CMakeLists.txt │ ├── CMakePresets.json │ ├── MANIFEST.in │ ├── README.md │ ├── README_zh.md │ ├── bench/ │ │ ├── .gitignore │ │ ├── Makefile │ │ ├── bench_attention.py │ │ ├── bench_attention_torch.py │ │ ├── bench_bf16_moe.py │ │ ├── bench_fp8_moe.py │ │ ├── bench_fp8_perchannel_moe.py │ │ ├── bench_k2_moe_amx.py │ │ ├── bench_k2_write_buffer.py │ │ ├── bench_linear.py │ │ ├── bench_linear_torch.py │ │ ├── bench_mla.py │ │ ├── bench_mlp.py │ │ ├── bench_mlp_torch.py │ │ ├── bench_moe.py │ │ ├── bench_moe_amx.py │ │ ├── bench_moe_amx_k.py │ │ ├── bench_moe_kernel.py │ │ ├── bench_moe_kernel_tiling.py │ │ ├── bench_moe_kml.py │ │ ├── bench_moe_torch.py │ │ ├── bench_write_buffer.py │ │ ├── compare_moe_performance.py │ │ ├── multi_bench_moe.py │ │ └── upload-bench-json.py │ ├── cmake/ │ │ ├── DetectCPU.cmake │ │ └── FindSIMD.cmake │ ├── cpu_backend/ │ │ ├── cpuinfer.h │ │ ├── shared_mem_buffer.cpp │ │ ├── shared_mem_buffer.h │ │ ├── task_queue.cpp │ │ ├── task_queue.h │ │ ├── vendors/ │ │ │ ├── README.md │ │ │ ├── cuda.h │ │ │ ├── hip.h │ │ │ ├── musa.h │ │ │ └── vendor.h │ │ ├── worker_pool.cpp │ │ └── worker_pool.h │ ├── cuda/ │ │ ├── binding.cpp │ │ ├── custom_gguf/ │ │ │ ├── dequant.cu │ │ │ └── ops.h │ │ ├── gptq_marlin/ │ │ │ ├── gptq_marlin.cu │ │ │ ├── gptq_marlin.cuh │ │ │ ├── gptq_marlin_dtypes.cuh │ │ │ └── ops.h │ │ ├── moe/ │ │ │ ├── moe_topk_softmax_kernels.cu │ │ │ ├── ops.h │ │ │ └── utils.h │ │ ├── setup.py │ │ └── test_dequant.py │ ├── demo/ │ │ ├── .gitignore │ │ ├── Makefile │ │ ├── bench_reorder_bandwidth.cpp │ │ ├── bf16-test.cpp │ │ ├── fp16-test.cpp │ │ ├── plot.py │ │ ├── simple_test.cpp │ │ ├── simple_test_aocl.cpp │ │ └── tflops.py │ ├── examples/ │ │ ├── .gitignore │ │ ├── bench_moe_amx_int8.py │ │ ├── configuration_deepseek_v3.py │ │ ├── modeling_deepseek_v3.py │ │ ├── repro_llamafile_re.py │ │ ├── test-debug.py │ │ ├── test_apply_rope.py │ │ ├── test_attention.py │ │ ├── test_awq_moe_amx.py │ │ ├── test_bf16_moe.py │ │ ├── test_deepseekv3.py │ │ ├── test_deepseekv3_prefill.py │ │ ├── test_deepseekv3_prefill_speed.py │ │ ├── test_fp8_moe.py │ │ ├── test_fp8_perchannel_moe.py │ │ ├── test_gate.py │ │ ├── test_k2_moe_amx.py │ │ ├── test_k2_write_buffer.py │ │ ├── test_linear.py │ │ ├── test_mla.py │ │ ├── test_mla_qlen.py │ │ ├── test_mla_quant.py │ │ ├── test_mla_simple.py │ │ ├── test_mla_torch.py │ │ ├── test_mlp.py │ │ ├── test_moe.py │ │ ├── test_moe_amx.py │ │ ├── test_moe_kernel.py │ │ ├── test_moe_kml.py │ │ ├── test_rope.cpp │ │ ├── test_rope.py │ │ ├── test_softmax.py │ │ ├── test_write_buffer.py │ │ └── torch_attention.py │ ├── ext_bindings.cpp │ ├── install.sh │ ├── operators/ │ │ ├── amx/ │ │ │ ├── awq-moe.hpp │ │ │ ├── bf16-moe.hpp │ │ │ ├── fp8-moe.hpp │ │ │ ├── fp8-perchannel-moe.hpp │ │ │ ├── k2-moe.hpp │ │ │ ├── la/ │ │ │ │ ├── amx-example.cpp │ │ │ │ ├── amx.hpp │ │ │ │ ├── amx_buffers.hpp │ │ │ │ ├── amx_config.hpp │ │ │ │ ├── amx_kernels.hpp │ │ │ │ ├── amx_quantization.hpp │ │ │ │ ├── amx_raw_buffers.hpp │ │ │ │ ├── amx_raw_kernels.hpp │ │ │ │ ├── amx_utils.hpp │ │ │ │ ├── pack.hpp │ │ │ │ └── utils.hpp │ │ │ ├── moe.hpp │ │ │ ├── moe_base.hpp │ │ │ └── test/ │ │ │ ├── amx-bkgroup-test.cpp │ │ │ ├── amx-c-reduce-test.cpp │ │ │ ├── amx-kgroup-test.cpp │ │ │ ├── amx-test.cpp │ │ │ ├── analyze-error.cpp │ │ │ ├── avx-test.cpp │ │ │ ├── debug-kgroup-details.cpp │ │ │ ├── debug-kgroup.cpp │ │ │ ├── debug-specific-dims.cpp │ │ │ ├── mat-test.hpp │ │ │ ├── mmq-test.cpp │ │ │ ├── mmq.cpp │ │ │ ├── mmq.h │ │ │ ├── test-kgroup-128.cpp │ │ │ ├── test-kgroup-kernel.cpp │ │ │ ├── test-specific-dims.cpp │ │ │ ├── thread_test.sh │ │ │ ├── timer.hh │ │ │ └── verify-kgroup.cpp │ │ ├── common.hpp │ │ ├── kvcache/ │ │ │ ├── kvcache.h │ │ │ ├── kvcache_attn.cpp │ │ │ ├── kvcache_load_dump.cpp │ │ │ ├── kvcache_read_write.cpp │ │ │ └── kvcache_utils.cpp │ │ ├── llamafile/ │ │ │ ├── conversion.h │ │ │ ├── linear.cpp │ │ │ ├── linear.h │ │ │ ├── mla.hpp │ │ │ ├── mlp.cpp │ │ │ ├── mlp.h │ │ │ └── moe.hpp │ │ ├── mla-tp.hpp │ │ ├── moe-tp.hpp │ │ ├── moe_kernel/ │ │ │ ├── api/ │ │ │ │ ├── common.h │ │ │ │ └── mat_kernel.h │ │ │ ├── la/ │ │ │ │ ├── kernel.hpp │ │ │ │ ├── mat_kernel.cpp │ │ │ │ └── utils.hpp │ │ │ ├── mat_kernel/ │ │ │ │ ├── aocl_kernel/ │ │ │ │ │ └── kernel.cpp │ │ │ │ └── batch_gemm_api.hpp │ │ │ ├── moe.hpp │ │ │ └── test/ │ │ │ ├── convert-test.cpp │ │ │ ├── debug.hpp │ │ │ ├── int4_mul-test.cpp │ │ │ ├── mat_test.cpp │ │ │ └── utils_test.cpp │ │ ├── reduce.hpp │ │ ├── rms-norm.hpp │ │ ├── rope.hpp │ │ ├── softmax.hpp │ │ └── tp.hpp │ ├── pyproject.toml │ ├── pytest.ini │ ├── python/ │ │ ├── __init__.py │ │ ├── _cpu_detect.py │ │ ├── cli/ │ │ │ ├── __init__.py │ │ │ ├── commands/ │ │ │ │ ├── __init__.py │ │ │ │ ├── bench.py │ │ │ │ ├── chat.py │ │ │ │ ├── config.py │ │ │ │ ├── doctor.py │ │ │ │ ├── model.py │ │ │ │ ├── quant.py │ │ │ │ ├── run.py │ │ │ │ ├── sft.py │ │ │ │ └── version.py │ │ │ ├── completions/ │ │ │ │ ├── __init__.py │ │ │ │ ├── _kt │ │ │ │ ├── kt-completion.bash │ │ │ │ └── kt.fish │ │ │ ├── config/ │ │ │ │ ├── __init__.py │ │ │ │ └── settings.py │ │ │ ├── i18n.py │ │ │ ├── main.py │ │ │ ├── requirements/ │ │ │ │ ├── inference.txt │ │ │ │ └── sft.txt │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── analyze_moe_model.py │ │ │ ├── console.py │ │ │ ├── debug_configs.py │ │ │ ├── download_helper.py │ │ │ ├── environment.py │ │ │ ├── input_validators.py │ │ │ ├── kv_cache_calculator.py │ │ │ ├── model_discovery.py │ │ │ ├── model_registry.py │ │ │ ├── model_scanner.py │ │ │ ├── model_table_builder.py │ │ │ ├── model_verifier.py │ │ │ ├── port_checker.py │ │ │ ├── quant_interactive.py │ │ │ ├── repo_detector.py │ │ │ ├── run_configs.py │ │ │ ├── run_interactive.py │ │ │ ├── sglang_checker.py │ │ │ ├── tuna_engine.py │ │ │ └── user_model_registry.py │ │ ├── experts.py │ │ ├── experts_base.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── amx.py │ │ ├── llamafile.py │ │ ├── loader.py │ │ └── moe_kernel.py │ ├── requirements.txt │ ├── scripts/ │ │ ├── README.md │ │ ├── check.py │ │ ├── check_cpu_features.py │ │ ├── compare_weights.py │ │ ├── convert_cpu_weights.py │ │ ├── convert_gpu_weights.py │ │ ├── convert_kimi_k2_fp8_to_bf16_cpu.py │ │ ├── convert_moe_to_bf16.py │ │ └── install-git-hooks.sh │ ├── setup.py │ └── test/ │ ├── __init__.py │ ├── ci/ │ │ ├── __init__.py │ │ ├── ci_register.py │ │ └── ci_utils.py │ ├── per_commit/ │ │ ├── __init__.py │ │ ├── test_amd_placeholder.py │ │ ├── test_basic_cpu.py │ │ ├── test_cuda_placeholder.py │ │ ├── test_moe_amx_accuracy_int4.py │ │ ├── test_moe_amx_accuracy_int4_1.py │ │ ├── test_moe_amx_accuracy_int4_1k.py │ │ ├── test_moe_amx_accuracy_int8.py │ │ ├── test_moe_amx_bench_int4.py │ │ ├── test_moe_amx_bench_int4_1.py │ │ ├── test_moe_amx_bench_int4_1k.py │ │ └── test_moe_amx_bench_int8.py │ ├── run_suite.py │ └── test_generate_gpu_experts_masks.py ├── kt-sft/ │ ├── .flake8 │ ├── .gitignore │ ├── .gitmodules │ ├── .pylintrc │ ├── Dockerfile │ ├── Dockerfile.xpu │ ├── LICENSE │ ├── MANIFEST.in │ ├── Makefile │ ├── README.md │ ├── SECURITY.md │ ├── autosetup.sh │ ├── book.toml │ ├── csrc/ │ │ ├── custom_marlin/ │ │ │ ├── __init__.py │ │ │ ├── binding.cpp │ │ │ ├── gptq_marlin/ │ │ │ │ ├── gptq_marlin.cu │ │ │ │ ├── gptq_marlin.cuh │ │ │ │ ├── gptq_marlin_dtypes.cuh │ │ │ │ ├── gptq_marlin_repack.cu │ │ │ │ └── ops.h │ │ │ ├── setup.py │ │ │ ├── test_cuda_graph.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── format24.py │ │ │ ├── marlin_24_perms.py │ │ │ ├── marlin_perms.py │ │ │ ├── marlin_utils.py │ │ │ └── quant_utils.py │ │ └── ktransformers_ext/ │ │ ├── CMakeLists.txt │ │ ├── bench/ │ │ │ ├── bench_attention.py │ │ │ ├── bench_attention_torch.py │ │ │ ├── bench_linear.py │ │ │ ├── bench_linear_torch.py │ │ │ ├── bench_mlp.py │ │ │ ├── bench_mlp_torch.py │ │ │ ├── bench_moe.py │ │ │ ├── bench_moe_amx.py │ │ │ └── bench_moe_torch.py │ │ ├── cmake/ │ │ │ └── FindSIMD.cmake │ │ ├── cpu_backend/ │ │ │ ├── backend.cpp │ │ │ ├── backend.h │ │ │ ├── cpuinfer.h │ │ │ ├── shared_mem_buffer.cpp │ │ │ ├── shared_mem_buffer.h │ │ │ ├── task_queue.cpp │ │ │ ├── task_queue.h │ │ │ └── vendors/ │ │ │ ├── README.md │ │ │ ├── cuda.h │ │ │ ├── hip.h │ │ │ ├── musa.h │ │ │ └── vendor.h │ │ ├── cuda/ │ │ │ ├── binding.cpp │ │ │ ├── custom_gguf/ │ │ │ │ ├── dequant.cu │ │ │ │ └── ops.h │ │ │ ├── gptq_marlin/ │ │ │ │ ├── gptq_marlin.cu │ │ │ │ ├── gptq_marlin.cuh │ │ │ │ ├── gptq_marlin_dtypes.cuh │ │ │ │ └── ops.h │ │ │ ├── setup.py │ │ │ └── test_dequant.py │ │ ├── examples/ │ │ │ ├── test_attention.py │ │ │ ├── test_linear.py │ │ │ ├── test_mlp.py │ │ │ ├── test_moe.py │ │ │ ├── test_sft_amx_moe.py │ │ │ └── test_sft_moe.py │ │ ├── ext_bindings.cpp │ │ ├── operators/ │ │ │ ├── amx/ │ │ │ │ ├── debug_sft_moe.hpp │ │ │ │ ├── debug_tools_sft_moe.hpp │ │ │ │ ├── la/ │ │ │ │ │ ├── amx.hpp │ │ │ │ │ └── utils.hpp │ │ │ │ ├── moe.hpp │ │ │ │ └── sft_moe.hpp │ │ │ ├── kvcache/ │ │ │ │ ├── kvcache.h │ │ │ │ ├── kvcache_attn.cpp │ │ │ │ ├── kvcache_load_dump.cpp │ │ │ │ ├── kvcache_read_write.cpp │ │ │ │ └── kvcache_utils.cpp │ │ │ └── llamafile/ │ │ │ ├── conversion.h │ │ │ ├── linear.cpp │ │ │ ├── linear.h │ │ │ ├── mlp.cpp │ │ │ ├── mlp.h │ │ │ ├── moe.cpp │ │ │ ├── moe.h │ │ │ ├── sft_moe.cpp │ │ │ ├── sft_moe.h │ │ │ └── sft_moe_forward_cache.h │ │ └── vendors/ │ │ ├── cuda.h │ │ ├── hip.h │ │ ├── musa.h │ │ └── vendor.h │ ├── install-with-cache.sh │ ├── install.bat │ ├── install.sh │ ├── ktransformers/ │ │ ├── __init__.py │ │ ├── configs/ │ │ │ ├── config.yaml │ │ │ ├── log_config.ini │ │ │ └── model_config/ │ │ │ ├── config.json │ │ │ └── configuration_deepseek.py │ │ ├── ktransformers_ext/ │ │ │ ├── operators/ │ │ │ │ └── custom_marlin/ │ │ │ │ └── quantize/ │ │ │ │ └── utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── format_24.py │ │ │ │ ├── marlin_24_perms.py │ │ │ │ ├── marlin_perms.py │ │ │ │ ├── marlin_utils.py │ │ │ │ └── quant_utils.py │ │ │ └── triton/ │ │ │ └── fp8gemm.py │ │ ├── local_chat.py │ │ ├── local_chat.sh │ │ ├── lora_test_module.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── configuration_deepseek.py │ │ │ ├── configuration_deepseek_v3.py │ │ │ ├── configuration_llama.py │ │ │ ├── configuration_qwen2_moe.py │ │ │ ├── configuration_qwen3_moe.py │ │ │ ├── custom_cache.py │ │ │ ├── custom_modeling_deepseek_v2.py │ │ │ ├── custom_modeling_deepseek_v3.py │ │ │ ├── custom_modeling_qwen2_moe.py │ │ │ ├── custom_modeling_qwen3_moe.py │ │ │ ├── modeling_deepseek.py │ │ │ ├── modeling_deepseek_v3.py │ │ │ ├── modeling_llama.py │ │ │ ├── modeling_mixtral.py │ │ │ ├── modeling_qwen2_moe.py │ │ │ └── modeling_qwen3_moe.py │ │ ├── moe_test_module.py │ │ ├── moe_test_module_old.py │ │ ├── operators/ │ │ │ ├── RoPE.py │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── balance_serve_attention.py │ │ │ ├── base_operator.py │ │ │ ├── cpuinfer.py │ │ │ ├── dynamic_attention.py │ │ │ ├── experts.py │ │ │ ├── flashinfer_batch_prefill_wrapper.py │ │ │ ├── flashinfer_wrapper.py │ │ │ ├── gate.py │ │ │ ├── layernorm.py │ │ │ ├── linear.py │ │ │ ├── mlp.py │ │ │ ├── models.py │ │ │ ├── triton_attention.py │ │ │ └── triton_attention_prefill.py │ │ ├── optimize/ │ │ │ ├── optimize.py │ │ │ └── optimize_rules/ │ │ │ ├── DeepSeek-V2-Chat-multi-gpu-4.yaml │ │ │ ├── DeepSeek-V2-Chat-multi-gpu.yaml │ │ │ ├── DeepSeek-V2-Chat-sft-amx.yaml │ │ │ ├── DeepSeek-V2-Chat.yaml │ │ │ ├── DeepSeek-V2-Lite-Chat-multi-gpu.yaml │ │ │ ├── DeepSeek-V2-Lite-Chat-sft-amx-multi-gpu.yaml │ │ │ ├── DeepSeek-V2-Lite-Chat-sft-amx.yaml │ │ │ ├── DeepSeek-V2-Lite-Chat-sft.yaml │ │ │ ├── DeepSeek-V2-Lite-Chat-use-adapter.yaml │ │ │ ├── DeepSeek-V2-Lite-Chat.yaml │ │ │ ├── DeepSeek-V3-Chat-amx.yaml │ │ │ ├── DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve-amx.yaml │ │ │ ├── DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve.yaml │ │ │ ├── DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml │ │ │ ├── DeepSeek-V3-Chat-multi-gpu-4.yaml │ │ │ ├── DeepSeek-V3-Chat-multi-gpu-8.yaml │ │ │ ├── DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml │ │ │ ├── DeepSeek-V3-Chat-multi-gpu-marlin.yaml │ │ │ ├── DeepSeek-V3-Chat-multi-gpu.yaml │ │ │ ├── DeepSeek-V3-Chat-serve.yaml │ │ │ ├── DeepSeek-V3-Chat-sft-amx-multi-gpu-4.yaml │ │ │ ├── DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml │ │ │ ├── DeepSeek-V3-Chat-sft-amx.yaml │ │ │ ├── DeepSeek-V3-Chat.yaml │ │ │ ├── Internlm2_5-7b-Chat-1m.yaml │ │ │ ├── Mixtral.yaml │ │ │ ├── Moonlight-16B-A3B-serve.yaml │ │ │ ├── Moonlight-16B-A3B.yaml │ │ │ ├── Qwen2-57B-A14B-Instruct-multi-gpu.yaml │ │ │ ├── Qwen2-57B-A14B-Instruct.yaml │ │ │ ├── Qwen2-serve-amx.yaml │ │ │ ├── Qwen2-serve.yaml │ │ │ ├── Qwen3Moe-serve-amx.yaml │ │ │ ├── Qwen3Moe-serve.yaml │ │ │ ├── Qwen3Moe-sft-amx.yaml │ │ │ ├── rocm/ │ │ │ │ └── DeepSeek-V3-Chat.yaml │ │ │ └── xpu/ │ │ │ ├── DeepSeek-V2-Chat.yaml │ │ │ ├── DeepSeek-V3-Chat.yaml │ │ │ └── Qwen3Moe-Chat.yaml │ │ ├── server/ │ │ │ ├── __init__.py │ │ │ ├── api/ │ │ │ │ ├── __init__.py │ │ │ │ ├── ollama/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── completions.py │ │ │ │ ├── openai/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── assistants/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── assistants.py │ │ │ │ │ │ ├── messages.py │ │ │ │ │ │ ├── runs.py │ │ │ │ │ │ └── threads.py │ │ │ │ │ ├── endpoints/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── chat.py │ │ │ │ │ └── legacy/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── completions.py │ │ │ │ └── web/ │ │ │ │ ├── __init__.py │ │ │ │ └── system.py │ │ │ ├── args.py │ │ │ ├── backend/ │ │ │ │ ├── __init__.py │ │ │ │ ├── args.py │ │ │ │ ├── base.py │ │ │ │ ├── context_manager.py │ │ │ │ └── interfaces/ │ │ │ │ ├── __init__.py │ │ │ │ ├── balance_serve.py │ │ │ │ ├── exllamav2.py │ │ │ │ ├── ktransformers.py │ │ │ │ └── transformers.py │ │ │ ├── balance_serve/ │ │ │ │ ├── inference/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── config.py │ │ │ │ │ ├── distributed/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── communication_op.py │ │ │ │ │ │ ├── cuda_wrapper.py │ │ │ │ │ │ ├── custom_all_reduce.py │ │ │ │ │ │ ├── custom_all_reduce_utils.py │ │ │ │ │ │ ├── parallel_state.py │ │ │ │ │ │ ├── pynccl.py │ │ │ │ │ │ ├── pynccl_wrapper.py │ │ │ │ │ │ └── utils.py │ │ │ │ │ ├── forward_batch.py │ │ │ │ │ ├── model_runner.py │ │ │ │ │ ├── query_manager.py │ │ │ │ │ └── sampling/ │ │ │ │ │ ├── penaltylib/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── orchestrator.py │ │ │ │ │ │ └── penalizers/ │ │ │ │ │ │ ├── frequency_penalty.py │ │ │ │ │ │ ├── min_new_tokens.py │ │ │ │ │ │ ├── presence_penalty.py │ │ │ │ │ │ └── repetition_penalty.py │ │ │ │ │ └── sampler.py │ │ │ │ ├── sched_rpc.py │ │ │ │ └── settings.py │ │ │ ├── config/ │ │ │ │ ├── config.py │ │ │ │ ├── log.py │ │ │ │ └── singleton.py │ │ │ ├── crud/ │ │ │ │ ├── __init__.py │ │ │ │ └── assistants/ │ │ │ │ ├── __init__.py │ │ │ │ ├── assistants.py │ │ │ │ ├── messages.py │ │ │ │ ├── runs.py │ │ │ │ └── threads.py │ │ │ ├── exceptions.py │ │ │ ├── main.py │ │ │ ├── models/ │ │ │ │ ├── __init__.py │ │ │ │ └── assistants/ │ │ │ │ ├── __init__.py │ │ │ │ ├── assistants.py │ │ │ │ ├── messages.py │ │ │ │ ├── run_steps.py │ │ │ │ ├── runs.py │ │ │ │ └── threads.py │ │ │ ├── schemas/ │ │ │ │ ├── __init__.py │ │ │ │ ├── assistants/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── assistants.py │ │ │ │ │ ├── messages.py │ │ │ │ │ ├── runs.py │ │ │ │ │ ├── streaming.py │ │ │ │ │ ├── threads.py │ │ │ │ │ └── tool.py │ │ │ │ ├── base.py │ │ │ │ ├── conversation.py │ │ │ │ ├── endpoints/ │ │ │ │ │ └── chat.py │ │ │ │ └── legacy/ │ │ │ │ ├── __init__.py │ │ │ │ └── completions.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── create_interface.py │ │ │ ├── multi_timer.py │ │ │ └── sql_utils.py │ │ ├── sft/ │ │ │ ├── __init__.py │ │ │ ├── flops_utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── custom_profile.py │ │ │ │ └── lora_test_utils.py │ │ │ ├── lora.py │ │ │ ├── metrics.py │ │ │ ├── metrics_utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── constants.py │ │ │ │ ├── env.py │ │ │ │ ├── logging.py │ │ │ │ ├── misc.py │ │ │ │ ├── packages.py │ │ │ │ └── ploting.py │ │ │ ├── monkey_patch_torch_module.py │ │ │ ├── peft_utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── lora_layer.py │ │ │ │ ├── lora_model.py │ │ │ │ ├── mapping.py │ │ │ │ └── peft_model.py │ │ │ └── torchviz_test.py │ │ ├── tests/ │ │ │ ├── .gitignore │ │ │ ├── AIME_2024/ │ │ │ │ ├── eval_api.py │ │ │ │ ├── evaluation.py │ │ │ │ └── prompts.py │ │ │ ├── dequant_gpu.py │ │ │ ├── dequant_gpu_t.py │ │ │ ├── function_call_test.py │ │ │ ├── humaneval/ │ │ │ │ ├── eval_api.py │ │ │ │ ├── evaluation.py │ │ │ │ └── prompts.py │ │ │ ├── mmlu_pro_test.py │ │ │ ├── mmlu_test.py │ │ │ ├── mmlu_test_multi.py │ │ │ ├── score.py │ │ │ ├── test_client.py │ │ │ ├── test_pytorch_q8.py │ │ │ ├── test_speed.py │ │ │ └── triton_fp8gemm_test.py │ │ ├── util/ │ │ │ ├── cuda_graph_runner.py │ │ │ ├── custom_gguf.py │ │ │ ├── custom_loader.py │ │ │ ├── globals.py │ │ │ ├── grad_wrapper.py │ │ │ ├── inference_state.py │ │ │ ├── modeling_rope_utils.py │ │ │ ├── textstream.py │ │ │ ├── utils.py │ │ │ ├── vendors.py │ │ │ └── weight_loader.py │ │ └── website/ │ │ ├── .browserslistrc │ │ ├── .eslintrc.js │ │ ├── .gitignore │ │ ├── README.md │ │ ├── config.d.ts │ │ ├── jest.config.js │ │ ├── package.json │ │ ├── public/ │ │ │ ├── config.js │ │ │ ├── css/ │ │ │ │ └── reset.css │ │ │ └── index.html │ │ ├── src/ │ │ │ ├── App.vue │ │ │ ├── api/ │ │ │ │ ├── api-client.ts │ │ │ │ ├── assistant.ts │ │ │ │ ├── message.ts │ │ │ │ ├── run.ts │ │ │ │ └── thread.ts │ │ │ ├── assets/ │ │ │ │ ├── css/ │ │ │ │ │ └── mixins.styl │ │ │ │ └── iconfont/ │ │ │ │ ├── demo.css │ │ │ │ ├── demo_index.html │ │ │ │ ├── iconfont.css │ │ │ │ ├── iconfont.js │ │ │ │ └── iconfont.json │ │ │ ├── components/ │ │ │ │ └── chat/ │ │ │ │ └── index.vue │ │ │ ├── conf/ │ │ │ │ └── config.ts │ │ │ ├── locals/ │ │ │ │ ├── en.js │ │ │ │ ├── index.js │ │ │ │ └── zh.js │ │ │ ├── main.ts │ │ │ ├── router/ │ │ │ │ └── index.ts │ │ │ ├── shims-vue.d.ts │ │ │ ├── store/ │ │ │ │ └── index.ts │ │ │ ├── utils/ │ │ │ │ ├── copy.ts │ │ │ │ └── types.ts │ │ │ └── views/ │ │ │ └── home.vue │ │ ├── tests/ │ │ │ └── unit/ │ │ │ └── example.spec.ts │ │ ├── tsconfig.json │ │ └── vue.config.js │ ├── merge_tensors/ │ │ └── merge_safetensor_gguf.py │ ├── pyproject.toml │ ├── requirements-sft.txt │ ├── setup.py │ ├── test_adapter/ │ │ ├── data_transfer.py │ │ ├── infer_with_adapter.py │ │ ├── inspect_adapter.py │ │ ├── pred2metrics.py │ │ ├── test_grad.py │ │ └── time_test_lora_train.py │ └── withoutKT_PEFT.py ├── pyproject.toml ├── setup.py ├── third_party/ │ └── llamafile/ │ ├── README.md │ ├── bench.h │ ├── flags.cpp │ ├── flags.h │ ├── iqk_mul_mat.inc │ ├── iqk_mul_mat_amd_avx2.cpp │ ├── iqk_mul_mat_amd_zen4.cpp │ ├── iqk_mul_mat_arm.inc │ ├── iqk_mul_mat_arm82.cpp │ ├── macros.h │ ├── micros.h │ ├── numba.h │ ├── sgemm.cpp │ ├── sgemm.h │ ├── tinyblas_cpu.h │ ├── tinyblas_cpu_mixmul.inc │ ├── tinyblas_cpu_mixmul_amd_avx.cpp │ ├── tinyblas_cpu_mixmul_amd_avx2.cpp │ ├── tinyblas_cpu_mixmul_amd_avx512f.cpp │ ├── tinyblas_cpu_mixmul_amd_avxvnni.cpp │ ├── tinyblas_cpu_mixmul_amd_fma.cpp │ ├── tinyblas_cpu_mixmul_amd_zen4.cpp │ ├── tinyblas_cpu_mixmul_arm80.cpp │ ├── tinyblas_cpu_mixmul_arm82.cpp │ ├── tinyblas_cpu_sgemm.inc │ ├── tinyblas_cpu_sgemm_amd_avx.cpp │ ├── tinyblas_cpu_sgemm_amd_avx2.cpp │ ├── tinyblas_cpu_sgemm_amd_avx512f.cpp │ ├── tinyblas_cpu_sgemm_amd_avxvnni.cpp │ ├── tinyblas_cpu_sgemm_amd_fma.cpp │ ├── tinyblas_cpu_sgemm_amd_zen4.cpp │ ├── tinyblas_cpu_sgemm_arm80.cpp │ ├── tinyblas_cpu_sgemm_arm82.cpp │ └── tinyblas_cpu_unsupported.cpp └── version.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community include: * Demonstrating empathy and kindness toward other people * Being respectful of differing opinions, viewpoints, and experiences * Giving and gracefully accepting constructive feedback * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience * Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: * The use of sexualized language or imagery, and sexual attention or advances of any kind * Trolling, insulting or derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or email address, without their explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at [INSERT CONTACT METHOD]. All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. For answers to common questions about this code of conduct, see the FAQ at [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at [https://www.contributor-covenant.org/translations][translations]. [homepage]: https://www.contributor-covenant.org [v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html [Mozilla CoC]: https://github.com/mozilla/diversity [FAQ]: https://www.contributor-covenant.org/faq [translations]: https://www.contributor-covenant.org/translations ================================================ FILE: .github/CONTRIBUTING.md ================================================ ## Before Commit! Your commit message must follow Conventional Commits (https://www.conventionalcommits.org/) and your code should be formatted. The Git hooks will do most of the work automatically: ### Tool Requirements You need a recent `clang-format` (>= 18). In a conda environment you can install: ```shell conda install -c conda-forge clang-format=18 ``` If you previously configured with an older version, remove the build directory and reconfigure: ```shell rm -rf kt-kernel/build ``` Install `black` for Python formatting: ```shell conda install black ``` ### Install hook: ```shell bash kt-kernel/scripts/install-git-hooks.sh #or just cmake the kt-kernel cmake -S kt-kernel -B kt-kernel/build ``` There are manual commands if you need format. ```shell cmake -S kt-kernel -B kt-kernel/build cmake --build kt-kernel/build --target format ``` ## Developer Note Formatting and commit message rules are enforced by Git hooks. After installing `clang-format` and `black`, just commit normally—the hooks will run formatting for you. > [!NOTE] > If formatting modifies files, the commit is aborted after staging those changes. Review them and run `git commit` again. Repeat until no further formatting changes appear. --- ### Conventional Commit Regex (Reference) The commit-msg hook enforces this pattern: ```text regex='^\[(feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert|wip)\](\([^\)]+\))?(!)?: .+' ``` Meaning (English): * `[type]` required — one of feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert|wip * Optional scope: `(scope)` — any chars except `)` * Optional breaking change marker: `!` right after type or scope * Separator: `: ` (colon + space) * Subject: free text (at least one character) Examples: ```text [feat]: add adaptive batching [fix(parser)]: handle empty token list [docs]!: update API section for breaking rename ``` You can bypass locally (not recommended) with: ```shell git commit --no-verify ``` ## 提交前提醒 提交信息必须满足 Conventional Commits 规范 (https://www.conventionalcommits.org/),代码需要符合格式要求。Git 钩子已经集成了大部分工作: ### 软件要求 需要较新的 `clang-format` (>= 18),在 conda 环境中安装: ```shell conda install -c conda-forge clang-format=18 ``` 如果之前用老版本配置过,请删除构建目录重新配置: ```shell rm -rf kt-kernel/build ``` 安装 `black` 以进行 Python 文件格式化: ```shell conda install black ``` ### 安装钩子 ```shell bash kt-kernel/scripts/install-git-hooks.sh #or just cmake the kt-kernel cmake -S kt-kernel -B kt-kernel/build ``` 如果你需要手动格式化: ```shell cmake -S kt-kernel -B kt-kernel/build cmake --build kt-kernel/build --target format ``` ## 开发者说明 本仓库通过 Git hooks 自动执行代码格式化与提交信息规范检查。只需安装好 `clang-format` 与 `black` 后正常执行提交即可,钩子会自动格式化。 > [!NOTE] > 如果格式化修改了文件,钩子会终止提交并已暂存这些改动。请查看修改后再次执行 `git commit`,重复直到没有新的格式化变更。 ### 提交信息正则(参考) 钩子使用如下正则检查提交信息: ```text regex='^\[(feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert|wip)\](\([^\)]+\))?(!)?: .+' ``` 含义: * `[type]` 必填:feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert|wip * 作用域可选:`(scope)`,不能包含右括号 * 可选的破坏性标记:`!` * 分隔符:冒号+空格 `: ` * 描述:至少一个字符 示例: ```text [feat]: 增加自适应 batch 功能 [fix(tokenizer)]: 修复空 token 列表处理 [docs]!: 更新接口文档(存在破坏性修改) ``` 跳过钩子(不推荐,仅紧急时): ```shell git commit --no-verify ``` ================================================ FILE: .github/ISSUE_TEMPLATE/-bug-.yaml ================================================ name: "\U0001F41B Bug / Help" description: Create a report to help us improve the ktransformers project labels: ["pending"] body: - type: markdown attributes: value: | Issues included in **[FAQs](https://github.com/kvcache-ai/ktransformers/issues/1608)** or those with **insufficient** information may be closed without a response. 已经包含在 **[常见问题](https://github.com/kvcache-ai/ktransformers/issues/1608)** 内或提供信息**不完整**的 issues 可能不会被回复。 - type: checkboxes id: reminder attributes: label: Reminder description: | Please ensure you have read the above rules carefully and searched the existing issues (including FAQs). 请确保您已经认真阅读了上述规则并且搜索过现有的 issues(包括常见问题)。 options: - label: I have read the above rules and searched the existing issues. required: true - type: textarea id: system-info validations: required: true attributes: label: System Info description: | Please share your system info with us. You can run the command **lscpu**, ** nvidia-smi ** etc. and copy-paste its output below. 请提供您的系统信息。您可以在命令行运行 **lscpu**, **nvidia-smi** 等命令,并将其输出复制到该文本框中。 placeholder: ktransformers version,sglang version, platform, python version, cpu info, GPU/NPU info ... - type: textarea id: reproduction validations: required: true attributes: label: Reproduction description: | Please provide entry arguments, error messages and stack traces that reproduces the problem. 请提供入口参数,错误日志以及异常堆栈以便于我们复现问题。 value: | ```text Put your message here. ``` - type: textarea id: others validations: required: false attributes: label: Others ================================================ FILE: .github/ISSUE_TEMPLATE/-feature-.yaml ================================================ name: "\U0001F680 Feature request" description: Submit a request for a new feature labels: ["enhancement", "pending"] body: - type: markdown attributes: value: | Please do not create issues that are not related to new features under this category. 请勿在此分类下创建和新特性无关的 issues。 - type: checkboxes id: reminder attributes: label: Reminder description: | Please ensure you have read the above rules carefully and searched the existing issues. 请确保您已经认真阅读了上述规则并且搜索过现有的 issues。 options: - label: I have read the above rules and searched the existing issues. required: true - type: textarea id: description validations: required: true attributes: label: Description description: | A clear and concise description of the feature proposal. 请详细描述您希望加入的新功能特性。 - type: textarea id: contribution validations: required: false attributes: label: Pull Request description: | Have you already created the relevant PR and submitted the code? 您是否已经创建了相关 PR 并提交了代码? ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ blank_issues_enabled: false contact_links: - name: 📚 FAQs | 常见问题 url: https://github.com/kvcache-ai/ktransformers/issues/1608 about: Reading in advance is recommended | 建议提前阅读 ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ # What does this PR do? Fixes # (issue) ## Before submitting - [ ] Did you read the [contributor guideline](https://github.com/kvcache-ai/ktransformers/blob/main/.github/CONTRIBUTING.md)? - [ ] Did you write any new necessary tests? ================================================ FILE: .github/SECURITY.md ================================================ # Reporting Security Issues To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/kvcache-ai/ktransformers/security/advisories/new) tab. We will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance. Report security bugs in third-party modules to the person or team maintaining the module. ================================================ FILE: .github/workflows/book-ci.yml ================================================ name: Book-CI on: push: branches: - main # - server_support pull_request: branches: - main # - server_support jobs: test: name: test runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] steps: - uses: actions/checkout@v4 - name: Install Rust run: | rustup set profile minimal rustup toolchain install stable rustup default stable - name: Setup mdBook uses: peaceiris/actions-mdbook@v2 with: mdbook-version: "latest" # - name: Run tests # run: mdbook test ================================================ FILE: .github/workflows/deploy.yml ================================================ name: Deploy on: push: branches: - main # - server_support pull_request: branches: - main # - server_support defaults: run: shell: bash permissions: contents: write jobs: deploy: runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] steps: - uses: actions/checkout@v4 - name: Install Rust run: | rustup set profile minimal rustup toolchain install stable rustup default stable - name: Setup mdBook uses: peaceiris/actions-mdbook@v2 with: mdbook-version: "latest" - run: mdbook build # - name: Copy Assets # run: | # chmod +x ci/copy-assets.sh # ci/copy-assets.sh ${{ matrix.os }} - name: Deploy uses: peaceiris/actions-gh-pages@v3 # or || github.ref == 'refs/heads/server_support' if: ${{ github.ref == 'refs/heads/main' }} with: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: ./book ================================================ FILE: .github/workflows/docker-image.yml ================================================ name: DockerHub CI on: release: types: [published] workflow_dispatch: inputs: push_to_dockerhub: description: 'Push image to DockerHub? (true/false)' required: true default: 'false' type: boolean cuda_version: description: 'CUDA version (e.g., 12.8.1)' required: false default: '12.8.1' type: string push_simplified_tag: description: 'Also push simplified tag? (true/false)' required: false default: 'true' type: boolean ubuntu_mirror: description: 'Use Tsinghua Ubuntu mirror? (0/1)' required: false default: '0' type: string # push: # branches: # - main env: DOCKERHUB_REPO: ${{ secrets.DOCKERHUB_USERNAME }}/ktransformers jobs: test: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Run tests run: | if [ -f docker-compose.test.yml ]; then docker-compose --file docker-compose.test.yml build docker-compose --file docker-compose.test.yml run sut else docker build . --file docker/Dockerfile fi build-and-push: needs: test name: Build and Push Multi-Variant Docker Image runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v4 - name: Move Docker data directory run: | sudo systemctl stop docker sudo mkdir -p /mnt/docker sudo rsync -avz /var/lib/docker/ /mnt/docker sudo rm -rf /var/lib/docker sudo ln -s /mnt/docker /var/lib/docker sudo systemctl start docker - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@v3 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Determine build parameters id: params run: | # Determine if we should push if [ "${{ github.event_name }}" = "release" ]; then echo "should_push=true" >> $GITHUB_OUTPUT echo "push_simplified=true" >> $GITHUB_OUTPUT elif [ "${{ github.event_name }}" = "workflow_dispatch" ]; then echo "should_push=${{ inputs.push_to_dockerhub }}" >> $GITHUB_OUTPUT echo "push_simplified=${{ inputs.push_simplified_tag }}" >> $GITHUB_OUTPUT else echo "should_push=false" >> $GITHUB_OUTPUT echo "push_simplified=false" >> $GITHUB_OUTPUT fi # Determine CUDA version if [ "${{ github.event_name }}" = "workflow_dispatch" ] && [ -n "${{ inputs.cuda_version }}" ]; then echo "cuda_version=${{ inputs.cuda_version }}" >> $GITHUB_OUTPUT else echo "cuda_version=12.8.1" >> $GITHUB_OUTPUT fi # Determine Ubuntu mirror setting if [ "${{ github.event_name }}" = "workflow_dispatch" ] && [ -n "${{ inputs.ubuntu_mirror }}" ]; then echo "ubuntu_mirror=${{ inputs.ubuntu_mirror }}" >> $GITHUB_OUTPUT else echo "ubuntu_mirror=0" >> $GITHUB_OUTPUT fi - name: Build and push Docker image run: | cd docker # Build command arguments BUILD_ARGS=( --cuda-version "${{ steps.params.outputs.cuda_version }}" --ubuntu-mirror "${{ steps.params.outputs.ubuntu_mirror }}" --repository "${{ env.DOCKERHUB_REPO }}" ) # Add simplified tag option if enabled if [ "${{ steps.params.outputs.push_simplified }}" = "true" ]; then BUILD_ARGS+=(--also-push-simplified) fi # Add HTTP proxy if available if [ -n "${{ secrets.HTTP_PROXY }}" ]; then BUILD_ARGS+=(--http-proxy "${{ secrets.HTTP_PROXY }}") fi # Add HTTPS proxy if available if [ -n "${{ secrets.HTTPS_PROXY }}" ]; then BUILD_ARGS+=(--https-proxy "${{ secrets.HTTPS_PROXY }}") fi # Dry run if not pushing if [ "${{ steps.params.outputs.should_push }}" != "true" ]; then BUILD_ARGS+=(--dry-run) fi # Execute build script ./push-to-dockerhub.sh "${BUILD_ARGS[@]}" - name: Display image information if: steps.params.outputs.should_push == 'true' run: | echo "::notice title=Docker Image::Image pushed successfully to ${{ env.DOCKERHUB_REPO }}" echo "Pull command: docker pull ${{ env.DOCKERHUB_REPO }}:v\$(VERSION)-cu\$(CUDA_SHORT)" ================================================ FILE: .github/workflows/kt-kernel-tests.yml ================================================ name: PR KT-Kernel Test on: pull_request: branches: - main - develop types: [synchronize, labeled] workflow_dispatch: concurrency: group: pr-kt-kernel-test-${{ github.ref }} cancel-in-progress: true jobs: # =============================================== check changes ==================================================== check-changes: runs-on: ubuntu-latest outputs: kt_kernel: ${{ steps.filter.outputs.kt_kernel }} steps: - name: Checkout code uses: actions/checkout@v4 - name: Fail if the PR does not have the 'run-ci' label if: github.event_name == 'pull_request' && !contains(github.event.pull_request.labels.*.name, 'run-ci') run: | echo "This pull request does not have the 'run-ci' label. Failing the workflow." exit 1 - name: Fail if the PR is a draft if: github.event_name == 'pull_request' && github.event.pull_request.draft == true run: | echo "This pull request is a draft. Failing the workflow." exit 1 - name: Detect file changes id: filter uses: dorny/paths-filter@v3 with: filters: | kt_kernel: - "kt-kernel/**" - ".github/workflows/kt-kernel-tests.yml" # =============================================== KT-Kernel tests ==================================================== per-commit-kt-kernel-cpu: needs: [check-changes] if: always() && !failure() && !cancelled() && (needs.check-changes.outputs.kt_kernel == 'true' || github.event_name == 'workflow_dispatch') runs-on: kt-cpu continue-on-error: false steps: - name: Cleanup run: | sudo rm -rf $GITHUB_WORKSPACE/* || true - name: Checkout code uses: actions/checkout@v4 with: submodules: recursive - name: Install KT-Kernel run: | cd kt-kernel bash install.sh build - name: Run KT-Kernel CPU tests timeout-minutes: 60 run: | cd kt-kernel/test python3 run_suite.py --hw cpu --suite default # =============================================== finish ==================================================== pr-test-kt-kernel-finish: needs: [check-changes, per-commit-kt-kernel-cpu] if: always() runs-on: ubuntu-latest steps: - name: Check all dependent job statuses run: | # Convert the 'needs' context to a JSON string json_needs='${{ toJson(needs) }}' # Get a list of all job names from the JSON keys job_names=$(echo "$json_needs" | jq -r 'keys_unsorted[]') for job in $job_names; do # For each job, extract its result result=$(echo "$json_needs" | jq -r --arg j "$job" '.[$j].result') # Print the job name and its result echo "$job: $result" # Check for failure or cancellation and exit if found if [[ "$result" == "failure" || "$result" == "cancelled" ]]; then echo "The above jobs failed." exit 1 fi done # If the loop completes, all jobs were successful echo "All jobs completed successfully" exit 0 ================================================ FILE: .github/workflows/release-fake-tag.yml ================================================ name: Release Fake Tag on: push: branches: - main paths: - "version.py" workflow_dispatch: permissions: contents: write jobs: publish: if: github.repository == 'kvcache-ai/ktransformers' runs-on: ubuntu-latest environment: 'prod' steps: - name: Checkout repository uses: actions/checkout@v4 with: token: ${{ secrets.GITHUB_TOKEN }} - name: Get version id: get_version run: | version=$(cat version.py | grep '__version__' | cut -d'"' -f2) echo "TAG=v$version" >> $GITHUB_OUTPUT - name: Create and push tag run: | git config user.name "ktransformers-bot" git config user.email "ktransformers-bot@users.noreply.github.com" git tag ${{ steps.get_version.outputs.TAG }} git push origin ${{ steps.get_version.outputs.TAG }} ================================================ FILE: .github/workflows/release-pypi.yml ================================================ name: Release to PyPI on: push: branches: - main paths: - "version.py" workflow_dispatch: inputs: test_pypi: description: 'Publish to TestPyPI instead of PyPI (for testing)' required: false default: 'false' type: choice options: - 'true' - 'false' permissions: contents: read jobs: # ── sglang-kt (must be on PyPI before users can pip install kt-kernel) ── build-and-publish-sglang-kt: name: Build & publish sglang-kt runs-on: [self-hosted, linux, x64] if: github.repository == 'kvcache-ai/ktransformers' && github.ref == 'refs/heads/main' environment: prod permissions: id-token: write contents: read steps: - name: Checkout repository uses: actions/checkout@v4 with: submodules: recursive - name: Set up Python uses: actions/setup-python@v4 with: python-version: '3.12' - name: Install build tools run: | python -m pip install --upgrade pip pip install build wheel setuptools twine - name: Build sglang-kt wheel working-directory: third_party/sglang/python run: | KT_VERSION=$(python3 -c "exec(open('${{ github.workspace }}/version.py').read()); print(__version__)") export SGLANG_KT_VERSION="$KT_VERSION" echo "Building sglang-kt v${KT_VERSION} wheel..." python -m build --wheel -v ls dist/ | grep -q "sglang_kt" || (echo "ERROR: Wheel name does not contain sglang_kt" && exit 1) - name: Publish sglang-kt to PyPI if: github.event.inputs.test_pypi != 'true' env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} run: | python -m twine upload --skip-existing --verbose third_party/sglang/python/dist/*.whl - name: Publish sglang-kt to TestPyPI (if requested) if: github.event.inputs.test_pypi == 'true' env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }} run: | python -m twine upload --repository testpypi --skip-existing --verbose third_party/sglang/python/dist/*.whl # ── kt-kernel ── build-kt-kernel: name: Build kt-kernel (Python ${{ matrix.python-version }}) runs-on: [self-hosted, linux, x64, gpu] strategy: fail-fast: false matrix: python-version: ['3.11', '3.12'] steps: - name: Checkout repository uses: actions/checkout@v4 with: submodules: recursive - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Verify CUDA availability run: | nvidia-smi || (echo "ERROR: GPU not available" && exit 1) nvcc --version || (echo "ERROR: CUDA toolkit not found" && exit 1) - name: Install dependencies run: | apt-get update && apt-get install -y cmake libhwloc-dev pkg-config libnuma-dev python -m pip install --upgrade pip pip install build wheel setuptools torch --index-url https://download.pytorch.org/whl/cu118 - name: Build kt-kernel wheel working-directory: kt-kernel env: CPUINFER_BUILD_ALL_VARIANTS: '1' CPUINFER_USE_CUDA: '1' CPUINFER_CUDA_ARCHS: '80;86;89;90' CPUINFER_CUDA_STATIC_RUNTIME: '1' CPUINFER_BUILD_TYPE: 'Release' CPUINFER_PARALLEL: '4' CPUINFER_FORCE_REBUILD: '1' CUDA_HOME: '/usr/local/cuda-11.8' run: | echo "Building kt-kernel with:" echo " - CUDA support (SM 80, 86, 89, 90)" echo " - CPU multi-variant (AMX, AVX512, AVX2)" python -m build --wheel -v - name: Verify wheel working-directory: kt-kernel run: | echo "Generated wheel:" ls -lh dist/ # Install and test pip install dist/*.whl python -c "import kt_kernel; print(f'✓ Version: {kt_kernel.__version__}')" python -c "import kt_kernel; print(f'✓ CPU variant: {kt_kernel.__cpu_variant__}')" # Verify CUDA support python -c " from kt_kernel import kt_kernel_ext cpu_infer = kt_kernel_ext.CPUInfer(4) methods = dir(cpu_infer) has_cuda = 'submit_with_cuda_stream' in methods print(f'✓ CUDA support: {has_cuda}') " # Verify CPU multi-variant support echo "Checking CPU variants in wheel..." python -m zipfile -l dist/*.whl | grep "_kt_kernel_ext_" || echo "Warning: No variant .so files found" python -m zipfile -l dist/*.whl | grep "_kt_kernel_ext_amx.cpython" && echo "✓ AMX variant found" || echo "Note: AMX variant missing" python -m zipfile -l dist/*.whl | grep "_kt_kernel_ext_avx512" && echo "✓ AVX512 variants found" || echo "Note: AVX512 variants missing" python -m zipfile -l dist/*.whl | grep "_kt_kernel_ext_avx2.cpython" && echo "✓ AVX2 variant found" || echo "Note: AVX2 variant missing" # Verify static linking (should NOT depend on libcudart.so) rm -rf /tmp/check unzip -q dist/*.whl -d /tmp/check if ldd /tmp/check/kt_kernel/*.so 2>/dev/null | grep -q "libcudart.so"; then echo "ERROR: Dynamic cudart found, should be statically linked" exit 1 else echo "✓ CUDA runtime statically linked" fi - name: Repair wheel for manylinux working-directory: kt-kernel run: | pip install auditwheel patchelf mkdir -p wheelhouse for wheel in dist/*.whl; do auditwheel repair "$wheel" --plat manylinux_2_17_x86_64 --exclude libcuda.so.1 -w wheelhouse/ || \ cp "$wheel" wheelhouse/$(basename "$wheel" | sed 's/linux_x86_64/manylinux_2_17_x86_64/') done rm -f dist/*.whl && cp wheelhouse/*.whl dist/ - name: Upload artifact uses: actions/upload-artifact@v4 with: name: kt-kernel-wheels-py${{ matrix.python-version }} path: kt-kernel/dist/*.whl retention-days: 7 publish-pypi: name: Publish kt-kernel to PyPI needs: [build-and-publish-sglang-kt, build-kt-kernel] runs-on: [self-hosted, linux, x64] if: github.repository == 'kvcache-ai/ktransformers' && github.ref == 'refs/heads/main' environment: prod permissions: id-token: write # For trusted publishing (OIDC) contents: read steps: - name: Download all wheel artifacts uses: actions/download-artifact@v4 with: path: artifacts/ - name: Organize wheels into dist/ run: | mkdir -p dist/ find artifacts/ -name "*.whl" -exec cp {} dist/ \; echo "Wheels to publish:" ls -lh dist/ - name: Get version from wheel id: get_version run: | # Extract version from first wheel filename wheel_name=$(ls dist/*.whl | head -1 | xargs basename) # Extract version (format: kt_kernel-X.Y.Z-...) version=$(echo "$wheel_name" | sed 's/kt_kernel-\([0-9.]*\)-.*/\1/') echo "VERSION=$version" >> $GITHUB_OUTPUT echo "Publishing version: $version" - name: Install twine run: | python -m pip install --upgrade pip pip install twine - name: Publish to TestPyPI (if requested) if: github.event.inputs.test_pypi == 'true' env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }} run: | python -m twine upload \ --repository testpypi \ --skip-existing \ --verbose \ dist/*.whl - name: Publish to PyPI if: github.event.inputs.test_pypi != 'true' env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} run: | python -m twine upload \ --skip-existing \ --verbose \ dist/*.whl - name: Create release summary run: | echo "## 🎉 kt-kernel v${{ steps.get_version.outputs.VERSION }} Published to PyPI" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY echo "### Installation" >> $GITHUB_STEP_SUMMARY echo '```bash' >> $GITHUB_STEP_SUMMARY echo "pip install kt-kernel==${{ steps.get_version.outputs.VERSION }}" >> $GITHUB_STEP_SUMMARY echo '```' >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY echo "### Published Wheels" >> $GITHUB_STEP_SUMMARY echo "Total: $(ls -1 dist/*.whl | wc -l) wheels (Python 3.10, 3.11, 3.12)" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY echo "### Features" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY echo "**CPU Multi-Variant Support:**" >> $GITHUB_STEP_SUMMARY echo "- ✅ AMX (Intel Sapphire Rapids+, 2023)" >> $GITHUB_STEP_SUMMARY echo "- ✅ AVX512 Base/VNNI/VBMI/BF16 (Intel Skylake-X/Ice Lake/Cascade Lake, 2017+)" >> $GITHUB_STEP_SUMMARY echo "- ✅ AVX2 (Maximum compatibility, 2013+)" >> $GITHUB_STEP_SUMMARY echo "- 🔧 Runtime CPU detection: Automatically selects optimal variant" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY echo "**CUDA Support:**" >> $GITHUB_STEP_SUMMARY echo "- ✅ SM 80 (Ampere: A100, RTX 3000 series)" >> $GITHUB_STEP_SUMMARY echo "- ✅ SM 86 (Ampere: RTX 3060-3090)" >> $GITHUB_STEP_SUMMARY echo "- ✅ SM 89 (Ada Lovelace: RTX 4000 series)" >> $GITHUB_STEP_SUMMARY echo "- ✅ SM 90 (Hopper: H100)" >> $GITHUB_STEP_SUMMARY echo "- 🔧 Static CUDA runtime: Compatible with CUDA 11.8+ and 12.x drivers" >> $GITHUB_STEP_SUMMARY echo "- 🔧 Works on CPU-only systems (CUDA features disabled gracefully)" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY echo "**Requirements:**" >> $GITHUB_STEP_SUMMARY echo "- Python 3.10, 3.11, or 3.12" >> $GITHUB_STEP_SUMMARY echo "- Linux x86-64 (manylinux_2_17 compatible)" >> $GITHUB_STEP_SUMMARY echo "- For CUDA features: NVIDIA driver with CUDA 11.8+ or 12.x support" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY echo "PyPI link: https://pypi.org/project/kt-kernel/${{ steps.get_version.outputs.VERSION }}/" >> $GITHUB_STEP_SUMMARY ================================================ FILE: .github/workflows/release-sglang-kt.yml ================================================ name: Release sglang-kt to PyPI on: push: branches: - main paths: - "third_party/sglang" - "version.py" workflow_dispatch: inputs: test_pypi: description: 'Publish to TestPyPI instead of PyPI (for testing)' required: false default: 'false' type: choice options: - 'true' - 'false' permissions: contents: read jobs: build-sglang-kt: name: Build sglang-kt wheel runs-on: [self-hosted, linux, x64] steps: - name: Checkout repository uses: actions/checkout@v4 with: submodules: recursive - name: Set up Python uses: actions/setup-python@v4 with: python-version: '3.12' - name: Install build tools run: | python -m pip install --upgrade pip pip install build wheel setuptools - name: Build sglang-kt wheel working-directory: third_party/sglang/python run: | # Read version from ktransformers version.py KT_VERSION=$(python3 -c "exec(open('${{ github.workspace }}/version.py').read()); print(__version__)") export SGLANG_KT_VERSION="$KT_VERSION" echo "Building sglang-kt v${KT_VERSION} wheel..." python -m build --wheel -v - name: Verify wheel working-directory: third_party/sglang/python run: | echo "Generated wheel:" ls -lh dist/ # Verify the wheel has the correct package name ls dist/ | grep -q "sglang_kt" || (echo "ERROR: Wheel name does not contain sglang_kt" && exit 1) echo "Wheel name verified." - name: Upload artifact uses: actions/upload-artifact@v4 with: name: sglang-kt-wheel path: third_party/sglang/python/dist/*.whl retention-days: 7 publish-pypi: name: Publish sglang-kt to PyPI needs: [build-sglang-kt] runs-on: [self-hosted, linux, x64] if: github.repository == 'kvcache-ai/ktransformers' && github.ref == 'refs/heads/main' environment: prod permissions: id-token: write contents: read steps: - name: Download wheel artifact uses: actions/download-artifact@v4 with: name: sglang-kt-wheel path: dist/ - name: Display wheels run: | echo "Wheels to publish:" ls -lh dist/ - name: Install twine run: | python -m pip install --upgrade pip pip install twine - name: Publish to TestPyPI (if requested) if: github.event.inputs.test_pypi == 'true' env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }} run: | python -m twine upload \ --repository testpypi \ --skip-existing \ --verbose \ dist/*.whl - name: Publish to PyPI if: github.event.inputs.test_pypi != 'true' env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} run: | python -m twine upload \ --skip-existing \ --verbose \ dist/*.whl - name: Create release summary run: | echo "## sglang-kt Published to PyPI" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY echo "### Installation" >> $GITHUB_STEP_SUMMARY echo '```bash' >> $GITHUB_STEP_SUMMARY echo "pip install sglang-kt" >> $GITHUB_STEP_SUMMARY echo '```' >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY echo "This is the kvcache-ai fork of SGLang with kt-kernel support." >> $GITHUB_STEP_SUMMARY echo "PyPI link: https://pypi.org/project/sglang-kt/" >> $GITHUB_STEP_SUMMARY ================================================ FILE: .github/workflows/sync-sglang-submodule.yml ================================================ name: Sync sglang submodule on: schedule: # Run daily at 08:00 UTC - cron: "0 8 * * *" workflow_dispatch: permissions: contents: write pull-requests: write jobs: sync: name: Check for sglang-kt updates runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v4 with: submodules: true fetch-depth: 0 token: ${{ secrets.GITHUB_TOKEN }} - name: Update sglang submodule to latest main id: update run: | OLD_SHA=$(git -C third_party/sglang rev-parse HEAD) git submodule update --remote third_party/sglang NEW_SHA=$(git -C third_party/sglang rev-parse HEAD) echo "old_sha=$OLD_SHA" >> "$GITHUB_OUTPUT" echo "new_sha=$NEW_SHA" >> "$GITHUB_OUTPUT" if [ "$OLD_SHA" = "$NEW_SHA" ]; then echo "changed=false" >> "$GITHUB_OUTPUT" echo "sglang submodule is already up to date ($OLD_SHA)" else echo "changed=true" >> "$GITHUB_OUTPUT" # Collect commit log between old and new COMMITS=$(git -C third_party/sglang log --oneline "$OLD_SHA..$NEW_SHA" | head -20) echo "commits<> "$GITHUB_OUTPUT" echo "$COMMITS" >> "$GITHUB_OUTPUT" echo "EOF" >> "$GITHUB_OUTPUT" # sglang-kt version = ktransformers version (from version.py) VERSION=$(python3 -c "exec(open('version.py').read()); print(__version__)" 2>/dev/null || echo "unknown") echo "version=$VERSION" >> "$GITHUB_OUTPUT" echo "sglang submodule updated: $OLD_SHA -> $NEW_SHA (v$VERSION)" fi - name: Create pull request if: steps.update.outputs.changed == 'true' uses: peter-evans/create-pull-request@v6 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: | [build]: sync sglang submodule to ${{ steps.update.outputs.new_sha }} branch: auto/sync-sglang delete-branch: true title: "[build] Sync sglang-kt submodule (v${{ steps.update.outputs.version }})" body: | Automated sync of `third_party/sglang` submodule to latest `main`. **Old ref:** `${{ steps.update.outputs.old_sha }}` **New ref:** `${{ steps.update.outputs.new_sha }}` **sglang-kt version:** `${{ steps.update.outputs.version }}` ### Commits included ``` ${{ steps.update.outputs.commits }} ``` --- *This PR was created automatically by the [sync-sglang-submodule](${{ github.server_url }}/${{ github.repository }}/actions/workflows/sync-sglang-submodule.yml) workflow.* labels: | dependencies automated ================================================ FILE: .gitignore ================================================ __pycache__ build .vscode *.so *.cache server.db logs node_modules *.nsys-rep .vs/ *pycache* *build/ .DS_Store compile_commands.json *.egg-info* *dist/ ktransformers/server/local_store/ ktransformers/server_test1.db *.patch img/ tmp*.txt test.txt book ktransformers/tests/chat_txt.txt mmlu_result* ktransformers/ktransformers_ext/cuda_musa/ test_prompt.txt csrc/demo build* CMakeFiles/ kvc2/ sched/ *.png ================================================ FILE: .gitmodules ================================================ [submodule "third_party/llama.cpp"] path = third_party/llama.cpp url = https://github.com/ggerganov/llama.cpp.git [submodule "third_party/pybind11"] path = third_party/pybind11 url = https://github.com/pybind/pybind11.git [submodule "third_party/custom_flashinfer"] path = third_party/custom_flashinfer url = https://github.com/kvcache-ai/custom_flashinfer.git branch = fix-precision-mla-merge-main [submodule "third_party/sglang"] path = third_party/sglang url = https://github.com/kvcache-ai/sglang.git branch = main ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: MAINTAINERS.md ================================================ # Maintainers This document lists the current maintainers and outlines their responsibilities. ## Current Maintainers | Name | GitHub | Role | Affiliation | Email | |------|--------|------|-------------|-------| | Weiyu Xie | [@ErvinXie](https://github.com/ErvinXie) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | xwy21@mails.tsinghua.edu.cn | | Hongtao Chen | [@chenht2022](https://github.com/chenht2022) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | cht22@mails.tsinghua.edu.cn | | Jianwei Dong | [@ovowei](https://github.com/ovowei) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | dongjw24@mails.tsinghua.edu.cn | | Ziwei Yuan | [@KMSorSMS](https://github.com/KMSorSMS) | Maintainer | [Approaching.AI](http://approaching.ai/) | 2022090910005@std.uestc.edu.cn | | Qingliang Ou | [@ouqingliang](https://github.com/ouqingliang) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | oql@bupt.edu.cn | | Jiaqi Liao | [@SkqLiao](https://github.com/SkqLiao) | Maintainer | [Approaching.AI](http://approaching.ai/) | jiaqi.liao@bit.edu.cn | | Peilin Li | [@JimmyPeilinLi](https://github.com/JimmyPeilinLi) | Maintainer | [Approaching.AI](http://approaching.ai/) | lipeilin@mail.nwpu.edu.cn | | Xingxing Hao | [@mrhaoxx](https://github.com/mrhaoxx) | Maintainer | [Approaching.AI](http://approaching.ai/) | mr.haoxx@gmail.com | | Boxin Zhang | [@Atream](https://github.com/Atream) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | zhangbx24@mails.tsinghua.edu.cn | | Jingqi Tang | [@Azure-Tang](https://github.com/Azure-Tang) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | tangjq25@mails.tsinghua.edu.cn | | Jiahao Wang | [@qiyuxinlin](https://github.com/qiyuxinlin) | Maintainer | [Approaching.AI](http://approaching.ai/) | 202241050020@hdu.edu.cn | ## Responsibilities Maintainers steward the project and keep it healthy for users and contributors. - Review and approve pull requests; ensure changes meet quality, testing, and documentation standards. - Triage issues, keep labels organized, and respond to questions in a timely manner. - Uphold the project’s code of conduct and report violations when needed. - Maintain CI reliability and address regressions promptly. - Oversee releases and keep compatibility with supported dependency versions. - Protect project security and follow the security disclosure process. ## Becoming a Maintainer We welcome contributors who show sustained, high-quality contributions and collaborative behavior. If you are interested, please contact an existing maintainer and share your recent contributions and areas of focus. ================================================ FILE: README.md ================================================

KTransformers

A Flexible Framework for Experiencing Cutting-edge LLM Inference/Fine-tune Optimizations

🎯 Overview | 🚀 kt-kernel | 🎓 kt-sft | 🔥 Citation | 🚀 Roadmap(2025Q4)
## 🎯 Overview KTransformers is a research project focused on efficient inference and fine-tuning of large language models through CPU-GPU heterogeneous computing. The project has evolved into **two core modules**: [kt-kernel](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel/) and [kt-sft](https://github.com/kvcache-ai/ktransformers/tree/main/kt-sft). ## 🔥 Updates * **Feb 13, 2026**: MiniMax-M2.5 Day0 Support! ([Tutorial](./doc/en/MiniMax-M2.5.md)) * **Feb 12, 2026**: GLM-5 Day0 Support! ([Tutorial](./doc/en/kt-kernel/GLM-5-Tutorial.md)) * **Jan 27, 2026**: Kimi-K2.5 Day0 Support! ([Tutorial](./doc/en/Kimi-K2.5.md)) ([SFT Tutorial](./doc/en/SFT_Installation_Guide_KimiK2.5.md)) * **Jan 22, 2026**: Support [CPU-GPU Expert Scheduling](./doc/en/kt-kernel/experts-sched-Tutorial.md), [Native BF16 and FP8 per channel Precision](./doc/en/kt-kernel/Native-Precision-Tutorial.md) and [AutoDL unified fine-tuning and inference](./doc/zh/【云端低价训推】%20KTransformers%2BAutoDL%2BLlamaFactory:随用随租的低成本超大模型「微调%2B推理」一体化流程.pdf) * **Dec 24, 2025**: Support Native MiniMax-M2.1 inference. ([Tutorial](./doc/en/kt-kernel/MiniMax-M2.1-Tutorial.md)) * **Dec 22, 2025**: Support RL-DPO fine-tuning with LLaMA-Factory. ([Tutorial](./doc/en/SFT/DPO_tutorial.md)) * **Dec 5, 2025**: Support Native Kimi-K2-Thinking inference ([Tutorial](./doc/en/kt-kernel/Kimi-K2-Thinking-Native.md)) * **Nov 6, 2025**: Support Kimi-K2-Thinking inference ([Tutorial](./doc/en/Kimi-K2-Thinking.md)) and fine-tune ([Tutorial](./doc/en/SFT_Installation_Guide_KimiK2.md)) * **Nov 4, 2025**: KTransformers Fine-Tuning × LLaMA-Factory Integration. ([Tutorial](./doc/en/KTransformers-Fine-Tuning_User-Guide.md)) * **Oct 27, 2025**: Support Ascend NPU. ([Tutorial](./doc/zh/DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md)) * **Oct 10, 2025**: Integrating into SGLang. ([Roadmap](https://github.com/sgl-project/sglang/issues/11425), [Blog](https://lmsys.org/blog/2025-10-22-KTransformers/)) * **Sept 11, 2025**: Support Qwen3-Next. ([Tutorial](./doc/en/Qwen3-Next.md)) * **Sept 05, 2025**: Support Kimi-K2-0905. ([Tutorial](./doc/en/Kimi-K2.md)) * **July 26, 2025**: Support SmallThinker and GLM4-MoE. ([Tutorial](./doc/en/SmallThinker_and_Glm4moe.md)) * **July 11, 2025**: Support Kimi-K2. ([Tutorial](./doc/en/Kimi-K2.md)) * **June 30, 2025**: Support 3-layer (GPU-CPU-Disk) [prefix cache](./doc/en/prefix_cache.md) reuse. * **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)). * **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md)) * **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./doc/en/llama4.md)). * **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./doc/en/balance-serve.md)). * **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)). * **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022--v023-longer-context--fp8-kernel) for DeepSeek-V3 and R1 in 24GB VRAM. * **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context). * **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/). * **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. For detailed show case and reproduction tutorial, see [here](./doc/en/DeepseekR1_V3_tutorial.md). * **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G. * **Aug 15, 2024**: Update detailed [tutorial](doc/en/injection_tutorial.md) for injection and multi-GPU. * **Aug 14, 2024**: Support llamfile as linear backend. * **Aug 12, 2024**: Support multiple GPU; Support new model: mixtral 8\*7B and 8\*22B; Support q2k, q3k, q5k dequant on gpu. * **Aug 9, 2024**: Support windows native. --- ## 📦 Core Modules ### 🚀 [kt-kernel](./kt-kernel/) - High-Performance Inference Kernels CPU-optimized kernel operations for heterogeneous LLM inference. image **Key Features:** - **AMX/AVX Acceleration**: Intel AMX and AVX512/AVX2 optimized kernels for INT4/INT8 quantized inference - **MoE Optimization**: Efficient Mixture-of-Experts inference with NUMA-aware memory management - **Quantization Support**: CPU-side INT4/INT8 quantized weights, GPU-side GPTQ support - **Easy Integration**: Clean Python API for SGLang and other frameworks **Quick Start:** ```bash cd kt-kernel pip install . ``` **Use Cases:** - CPU-GPU hybrid inference for large MoE models - Integration with SGLang for production serving - Heterogeneous expert placement (hot experts on GPU, cold experts on CPU) **Performance Examples:** | Model | Hardware Configuration | Total Throughput | Output Throughput | |-------|------------------------|------------------|-------------------| | DeepSeek-R1-0528 (FP8) | 8×L20 GPU + Xeon Gold 6454S | 227.85 tokens/s | 87.58 tokens/s (8-way concurrency) | 👉 **[Full Documentation →](./kt-kernel/README.md)** --- ### 🎓 [kt-sft](./kt-sft/) - Fine-Tuning Framework KTransformers × LLaMA-Factory integration for ultra-large MoE model fine-tuning. ![image-20251011010558909](https://raw.githubusercontent.com/kvcache-ai/ktransformers/main/doc/assets/image-20251011010558909.png) **Key Features:** - **Resource Efficient**: Fine-tune 671B DeepSeek-V3 with just **70GB GPU memory** + 1.3TB RAM - **LoRA Support**: Full LoRA fine-tuning with heterogeneous acceleration - **LLaMA-Factory Integration**: Seamless integration with popular fine-tuning framework - **Production Ready**: Chat, batch inference, and metrics evaluation **Performance Examples:** | Model | Configuration | Throughput | GPU Memory | |-------|--------------|------------|------------| | DeepSeek-V3 (671B) | LoRA + AMX | ~40 tokens/s | 70GB (multi-GPU) | | DeepSeek-V2-Lite (14B) | LoRA + AMX | ~530 tokens/s | 6GB | **Quick Start:** ```bash cd kt-sft # Install environment following kt-sft/README.md USE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml ``` 👉 **[Full Documentation →](./kt-sft/README.md)** --- ## 🔥 Citation If you use KTransformers in your research, please cite our paper: ```bibtex @inproceedings{10.1145/3731569.3764843, title = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models}, author = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing}, booktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles}, year = {2025} } ``` ## 👥 Contributors & Team Developed and maintained by: - [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University - [Approaching.AI](http://approaching.ai/) - [9#AISoft](https://github.com/aisoft9) - Community contributors We welcome contributions! Please feel free to submit issues and pull requests. ## 💬 Community & Support - **GitHub Issues**: [Report bugs or request features](https://github.com/kvcache-ai/ktransformers/issues) - **WeChat Group**: See [archive/WeChatGroup.png](./archive/WeChatGroup.png) ## 📦 KT original Code The original integrated KTransformers framework has been archived to the [`archive/`](./archive/) directory for reference. The project now focuses on the two core modules above for better modularity and maintainability. For the original documentation with full quick-start guides and examples, see: - [archive/README.md](./archive/README.md) (English) - [archive/README_ZH.md](./archive/README_ZH.md) (中文) ================================================ FILE: README_ZH.md ================================================

KTransformers

一个用于体验尖端 LLM 推理/微调优化的灵活框架

🎯 概览 | 🚀 kt-kernel | 🎓 kt-sft | 🔥 引用
## 🎯 概览 KTransformers 是一个专注于通过 CPU-GPU 异构计算实现大语言模型高效推理和微调的研究项目。该项目已发展为**两个核心模块**:[kt-kernel](./kt-kernel/) 和 [kt-sft](./kt-sft/)。 ## 🔥 更新 * **2025 年 12 月 5 日**:支持原生 Kimi-K2-Thinking 推理([教程](./doc/en/Kimi-K2-Thinking-Native.md)) * **2025 年 11 月 6 日**:支持 Kimi-K2-Thinking 推理([教程](./doc/en/Kimi-K2-Thinking.md))和微调([教程](./doc/en/SFT_Installation_Guide_KimiK2.md)) * **2025 年 11 月 4 日**:KTransformers 微调 × LLaMA-Factory 集成([教程](./doc/en/KTransformers-Fine-Tuning_User-Guide.md)) * **2025 年 10 月 27 日**:支持昇腾 NPU([教程](./doc/zh/DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md)) * **2025 年 10 月 10 日**:集成到 SGLang([路线图](https://github.com/sgl-project/sglang/issues/11425),[博客](https://lmsys.org/blog/2025-10-22-KTransformers/)) * **2025 年 9 月 11 日**:支持 Qwen3-Next([教程](./doc/en/Qwen3-Next.md)) * **2025 年 9 月 5 日**:支持 Kimi-K2-0905([教程](./doc/en/Kimi-K2.md)) * **2025 年 7 月 26 日**:支持 SmallThinker 和 GLM4-MoE([教程](./doc/en/SmallThinker_and_Glm4moe.md)) * **2025 年 7 月 11 日**:支持 Kimi-K2([教程](./doc/en/Kimi-K2.md)) * **2025 年 6 月 30 日**:支持 3 层(GPU-CPU-磁盘)[前缀缓存](./doc/en/prefix_cache.md)复用 * **2025 年 5 月 14 日**:支持 Intel Arc GPU([教程](./doc/en/xpu.md)) * **2025 年 4 月 29 日**:支持 AMX-Int8、AMX-BF16 和 Qwen3MoE([教程](./doc/en/AMX.md)) * **2025 年 4 月 9 日**:实验性支持 LLaMA 4 模型([教程](./doc/en/llama4.md)) * **2025 年 4 月 2 日**:支持多并发([教程](./doc/en/balance-serve.md)) * **2025 年 3 月 15 日**:支持 AMD GPU 上的 ROCm([教程](./doc/en/ROCm.md)) * **2025 年 3 月 5 日**:支持 unsloth 1.58/2.51 位权重和 [IQ1_S/FP8 混合](./doc/en/fp8_kernel.md)权重。在 24GB VRAM 中支持 DeepSeek-V3 和 R1 的 139K [更长上下文](./doc/en/DeepseekR1_V3_tutorial.md#v022--v023-longer-context--fp8-kernel) * **2025 年 2 月 25 日**:为 DeepSeek-V3 和 R1 支持 [FP8 GPU 内核](./doc/en/fp8_kernel.md);[更长上下文](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context) * **2025 年 2 月 15 日**:更长上下文(24GB VRAM 从 4K 到 8K)& 速度稍快(+15%,最高 16 Tokens/s),更新[文档](./doc/en/DeepseekR1_V3_tutorial.md)和[在线手册](https://kvcache-ai.github.io/ktransformers/) * **2025 年 2 月 10 日**:支持 Deepseek-R1 和 V3 在单 GPU(24GB VRAM)/多 GPU 和 382GB DRAM 上运行,速度提升高达 3~28 倍。详细案例展示和复现教程请参见[这里](./doc/en/DeepseekR1_V3_tutorial.md) * **2024 年 8 月 28 日**:将 DeepseekV2 所需的 VRAM 从 21GB 降低到 11GB * **2024 年 8 月 15 日**:更新了关于注入和多 GPU 的详细[教程](doc/en/injection_tutorial.md) * **2024 年 8 月 14 日**:支持 llamfile 作为线性后端 * **2024 年 8 月 12 日**:支持多 GPU;支持新模型:mixtral 8\*7B 和 8\*22B;支持 GPU 上的 q2k、q3k、q5k 去量化 * **2024 年 8 月 9 日**:支持 Windows 原生环境 --- ## 📦 核心模块 ### 🚀 [kt-kernel](./kt-kernel/) - 高性能推理内核 用于异构 LLM 推理的 CPU 优化内核操作。 ![image-20251011010558909](./doc/assets/heterogeneous_computing.png) **主要特性:** - **AMX/AVX 加速**:Intel AMX 和 AVX512/AVX2 优化的内核,用于 INT4/INT8 量化推理 - **MoE 优化**:高效的专家混合推理,具有 NUMA 感知内存管理 - **量化支持**:CPU 端 INT4/INT8 量化权重,GPU 端 GPTQ 支持 - **易于集成**:为 SGLang 和其他框架提供简洁的 Python API **快速开始:** ```bash cd kt-kernel pip install . ``` **使用场景:** - 大型 MoE 模型的 CPU-GPU 混合推理 - 与 SGLang 集成用于生产服务 - 异构专家放置(热专家在 GPU 上,冷专家在 CPU 上) **性能示例:** | 模型 | 硬件配置 | 总吞吐量 | 输出吞吐量 | |-------|------------------------|------------------|-------------------| | DeepSeek-R1-0528 (FP8) | 8×L20 GPU + Xeon Gold 6454S | 227.85 tokens/s | 87.58 tokens/s(8 路并发)| 👉 **[完整文档 →](./kt-kernel/README.md)** --- ### 🎓 [kt-sft](./kt-sft/) - 微调框架 KTransformers × LLaMA-Factory 集成,用于超大型 MoE 模型微调。 ![image-20251011010558909](./doc/assets/image-20251011010558909.png) **主要特性:** - **资源高效**:仅需 **70GB GPU 显存** + 1.3TB 内存即可微调 671B DeepSeek-V3 - **LoRA 支持**:完整的 LoRA 微调,带有异构加速 - **LLaMA-Factory 集成**:与流行的微调框架无缝集成 - **生产就绪**:聊天、批量推理和指标评估 **性能示例:** | 模型 | 配置 | 吞吐量 | GPU 显存 | |-------|--------------|------------|--------------| | DeepSeek-V3 (671B) | LoRA + AMX | ~40 tokens/s | 70GB(多 GPU)| | DeepSeek-V2-Lite (14B) | LoRA + AMX | ~530 tokens/s | 6GB | **快速开始:** ```bash cd kt-sft # 按照 kt-sft/README.md 安装环境 USE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml ``` 👉 **[完整文档 →](./kt-sft/README.md)** --- ## 🔥 引用 如果您在研究中使用了 KTransformers,请引用我们的论文: ```bibtex @inproceedings{10.1145/3731569.3764843, title = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models}, author = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing}, booktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles}, year = {2025} } ``` ## 👥 贡献者与团队 由以下团队开发和维护: - 清华大学 [MADSys 实验室](https://madsys.cs.tsinghua.edu.cn/) - [Approaching.AI](http://approaching.ai/) - 社区贡献者 我们欢迎贡献!请随时提交问题和拉取请求。 ## 💬 社区与支持 - **GitHub Issues**:[报告问题或请求功能](https://github.com/kvcache-ai/ktransformers/issues) - **微信群**:请参见 [archive/WeChatGroup.png](./archive/WeChatGroup.png) ## 📦 KT原仓库 原始的集成 KTransformers 框架已归档到 [`archive/`](./archive/) 目录以供参考。该项目现在专注于上述两个核心模块,以获得更好的模块化和可维护性。 有关原始文档以及完整的快速入门指南和示例,请参见: - [archive/README.md](./archive/README.md)(英文) - [archive/README_ZH.md](./archive/README_ZH.md)(中文) ================================================ FILE: archive/.devcontainer/Dockerfile ================================================ FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel as compile_server WORKDIR /workspace ENV CUDA_HOME /usr/local/cuda RUN <> ~/.bashrc && \ echo "conda activate ktransformers" >> ~/.bashrc WORKDIR /ktransformers/ CMD ["bash"] ================================================ FILE: archive/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: archive/MANIFEST.in ================================================ graft third_party graft ktransformers graft local_chat.py graft csrc include LICENSE README.md prune ktransformers/website prune ktransformers/logs prune ktransformers.egg-info prune third_party/llama.cpp/models graft ktransformers/website/dist global-exclude __pycache__ include KTransformersOps.*.so include cpuinfer_ext.*.so ================================================ FILE: archive/Makefile ================================================ flake_find: cd ktransformers && flake8 | grep -Eo '[A-Z][0-9]{3}' | sort | uniq| paste -sd ',' - format: @cd ktransformers && black . @black setup.py dev_install: # clear build dirs rm -rf build rm -rf *.egg-info rm -rf ktransformers/ktransformers_ext/build rm -rf ktransformers/ktransformers_ext/cuda/build rm -rf ktransformers/ktransformers_ext/cuda/dist rm -rf ktransformers/ktransformers_ext/cuda/*.egg-info # install ktransformers echo "Installing python dependencies from requirements.txt" pip install -r requirements-local_chat.txt echo "Installing ktransformers" KTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . -v --no-build-isolation echo "Installation completed successfully" clean: rm -rf build rm -rf *.egg-info rm -rf ktransformers/ktransformers_ext/build rm -rf ktransformers/ktransformers_ext/cuda/build rm -rf ktransformers/ktransformers_ext/cuda/dist rm -rf ktransformers/ktransformers_ext/cuda/*.egg-info install_numa: USE_NUMA=1 make dev_install install_no_numa: env -u USE_NUMA make dev_install ================================================ FILE: archive/README.md ================================================

KTransformers

High-Performance CPU-GPU Hybrid Inference for Large Language Models

## 🎯 Overview KTransformers is a research project focused on efficient inference and fine-tuning of large language models through CPU-GPU heterogeneous computing. The project has evolved into **two core modules**: [kt-kernel](./kt-kernel/) and [kt-sft](./kt-sft/). ## 🔥 Updates * **Nov 6, 2025**: Support Kimi-K2-Thinking inference and fine-tune * **Nov 4, 2025**: KTransformers Fine-Tuning × LLaMA-Factory Integration * **Oct 27, 2025**: Support Ascend NPU * **Oct 10, 2025**: Integrating into SGLang ([Roadmap](https://github.com/sgl-project/sglang/issues/11425), [Blog](https://lmsys.org/blog/2025-10-22-KTransformers/)) * **Sept 11, 2025**: Support Qwen3-Next * **Sept 05, 2025**: Support Kimi-K2-0905 * **July 26, 2025**: Support SmallThinker and GLM4-MoE * **June 30, 2025**: Support 3-layer (GPU-CPU-Disk) prefix cache reuse * **May 14, 2025**: Support Intel Arc GPU * **Apr 29, 2025**: Support AMX-Int8、AMX-BF16 and Qwen3MoE * **Apr 9, 2025**: Experimental support for LLaMA 4 models * **Apr 2, 2025**: Support Multi-concurrency * **Mar 15, 2025**: Support ROCm on AMD GPU * **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and IQ1_S/FP8 hybrid weights; 139K longer context for DeepSeek-V3/R1 * **Feb 25, 2025**: Support FP8 GPU kernel for DeepSeek-V3 and R1 * **Feb 10, 2025**: Support Deepseek-R1 and V3, up to 3~28x speedup --- ## 📦 Core Modules ### 🚀 [kt-kernel](./kt-kernel/) - High-Performance Inference Kernels CPU-optimized kernel operations for heterogeneous LLM inference. ![image-20251011010558909](./doc/assets/heterogeneous_computing.png) **Key Features:** - **AMX/AVX Acceleration**: Intel AMX and AVX512/AVX2 optimized kernels for INT4/INT8 quantized inference - **MoE Optimization**: Efficient Mixture-of-Experts inference with NUMA-aware memory management - **Quantization Support**: CPU-side INT4/INT8 quantized weights, GPU-side GPTQ support - **Easy Integration**: Clean Python API for SGLang and other frameworks **Quick Start:** ```bash cd kt-kernel pip install . ``` **Use Cases:** - CPU-GPU hybrid inference for large MoE models - Integration with SGLang for production serving - Heterogeneous expert placement (hot experts on GPU, cold experts on CPU) **Performance Examples:** | Model | Hardware Configuration | Total Throughput | Output Throughput | |-------|------------------------|------------------|-------------------| | DeepSeek-R1-0528 (FP8) | 8×L20 GPU + Xeon Gold 6454S | 227.85 tokens/s | 87.58 tokens/s (8-way concurrency) | 👉 **[Full Documentation →](./kt-kernel/README.md)** --- ### 🎓 [kt-sft](./kt-sft/) - Fine-Tuning Framework KTransformers × LLaMA-Factory integration for ultra-large MoE model fine-tuning. ![image-20251011010558909](./doc/assets/image-20251011010558909.png) **Key Features:** - **Resource Efficient**: Fine-tune 671B DeepSeek-V3 with just **70GB GPU memory** + 1.3TB RAM - **LoRA Support**: Full LoRA fine-tuning with heterogeneous acceleration - **LLaMA-Factory Integration**: Seamless integration with popular fine-tuning framework - **Production Ready**: Chat, batch inference, and metrics evaluation **Performance Examples:** | Model | Configuration | Throughput | GPU Memory | |-------|--------------|------------|------------| | DeepSeek-V3 (671B) | LoRA + AMX | ~40 tokens/s | 70GB (multi-GPU) | | DeepSeek-V2-Lite (14B) | LoRA + AMX | ~530 tokens/s | 6GB | **Quick Start:** ```bash cd kt-sft # Install environment following kt-sft/README.md USE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml ``` 👉 **[Full Documentation →](./kt-sft/README.md)** --- ## 🔥 Citation If you use KTransformers in your research, please cite our paper: ```bibtex @inproceedings{10.1145/3731569.3764843, title = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models}, author = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing}, booktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles}, year = {2025} } ``` ## 👥 Contributors & Team Developed and maintained by: - [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University - [Approaching.AI](http://approaching.ai/) - Community contributors We welcome contributions! Please feel free to submit issues and pull requests. ## 💬 Community & Support - **GitHub Issues**: [Report bugs or request features](https://github.com/kvcache-ai/ktransformers/issues) - **GitHub Discussions**: [Ask questions and share ideas](https://github.com/kvcache-ai/ktransformers/discussions) - **WeChat Group**: See [archive/WeChatGroup.png](./archive/WeChatGroup.png) ## 📦 Legacy Code The original integrated KTransformers framework has been archived to the [`archive/`](./archive/) directory for reference. The project now focuses on the two core modules above for better modularity and maintainability. For the original documentation with full quick-start guides and examples, see: - [archive/README_LEGACY.md](./archive/README_LEGACY.md) (English) - [archive/README_ZH_LEGACY.md](./archive/README_ZH_LEGACY.md) (中文) ================================================ FILE: archive/README_LEGACY.md ================================================

KTransformers

A Flexible Framework for Experiencing Cutting-edge LLM Inference Optimizations

🌟 Show Cases | 🚀 Quick Start | 📃 Tutorial | 🔥 Citation | 💬 Discussion | 🙋 FAQ

🎉 Introduction

KTransformers, pronounced as Quick Transformers, is designed to enhance your 🤗 Transformers experience with advanced kernel optimizations and placement/parallelism strategies.

KTransformers is a flexible, Python-centric framework designed with extensibility at its core. By implementing and injecting an optimized module with a single line of code, users gain access to a Transformers-compatible interface, RESTful APIs compliant with OpenAI and Ollama, and even a simplified ChatGPT-like web UI.

Our vision for KTransformers is to serve as a flexible platform for experimenting with innovative LLM inference optimizations. Please let us know if you need any other features.

🔥 Updates

* **Nov 6, 2025**: Support Kimi-K2-Thinking inference ([Tutorial](./doc/en/Kimi-K2-Thinking.md)) and fine-tune ([Tutorial](./doc/en/SFT_Installation_Guide_KimiK2.md)) * **Nov 4, 2025**: KTransformers Fine-Tuning × LLaMA-Factory Integration. ([Tutorial](./doc/en/KTransformers-Fine-Tuning_User-Guide.md)) * **Oct 27, 2025**: Support Ascend NPU. ([Tutorial](./doc/zh/DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md)) * **Oct 10, 2025**: Integrating into SGLang. ([Roadmap](https://github.com/sgl-project/sglang/issues/11425)) * **Sept 11, 2025**: Support Qwen3-Next. ([Tutorial](./doc/en/Qwen3-Next.md)) * **Sept 05, 2025**: Support Kimi-K2-0905. ([Tutorial](./doc/en/Kimi-K2.md)) * **July 26, 2025**: Support SmallThinker and GLM4-MoE. ([Tutorial](./doc/en/SmallThinker_and_Glm4moe.md)) * **July 11, 2025**: Support Kimi-K2. ([Tutorial](./doc/en/Kimi-K2.md)) * **June 30, 2025**: Support 3-layer (GPU-CPU-Disk) [prefix cache](./doc/en/prefix_cache.md) reuse. * **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)). * **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md)) https://github.com/user-attachments/assets/fafe8aec-4e22-49a8-8553-59fb5c6b00a2 * **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./doc/en/llama4.md)). * **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./doc/en/balance-serve.md)). https://github.com/user-attachments/assets/faa3bda2-928b-45a7-b44f-21e12ec84b8a * **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)). * **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022--v023-longer-context--fp8-kernel) for DeepSeek-V3 and R1 in 24GB VRAM. * **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context). * **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed (+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/). * **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. For detailed show case and reproduction tutorial, see [here](./doc/en/DeepseekR1_V3_tutorial.md). * **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G. * **Aug 15, 2024**: Update detailed [tutorial](doc/en/injection_tutorial.md) for injection and multi-GPU. * **Aug 14, 2024**: Support llamfile as linear backend. * **Aug 12, 2024**: Support multiple GPU; Support new model: mixtral 8\*7B and 8\*22B; Support q2k, q3k, q5k dequant on gpu. * **Aug 9, 2024**: Support windows native.

🌟 Show Cases

GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM

https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285

- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM([Tutorial](./doc/en/DeepseekR1_V3_tutorial.md)). - Prefill Speed (tokens/s): - KTransformers: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → 255.26 (optimized AMX-based MoE kernel, V0.3 only) → 286.55 (selectively using 6 experts, V0.3 only) - Compared to 10.31 tokens/s in llama.cpp with 2×32 cores, achieving up to **27.79× speedup**. - Decode Speed (tokens/s): - KTransformers: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, V0.3 only) - Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**. - Upcoming Open Source Release: - AMX optimizations and selective expert activation will be open-sourced in V0.3. - Currently available only in preview binary distribution, which can be downloaded [here](./doc/en/DeepseekR1_V3_tutorial.md). - **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench).

DeepSeek-Coder-V2 Score

- **Faster Speed:** Achieving 126 tokens/s for 2K prompt prefill and 13.6 tokens/s for generation through MoE offloading and injecting advanced kernels from [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) and [Marlin](https://github.com/IST-DASLab/marlin). - **VSCode Integration:** Wrapped into an OpenAI and Ollama compatible API for seamless integration as a backend for [Tabby](https://github.com/TabbyML/tabby) and various other frontends.

https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c

More advanced features will coming soon, so stay tuned!

🚀 Quick Start

Getting started with KTransformers is simple! Follow the steps below to set up and start using it. we have already supported vendors: - Metax - Sanechips (ZhuFeng V1.0) - Intel - Ascend - Kunpeng - AMD ### 📥 Installation To install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html).

📃 Brief Injection Tutorial

At the heart of KTransformers is a user-friendly, template-based injection framework. This allows researchers to easily replace original torch modules with optimized variants. It also simplifies the process of combining multiple optimizations, allowing the exploration of their synergistic effects.

Inject-Struction

Given that vLLM already serves as a great framework for large-scale deployment optimizations, KTransformers is particularly focused on local deployments that are constrained by limited resources. We pay special attention to heterogeneous computing opportunities, such as GPU/CPU offloading of quantized models. For example, we support the efficient Llamafile and Marlin kernels for CPU and GPU, respectively. More details can be found here.

Example Usage

To utilize the provided kernels, users only need to create a YAML-based injection template and add the call to `optimize_and_load_gguf` before using the Transformers model. ```python with torch.device("meta"): model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) ... generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens=1000) ``` In this example, the AutoModel is first initialized on the meta device to avoid occupying any memory resources. Then, `optimize_and_load_gguf` iterates through all sub-modules of the model, matches rules specified in your YAML rule file, and replaces them with advanced modules as specified. After injection, the original `generate` interface is available, but we also provide a compatible `prefill_and_generate` method, which enables further optimizations like CUDAGraph to improve generation speed.

How to custom your model

A detailed tutorial of the injection and multi-GPU using DeepSeek-V2 as an example is given [here](doc/en/injection_tutorial.md). Below is an example of a YAML template for replacing all original Linear modules with Marlin, an advanced 4-bit quantization kernel. ```yaml - match: name: "^model\\.layers\\..*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types device: "cpu" # which devices to load this module when initializing kwargs: generate_device: "cuda" generate_linear_type: "QuantizedLinearMarlin" ``` Each rule in the YAML file has two parts: `match` and `replace`. The `match` part specifies which module should be replaced, and the `replace` part specifies the module to be injected into the model along with the initialization keywords. You can find example rule templates for optimizing DeepSeek-V2 and Qwen2-57B-A14, two SOTA MoE models, in the [ktransformers/optimize/optimize_rules](ktransformers/optimize/optimize_rules) directory. These templates are used to power the `local_chat.py` demo. If you are interested in our design principles and the implementation of the injection framework, please refer to the [design document](doc/en/deepseek-v2-injection.md).

🔥 Citation

If you use KTransformers for your research, please cite our [paper](https://madsys.cs.tsinghua.edu.cn/publication/ktransformers-unleashing-the-full-potential-of-cpu/gpu-hybrid-inference-for-moe-models/): ``` @inproceedings{10.1145/3731569.3764843, title = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models}, author = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing}, booktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles}, year = {2025} } ```

Acknowledgment and Contributors

The development of KTransformers is based on the flexible and versatile framework provided by Transformers. We also benefit from advanced kernels such as GGUF/GGML, Llamafile, Marlin, sglang and flashinfer. We are planning to contribute back to the community by upstreaming our modifications. KTransformers is actively maintained and developed by contributors from the MADSys group at Tsinghua University and members from Approaching.AI. We welcome new contributors to join us in making KTransformers faster and easier to use.

Discussion

If you have any questions, feel free to open an issue. Alternatively, you can join our WeChat group for further discussion. QR Code: [WeChat Group](WeChatGroup.png)

🙋 FAQ

Some common questions are answered in the [FAQ](doc/en/FAQ.md). ================================================ FILE: archive/README_ZH.md ================================================

KTransformers

高性能 CPU-GPU 异构大语言模型推理

## 🎯 项目概述 KTransformers 是一个专注于大语言模型高效推理和微调的研究项目,通过 CPU-GPU 异构计算实现资源受限环境下的模型部署。项目已演进为**两个核心模块**:[kt-kernel](./kt-kernel/) 和 [kt-sft](./kt-sft/)。 ## 🔥 更新 * **2025年11月6日**:支持 Kimi-K2-Thinking 推理和微调 * **2025年11月4日**:KTransformers 微调 × LLaMA-Factory 集成 * **2025年10月27日**:支持 Ascend NPU * **2025年10月10日**:集成到 SGLang ([路线图](https://github.com/sgl-project/sglang/issues/11425), [博客](https://lmsys.org/blog/2025-10-22-KTransformers/)) * **2025年9月11日**:支持 Qwen3-Next * **2025年9月5日**:支持 Kimi-K2-0905 * **2025年7月26日**:支持 SmallThinker 和 GLM4-MoE * **2025年6月30日**:支持 3层(GPU-CPU-磁盘)前缀缓存复用 * **2025年5月14日**:支持 Intel Arc GPU * **2025年4月29日**:支持 AMX-Int8、AMX-BF16 和 Qwen3MoE * **2025年4月9日**:实验性支持 LLaMA 4 模型 * **2025年4月2日**:支持多并发 * **2025年3月15日**:支持 AMD GPU 的 ROCm * **2025年3月5日**:支持 unsloth 1.58/2.51 bits 权重和 IQ1_S/FP8 混合权重;DeepSeek-V3/R1 支持 139K 长上下文 * **2025年2月25日**:支持 DeepSeek-V3 和 R1 的 FP8 GPU 内核 * **2025年2月10日**:支持 Deepseek-R1 和 V3,速度提升最高达 3~28 倍 --- ## 📦 核心模块 ### 🚀 [kt-kernel](./kt-kernel/) - 高性能推理内核 面向异构 LLM 推理的 CPU 优化内核操作库。 ![image-20251011010558909](./doc/assets/heterogeneous_computing.png) **核心特性:** - **AMX/AVX 加速**:Intel AMX 和 AVX512/AVX2 优化内核,支持 INT4/INT8 量化推理 - **MoE 优化**:高效的专家混合推理,支持 NUMA 感知内存管理 - **量化支持**:CPU 端 INT4/INT8 量化权重,GPU 端 GPTQ 支持 - **易于集成**:简洁的 Python API,可集成到 SGLang 等框架 **快速开始:** ```bash cd kt-kernel pip install . ``` **应用场景:** - 大型 MoE 模型的 CPU-GPU 混合推理 - 与 SGLang 集成用于生产服务 - 异构专家放置(热门专家在 GPU,冷门专家在 CPU) **性能示例:** | 模型 | 硬件配置 | 总吞吐量 | 输出吞吐量 | |------|---------|---------|-----------| | DeepSeek-R1-0528 (FP8) | 8×L20 GPU + Xeon Gold 6454S | 227.85 tokens/s | 87.58 tokens/s(8路并发)| 👉 **[完整文档 →](./kt-kernel/README.md)** --- ### 🎓 [kt-sft](./kt-sft/) - 微调框架 KTransformers × LLaMA-Factory 集成,支持超大 MoE 模型微调。 ![image-20251011010558909](./doc/assets/image-20251011010558909.png) **核心特性:** - **资源高效**:仅需 **70GB 显存** + 1.3TB 内存即可微调 671B DeepSeek-V3 - **LoRA 支持**:完整的 LoRA 微调与异构加速 - **LLaMA-Factory 集成**:与流行微调框架无缝集成 - **生产就绪**:支持对话、批量推理和指标评估 **性能示例:** | 模型 | 配置 | 吞吐量 | GPU 显存 | |------|------|--------|----------| | DeepSeek-V3 (671B) | LoRA + AMX | ~40 tokens/s | 70GB (多卡) | | DeepSeek-V2-Lite (14B) | LoRA + AMX | ~530 tokens/s | 6GB | **快速开始:** ```bash cd kt-sft # 按照 kt-sft/README.md 安装环境 USE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml ``` 👉 **[完整文档 →](./kt-sft/README.md)** --- ## 🔥 引用 如果您在研究中使用了 KTransformers,请引用我们的论文: ```bibtex @inproceedings{10.1145/3731569.3764843, title = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models}, author = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing}, booktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles}, year = {2025} } ``` ## 👥 贡献者与团队 由以下团队开发和维护: - 清华大学 [MADSys 实验室](https://madsys.cs.tsinghua.edu.cn/) - [Approaching.AI](http://approaching.ai/) - 社区贡献者 我们欢迎贡献!请随时提交 issues 和 pull requests。 ## 💬 社区与支持 - **GitHub Issues**:[报告 bug 或请求功能](https://github.com/kvcache-ai/ktransformers/issues) - **GitHub Discussions**:[提问和分享想法](https://github.com/kvcache-ai/ktransformers/discussions) - **微信群**:查看 [archive/WeChatGroup.png](./archive/WeChatGroup.png) ## 📦 历史代码 原完整的 KTransformers 框架代码已归档至 [`archive/`](./archive/) 目录供参考。项目现专注于上述两个核心模块,以实现更好的模块化和可维护性。 关于原始完整文档(包含快速入门指南和示例),请查看: - [archive/README_LEGACY.md](./archive/README_LEGACY.md) (English) - [archive/README_ZH_LEGACY.md](./archive/README_ZH_LEGACY.md) (中文) ================================================ FILE: archive/README_ZH_LEGACY.md ================================================

KTransformers

一个用于体验尖端 LLM 推理优化的灵活框架

🌟 案例展示 | 🚀 快速入门 | 📃 教程 | 💬 讨论 | 🙋 常见问题

🎉 介绍

KTransformers(发音为 Quick Transformers)旨在通过先进的内核优化和放置/并行策略来增强您对 🤗 [Transformers](https://github.com/huggingface/transformers) 的体验。

KTransformers 是一个以 Python 为中心的灵活框架,其核心是可扩展性。通过用一行代码实现并注入优化模块,用户可以获得与 Transformers 兼容的接口、符合 OpenAI 和 Ollama 的 RESTful API,甚至是一个简化的类似 ChatGPT 的 Web 界面。

我们对 KTransformers 的愿景是成为一个用于实验创新 LLM 推理优化的灵活平台。如果您需要任何其他功能,请告诉我们。

🔥 更新

* **2025 年 2 月 15 日**:为DeepSeek-V3/R1支持[FP8 GPU内核](./doc/en/fp8_kernel.md); 支持更长的上下文([教程](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context)). * **2025 年 2 月 15 日**:长上下文(从4K到8K,24GB VRAM) & 稍快的速度(+15%)(最快 16 Tokens/s),文档请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md) 和 [在线指南](https://kvcache-ai.github.io/ktransformers/) 。 * **2025 年 2 月 10 日**:支持 Deepseek-R1 和 V3 在单个(24GB VRAM)/多 GPU 和 382G DRAM 上运行,速度提升高达 3~28 倍。详细教程请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md)。 * **2024 年 8 月 28 日**:支持 InternLM2.5-7B-Chat-1M 模型下的 1M 上下文,使用 24GB 的 VRAM 和 150GB 的 DRAM。详细教程请参见 [这里](./doc/en/long_context_tutorial.md)。 * **2024 年 8 月 28 日**:将 DeepseekV2 所需的 VRAM 从 21G 降低到 11G。 * **2024 年 8 月 15 日**:更新了详细的 [教程](doc/en/injection_tutorial.md),介绍注入和多 GPU 的使用。 * **2024 年 8 月 14 日**:支持 llamfile 作为线性后端。 * **2024 年 8 月 12 日**:支持多 GPU;支持新模型:mixtral 8\*7B 和 8\*22B;支持 q2k、q3k、q5k 在 GPU 上的去量化。 * **2024 年 8 月 9 日**:支持 Windows。

🌟 案例展示

在仅 24GB VRAM 的桌面上运行 GPT-4/o1 级别的本地 VSCode Copilot

https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285

- **[NEW!!!] 本地 671B DeepSeek-Coder-V3/R1**:使用其 Q4_K_M 版本,仅需 14GB VRAM 和 382GB DRAM 即可运行(教程请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md))。 - 预填充速度(tokens/s): - KTransformers:54.21(32 核)→ 74.362(双插槽,2×32 核)→ 255.26(优化的 AMX 基 MoE 内核,仅 V0.3)→ 286.55(选择性使用 6 个专家,仅 V0.3) - 与 llama.cpp 在 2×32 核下相比,达到 **27.79× 速度提升**。 - 解码速度(tokens/s): - KTransformers:8.73(32 核)→ 11.26(双插槽,2×32 核)→ 13.69(选择性使用 6 个专家,仅 V0.3) - 与 llama.cpp 在 2×32 核下相比,达到 **3.03× 速度提升**。 - 即将开源发布: - AMX 优化和选择性专家激活将在 V0.3 中开源。 - 目前仅在预览二进制分发中可用,可从 [这里](./doc/en/DeepseekR1_V3_tutorial.md) 下载。 - **本地 236B DeepSeek-Coder-V2**:使用其 Q4_K_M 版本,仅需 21GB VRAM 和 136GB DRAM 即可运行,甚至在 [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench) 中得分超过 GPT4-0613。

DeepSeek-Coder-V2 Score

- **更快的速度**:通过 MoE 卸载和注入来自 [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) 和 [Marlin](https://github.com/IST-DASLab/marlin) 的高级内核,实现了 2K 提示预填充 126 tokens/s 和生成 13.6 tokens/s 的速度。 - **VSCode 集成**:封装成符合 OpenAI 和 Ollama 的 API,可无缝集成到 [Tabby](https://github.com/TabbyML/tabby) 和其他前端的后端。

https://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c

更多高级功能即将推出,敬请期待!

🚀 快速入门

KTransformers 的入门非常简单!请参考我们的[安装指南]((https://kvcache-ai.github.io/ktransformers/))进行安装。

📃 简要注入教程

KTransformers 的核心是一个用户友好的、基于模板的注入框架。这使得研究人员可以轻松地将原始 torch 模块替换为优化的变体。它还简化了多种优化的组合过程,允许探索它们的协同效应。

Inject-Struction

鉴于 vLLM 已经是一个用于大规模部署优化的优秀框架,KTransformers 特别关注受资源限制的本地部署。我们特别关注异构计算时机,例如量化模型的 GPU/CPU 卸载。例如,我们支持高效的 LlamafileMarlin 内核,分别用于 CPU 和 GPU。 更多详细信息可以在 这里找到。

示例用法

要使用提供的内核,用户只需创建一个基于 YAML 的注入模板,并在使用 Transformers 模型之前添加对 `optimize_and_load_gguf` 的调用。 ```python with torch.device("meta"): model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) ... generated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens=1000) ``` 在这个示例中,首先在 meta 设备上初始化 AutoModel,以避免占用任何内存资源。然后,`optimize_and_load_gguf` 遍历模型的所有子模块,匹配您的 YAML 规则文件中指定的规则,并将它们替换为指定的高级模块。 注入后,原始的 `generate` 接口仍然可用,但我们还提供了一个兼容的 `prefill_and_generate` 方法,这使得可以进一步优化,例如使用 CUDAGraph 提高生成速度。

如何自定义您的模型

一个详细的使用 DeepSeek-V2 作为示例的注入和 multi-GPU 教程在 [这里](doc/en/injection_tutorial.md)。 以下是一个将所有原始 Linear 模块替换为 Marlin 的 YAML 模板示例,Marlin 是一个高级的 4 位量化内核。 ```yaml - match: name: "^model\\.layers\\..*$" # 正则表达式 class: torch.nn.Linear # 仅匹配同时符合名称和类的模块 replace: class: ktransformers.operators.linear.KTransformerLinear # 量化数据类型的优化内核 device: "cpu" # 初始化时加载该模块的 device kwargs: generate_device: "cuda" generate_linear_type: "QuantizedLinearMarlin" ``` YAML 文件中的每个规则都有两部分:`match` 和 `replace`。`match` 部分指定应替换的模块,`replace` 部分指定要注入到模型中的模块以及初始化关键字。 您可以在 [ktransformers/optimize/optimize_rules](ktransformers/optimize/optimize_rules) 目录中找到用于优化 DeepSeek-V2 和 Qwen2-57B-A14 的示例规则模板。这些模板用于为 `local_chat.py` 示例提供支持。 如果您对我们的设计原则和注入框架的实现感兴趣,请参考 [设计文档](doc/en/deepseek-v2-injection.md)。

致谢和贡献者

KTransformers 的开发基于 Transformers 提供的灵活和多功能框架。我们还受益于 GGUF/GGML、Llamafile 、 Marlin、sglang和flashinfer 等高级内核。我们计划通过向上游贡献我们的修改来回馈社区。 KTransformers 由清华大学 MADSys group 小组的成员以及 Approaching.AI 的成员积极维护和开发。我们欢迎新的贡献者加入我们,使 KTransformers 更快、更易于使用。

讨论

如果您有任何问题,欢迎随时提出 issue。或者,您可以加入我们的微信群进行进一步讨论。二维码: [微信群](WeChatGroup.png)

🙋 常见问题

一些常见问题的答案可以在 [FAQ](doc/en/FAQ.md) 中找到。 ================================================ FILE: archive/SECURITY.md ================================================ # Security Policy ## Supported Versions Use this section to tell people about which versions of your project are currently being supported with security updates. | Version | Supported | | ------- | ------------------ | | 5.1.x | :white_check_mark: | | 5.0.x | :x: | | 4.0.x | :white_check_mark: | | < 4.0 | :x: | ## Reporting a Vulnerability Use this section to tell people how to report a vulnerability. Tell them where to go, how often they can expect to get an update on a reported vulnerability, what to expect if the vulnerability is accepted or declined, etc. ================================================ FILE: archive/book.toml ================================================ [book] authors = ["kvcache-ai"] language = "zh-CN" title = "Ktransformers" src = "doc" [output.html] git-repository-url = "https://github.com/kvcache-ai/ktransformers" edit-url-template = "https://github.com/kvcache-ai/ktransformers/edit/main/{path}" [output.html.playground] editable = true copy-js = true # line-numbers = true [output.html.fold] enable = true level = 0 ================================================ FILE: archive/config.json ================================================ ================================================ FILE: archive/csrc/balance_serve/CMakeLists.txt ================================================ option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF) if(KTRANSFORMERS_USE_NPU) add_definitions(-DKTRANSFORMERS_USE_NPU=1) endif() if(KTRANSFORMERS_USE_NPU) set(ASCEND_HOME_PATH "$ENV{ASCEND_HOME_PATH}") message(STATUS "ASCEND_HOME_PATH is ${ASCEND_HOME_PATH}") include_directories(${ASCEND_HOME_PATH}/include) link_directories(${TORCH_INSTALL_PREFIX}/../torch.libs) # find torch_npu execute_process( COMMAND python -c "import torch; import torch_npu; print(torch_npu.__path__[0])" OUTPUT_VARIABLE TORCH_NPU_PATH OUTPUT_STRIP_TRAILING_WHITESPACE ) message(STATUS "Found PTA at: ${TORCH_NPU_PATH}") find_library(PTA_LIBRARY torch_npu PATH "${TORCH_NPU_PATH}/lib") endif() cmake_minimum_required(VERSION 3.21) find_program(GCC_COMPILER NAMES g++-13 g++-12 g++-11 g++ REQUIRED) set(CMAKE_CXX_COMPILER ${GCC_COMPILER}) # 显示选定的编译器 message(STATUS "Using compiler: ${CMAKE_CXX_COMPILER}") project(balance_serve VERSION 0.1.0) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_FLAGS "-Og -march=native -Wall -Wextra -g -fPIC") set(CMAKE_BUILD_TYPE "Debug") # set(CMAKE_CXX_FLAGS "-O3 -march=native -Wall -Wextra -fPIC") # set(CMAKE_BUILD_TYPE "Release") if(NOT DEFINED _GLIBCXX_USE_CXX11_ABI) find_package(Python3 REQUIRED COMPONENTS Interpreter) execute_process( COMMAND ${Python3_EXECUTABLE} -c "import torch; print('1' if torch.compiled_with_cxx11_abi() else '0')" OUTPUT_VARIABLE ABI_FLAG OUTPUT_STRIP_TRAILING_WHITESPACE ) set(_GLIBCXX_USE_CXX11_ABI ${ABI_FLAG} CACHE STRING "C++11 ABI setting from PyTorch" FORCE) endif() # 无论是否是自动检测,都传给编译器 add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI}) message(STATUS "_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI}") file(GLOB_RECURSE FMT_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.h") add_custom_target( format COMMAND clang-format -i -style=file ${FMT_SOURCES} COMMENT "Running clang-format on all source files" ) set(BUILD_SHARED_LIBS ON) set(ENABLE_PUSH OFF) set(ENABLE_COMPRESSION OFF) # set(CMAKE_BUILD_TYPE "Release") set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(THIRD_PARTY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party) set(THIRD_PARTY_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/third_party) add_subdirectory(${THIRD_PARTY_DIR}/prometheus-cpp ${THIRD_PARTY_BUILD_DIR}/prometheus-cpp EXCLUDE_FROM_ALL) add_subdirectory(${THIRD_PARTY_DIR}/xxHash/cmake_unofficial ${THIRD_PARTY_BUILD_DIR}/xxHash EXCLUDE_FROM_ALL) set_target_properties(xxhash PROPERTIES POSITION_INDEPENDENT_CODE ON) # add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third_party/prometheus-cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/prometheus-cpp) set(SPDLOG_DIR ${THIRD_PARTY_DIR}/spdlog) set(FMT_DIR ${THIRD_PARTY_DIR}/fmt) set(KVC2_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/kvc2/src) include_directories(${THIRD_PARTY_DIR}) add_subdirectory(${THIRD_PARTY_DIR}/pybind11 ${THIRD_PARTY_BUILD_DIR}/pybind11) execute_process( COMMAND python3 -c "import torch; print(torch.__path__[0])" OUTPUT_VARIABLE TORCH_INSTALL_PREFIX OUTPUT_STRIP_TRAILING_WHITESPACE ) message(STATUS "Found PyTorch at: ${TORCH_INSTALL_PREFIX}") # set(TORCH_INSTALL_PREFIX "/home/xwy/.conda/envs/kvc/lib/python3.12/site-packages/torch") find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib") find_package(Torch REQUIRED PATHS "${TORCH_INSTALL_PREFIX}/share/cmake/Torch" NO_DEFAULT_PATH) add_subdirectory(kvc2) add_subdirectory(sched) # add_subdirectory(test) ================================================ FILE: archive/csrc/custom_marlin/__init__.py ================================================ ================================================ FILE: archive/csrc/custom_marlin/binding.cpp ================================================ /** * @Description : * @Author : Azure-Tang * @Date : 2024-07-25 13:38:30 * @Version : 1.0.0 * @LastEditors : kkk1nak0 * @LastEditTime : 2024-08-12 03:05:04 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "gptq_marlin/ops.h" // Python bindings #include #include #include #include #include // namespace py = pybind11; PYBIND11_MODULE(vLLMMarlin, m) { /*m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));*/ m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m_tensor"), py::arg("size_m"), py::arg("size_n"), py::arg("size_k"), py::arg("sms"), py::arg("is_k_full")); m.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); } ================================================ FILE: archive/csrc/custom_marlin/gptq_marlin/gptq_marlin.cu ================================================ /* * Modified by Neural Magic * Copyright (C) Marlin.2024 Elias Frantar * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Adapted from https://github.com/IST-DASLab/marlin */ /* * Adapted from * https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin */ #include "gptq_marlin.cuh" #include "gptq_marlin_dtypes.cuh" #include #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert(std::is_same::value || \ std::is_same::value, \ "only float16 and bfloat16 is supported"); template inline std::string str(T x) { return std::to_string(x); } namespace gptq_marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, int4* __restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) {} template shared // fetch pipeline const bool has_act_order, // whether act_order is enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > __global__ void Marlin(const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn int4* __restrict__ C, // fp16 output buffer of shape mxn const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m int prob_n, // output dimension n int prob_k, // reduction dimension k int* locks // extra global storage for barrier synchronization ) {} } // namespace gptq_marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full) { TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({ 1, 1 }); } #else // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. template __device__ inline void mma(const typename ScalarType::FragA& a_frag, const typename ScalarType::FragB& frag_b, typename ScalarType::FragC& frag_c) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); float* c = reinterpret_cast(&frag_c); if constexpr (std::is_same::value) { asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } else if constexpr (std::is_same::value) { asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } else { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. template __device__ inline void ldsm4(typename ScalarType::FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)); } // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. template __device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); return res; } // Constructs destination register by taking bytes from 2 sources (based on // mask) template __device__ inline uint32_t prmt(uint32_t a) { uint32_t res; asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask)); return res; } // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 // values. We mostly follow the strategy in the link below, with some small // changes: // - FP16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 // - BF16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 template __device__ inline typename ScalarType::FragB dequant_4bit(int q) { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } template <> __device__ inline typename ScalarType::FragB dequant_4bit(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; // Guarantee that the `(a & b) | c` operations are LOP3s. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point // directly into `SUB` and `ADD`. const int SUB = 0x64086408; const int MUL = 0x2c002c00; const int ADD = 0xd480d480; typename ScalarType::FragB frag_b; frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); frag_b[1] = __hfma2(*reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); return frag_b; } template <> __device__ inline typename ScalarType::FragB dequant_4bit(int q) { static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t EX = 0x43004300; // Guarantee that the `(a & b) | c` operations are LOP3s. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); q >>= 4; int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); typename ScalarType::FragB frag_b; static constexpr uint32_t MUL = 0x3F803F80; static constexpr uint32_t ADD = 0xC308C308; frag_b[0] = __hfma2(*reinterpret_cast(&lo), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); frag_b[1] = __hfma2(*reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); return frag_b; } // Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or // bf16 Reference: // - FP16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 // - BF16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 template __device__ inline typename ScalarType::FragB dequant_8bit(int q) { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } template <> __device__ inline typename ScalarType::FragB dequant_8bit(int q) { static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t start_byte_for_fp16 = 0x64646464; uint32_t lo = prmt(q); uint32_t hi = prmt(q); static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; typename ScalarType::FragB frag_b; frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); frag_b[1] = __hsub2(*reinterpret_cast(&hi), *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); return frag_b; } template <> __device__ inline typename ScalarType::FragB dequant_8bit(int q) { typename ScalarType::FragB frag_b; float fp32_intermediates[4]; uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); static constexpr uint32_t fp32_base = 0x4B000000; fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); fp32_intermediates[0] -= 8388736.f; fp32_intermediates[1] -= 8388736.f; fp32_intermediates[2] -= 8388736.f; fp32_intermediates[3] -= 8388736.f; uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); return frag_b; } // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. template __device__ inline void scale(typename ScalarType::FragB& frag_b, typename ScalarType::FragS& frag_s, int i) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s = ScalarType::num2num2( reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } // Same as above, but for act_order (each K is multiplied individually) template __device__ inline void scale4(typename ScalarType::FragB& frag_b, typename ScalarType::FragS& frag_s_1, typename ScalarType::FragS& frag_s_2, typename ScalarType::FragS& frag_s_3, typename ScalarType::FragS& frag_s_4, int i) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s_val_1_2; s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; scalar_t2 s_val_3_4; s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; frag_b[0] = __hmul2(frag_b[0], s_val_1_2); frag_b[1] = __hmul2(frag_b[1], s_val_3_4); } // Given 2 floats multiply by 2 scales (halves) template __device__ inline void scale_float(float* c, typename ScalarType::FragS& s) { scalar_t* s_ptr = reinterpret_cast(&s); c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. __device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { int state = -1; do // Guarantee that subsequent writes by this threadblock will be // visible globally. asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); while (state != count); } __syncthreads(); } // Release barrier and increment visitation count. __device__ inline void barrier_release(int* lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { lock[0] = 0; return; } int val = 1; // Make sure that all writes since acquiring this barrier are visible // globally, while releasing the barrier. asm volatile("fence.acq_rel.gpu;\n"); asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); } } // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, int4* __restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) { int start_row = block_rows * blockIdx.x; int finish_row = start_row + block_rows; if (finish_row > size_m) { finish_row = size_m; } int cur_block_rows = finish_row - start_row; int row_stride = size_k * sizeof(half) / 16; auto permute_row = [&](int row) { int iters = size_k / default_threads; int rest = size_k % default_threads; int offset = row * row_stride; half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); half* out_half = reinterpret_cast(out_int4_ptr + offset); int base_k = 0; for (int i = 0; i < iters; i++) { int cur_k = base_k + threadIdx.x; int src_pos = perm_int_ptr[cur_k]; out_half[cur_k] = a_row_half[src_pos]; base_k += default_threads; } if (rest) { if (threadIdx.x < rest) { int cur_k = base_k + threadIdx.x; int src_pos = perm_int_ptr[cur_k]; out_half[cur_k] = a_row_half[src_pos]; } } }; for (int i = 0; i < cur_block_rows; i++) { int cur_row = start_row + i; if (cur_row < size_m) { permute_row(cur_row); } } } template shared // fetch pipeline const bool has_act_order, // whether act_order is enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > __device__ void Marlin(const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn int4* __restrict__ C, // fp16 output buffer of shape mxn const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m, should be divisible by (16 * thread_m_blocks) if bigger than that int prob_n, // output dimension n int prob_k, // reduction dimension k int* locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM // example: // 0 1 3 // 0 2 3 // 1 2 4 // While this kind of partitioning makes things somewhat more complicated, it // ensures good utilization of all SMs for many kinds of shape and GPU // configurations, while requiring as few slow global cross-threadblock // reductions as possible. using Dtype = ScalarType; using scalar_t2 = typename ScalarType::scalar_t2; using FragA = typename ScalarType::FragA; using FragB = typename ScalarType::FragB; using FragC = typename ScalarType::FragC; using FragS = typename ScalarType::FragS; constexpr int pack_factor = 32 / num_bits; // int prob_m = *prob_m_ptr; // const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks); // constexpr int thread_m_blocks = template_thread_m_blocks; // For larger GEMMs we run multiple batchsize 64 versions in parallel for a // better partitioning with less reductions int parallel = 1; if (prob_m > 16 * thread_m_blocks) { parallel = prob_m / (16 * thread_m_blocks); prob_m = 16 * thread_m_blocks; } int k_tiles = prob_k / 16 / thread_k_blocks; int n_tiles = prob_n / 16 / thread_n_blocks; int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); if constexpr (!has_act_order && group_blocks != -1) { if (group_blocks >= thread_k_blocks) { // Ensure that the number of tiles in each stripe is a multiple of the // groupsize; this avoids an annoying special case where a stripe starts // in the middle of group. iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks)); } } int slice_row = (iters * blockIdx.x) % k_tiles; int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col = slice_col_par; int slice_iters; // number of threadblock tiles in the current slice int slice_count = 0; // total number of active threadblocks in the current slice int slice_idx; // index of threadblock in current slice; numbered bottom to // top // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers if (slice_col_par >= n_tiles) { A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; locks += (slice_col_par / n_tiles) * n_tiles; slice_col = slice_col_par % n_tiles; } // Compute all information about the current slice which is required for // synchronization. auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; if (slice_iters == 0) return; if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = div_ceil(k_tiles - col_off, iters); if (col_off > 0) slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; if (col_off > 0) slice_idx--; } } if (slice_col == n_tiles) { A += 16 * thread_m_blocks * prob_k / 8; C += 16 * thread_m_blocks * prob_n / 8; locks += n_tiles; slice_col = 0; } }; init_slice(); // A sizes/strides // stride of the A matrix in global memory int a_gl_stride = prob_k / 8; // stride of an A matrix tile in shared memory constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // between subsequent accesses within a tile int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between shared memory writes constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory tile reads constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // within a shared memory tile constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // overall size of a tile constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // number of shared write iterations for a tile constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); // B sizes/strides int b_gl_stride = 16 * prob_n / (pack_factor * 4); constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); constexpr int b_sh_wr_delta = threads * b_thread_vecs; constexpr int b_sh_rd_delta = threads * b_thread_vecs; constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order int s_gl_stride = prob_n / 8; constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; // Scale size/strides with act_order constexpr int tb_k = 16 * thread_k_blocks; constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; // constexpr int act_s_row_stride = 1; // int act_s_col_stride = act_s_row_stride * num_groups; int act_s_col_stride = 1; int act_s_col_warp_stride = act_s_col_stride * 8; int tb_n_warps = thread_n_blocks / 4; int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; // Global A read index of current thread. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); a_gl_rd += a_gl_rd_delta_o * slice_row; // Shared write index of current thread. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); // Shared read index. int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; int b_sh_wr = threadIdx.x * b_thread_vecs; int b_sh_rd = threadIdx.x * b_thread_vecs; // For act_order constexpr int k_iter_size = tb_k / b_sh_wr_iters; int slice_k_start = tb_k * slice_row; int slice_k_finish = slice_k_start + tb_k * slice_iters; int slice_k_start_shared_fetch = slice_k_start; int slice_n_offset = act_s_col_tb_stride * slice_col; // No act_order int s_gl_rd; if constexpr (!has_act_order) { if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; } } int s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; // Precompute which thread should not read memory in which iterations; this is // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; } // To ensure that writing and reading A tiles to/from shared memory, the // latter in fragment format, is fully bank conflict free, we need to use a // rather fancy XOR-based layout. The key here is that neither reads nor // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the // same shared memory banks. Further, it seems (based on NSight-Compute) that // each warp must also write a consecutive memory segment? auto transform_a = [&](int i) { int row = i / a_gl_rd_delta_o; return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; }; // Since the computation of this remapping is non-trivial and, due to our main // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); } int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll for (int j = 0; j < thread_m_blocks; j++) { a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); } } // Since B-accesses have non-constant stride they have to be computed at // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. const int4* B_ptr[b_sh_wr_iters]; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. int4* sh_a = sh; int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_g_idx = sh_b + (stages * b_sh_stage); int4* sh_s = sh_g_idx + (stages * g_idx_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; FragS frag_s[2][4]; // No act-order FragS act_frag_s[2][4][4]; // For act-order // Zero accumulators. auto zero_accums = [&]() { #pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) { reinterpret_cast(frag_c)[i] = 0; } }; int sh_first_group_id = -1; int sh_num_groups = -1; constexpr int sh_max_num_groups = 32; auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) { sh_first_group_id = first_group_id; sh_num_groups = last_group_id - first_group_id + 1; if (sh_num_groups < sh_max_num_groups) { sh_num_groups = sh_max_num_groups; } if (sh_first_group_id + sh_num_groups > num_groups) { sh_num_groups = num_groups - sh_first_group_id; } int row_offset = first_group_id * s_gl_stride; if (is_async) { for (int i = 0; i < sh_num_groups; i++) { if (threadIdx.x < s_sh_stride) { cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); } } } else { for (int i = 0; i < sh_num_groups; i++) { if (threadIdx.x < s_sh_stride) { sh_s[(i * s_sh_stride) + threadIdx.x] = scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; } } } }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { int4* sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll for (int j = 0; j < b_thread_vecs; j++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); } B_ptr[i] += b_gl_rd_delta_o; } if constexpr (has_act_order) { // Fetch g_idx thread-block portion int full_pipe = a_off; int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; if (cur_k < prob_k && cur_k < slice_k_finish) { int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; int4 const* cur_g_idx_stage_ptr = reinterpret_cast(&g_idx[cur_k]); if (threadIdx.x < g_idx_stage) { cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]); } } } else { if constexpr (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * pipe; if constexpr (group_blocks >= thread_k_blocks) { // Only fetch scales if this tile starts a new group if (pipe % (group_blocks / thread_k_blocks) == 0) { if (s_sh_wr_pred) { cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); } s_gl_rd += s_gl_rd_delta; } } else { for (int i = 0; i < s_tb_groups; i++) { if (s_sh_wr_pred) { cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); } s_gl_rd += s_gl_rd_delta; } } } } } // Insert a fence even when we are winding down the pipeline to ensure that // waiting is also correct at this point. cp_async_fence(); }; // Wait until the next thread tile has been loaded to shared memory. auto wait_for_stage = [&]() { // We only have `stages - 2` active fetches since we are double buffering // and can only issue the next fetch when it is guaranteed that the previous // shared memory load is fully complete (as it may otherwise be // overwritten). cp_async_wait(); __syncthreads(); }; // Load the next sub-tile from the current location in the shared memory pipe // into the current register buffer. auto fetch_to_registers = [&](int k, int pipe) { int4* sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); } int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_thread_vecs; i++) { frag_b_quant[k % 2][i] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); } }; bool is_same_group[stages]; int same_group_id[stages]; auto init_same_group = [&](int pipe) { if constexpr (!has_act_order) { is_same_group[pipe] = false; same_group_id[pipe] = 0; return; } int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); int group_id_1 = sh_g_idx_int_ptr[0]; int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; is_same_group[pipe] = group_id_1 == group_id_2; same_group_id[pipe] = group_id_1; }; auto fetch_scales_to_registers = [&](int k, int full_pipe) { int pipe = full_pipe % stages; if constexpr (!has_act_order) { // No act-order case if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } else { int warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; int warp_row = warp_id / n_warps; int cur_k = warp_row * 16; cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; int cur_group_id = k_blocks / group_blocks; int4* sh_s_stage = sh_s + s_sh_stage * pipe; reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; } } return; } // Act-order case // Determine K of the "current" thread-block int cur_k = slice_k_start + tb_k * full_pipe; if (cur_k >= prob_k || cur_k >= slice_k_finish) { return; } // Reset (to current thread-block) since we read g_idx portion from the // shared memory cur_k = 0; // Progress to current iteration cur_k += k_iter_size * (k % b_sh_wr_iters); // Determine "position" inside the thread-block (based on warp and // thread-id) int warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N int warp_row = warp_id / n_warps; int warp_col = warp_id % n_warps; cur_k += warp_row * 16; int th_id = threadIdx.x % 32; cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix int s_col_shift = /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride; if (is_same_group[pipe]) { if (k % 2 == 0) { *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; } else { *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); } for (int i = 1; i < 4; i++) { *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); } return; } int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); constexpr int k_frag_offsets[4] = { 0, 1, 8, 9 }; // Tensor core offsets per thread #pragma unroll for (int i = 0; i < 4; i++) { int actual_k = cur_k + k_frag_offsets[i]; int group_id = sh_g_idx_int_ptr[actual_k]; int rel_group_id = group_id - sh_first_group_id; *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; } }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { FragB frag_b0; FragB frag_b1; if constexpr (num_bits == 4) { int b_quant = frag_b_quant[k % 2][0][j]; int b_quant_shift = b_quant >> 8; frag_b0 = dequant_4bit(b_quant); frag_b1 = dequant_4bit(b_quant_shift); } else { int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; frag_b0 = dequant_8bit(b_quant_0); frag_b1 = dequant_8bit(b_quant_1); } // Apply scale to frag_b0 if constexpr (has_act_order) { scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); } else { if constexpr (group_blocks != -1) { scale(frag_b0, frag_s[k % 2][j], 0); } } // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); } else { if constexpr (group_blocks != -1) { scale(frag_b1, frag_s[k % 2][j], 1); } } #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); } } }; // Since we slice across the k dimension of a tile in order to increase the // number of warps while keeping the n dimension of a tile reasonable, we have // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&]() { constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { int red_idx = threadIdx.x / b_sh_stride_threads; constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; constexpr int red_sh_delta = b_sh_stride_threads; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); // Parallel logarithmic shared memory reduction. We make sure to avoid any // unnecessary read or write iterations, e.g., for two warps we write only // once by warp 1 and read only once by warp 0. #pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { #pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { #pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { float* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); float* c_wr = reinterpret_cast(&sh[red_sh_wr]); #pragma unroll for (int k = 0; k < 4; k++) reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { #pragma unroll for (int i = 0; i < 4 * 2; i++) { float* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); #pragma unroll for (int j = 0; j < 4; j++) reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } __syncthreads(); } } }; // Since multiple threadblocks may process parts of the same column slice, we // finally have to globally reduce over the results. As the striped // partitioning minimizes the number of such reductions and our outputs are // usually rather small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out // results in FP16 (but still reduce with FP32 compute). constexpr int active_threads = 32 * thread_n_blocks / 4; if (threadIdx.x < active_threads) { int c_gl_stride = prob_n / 8; int c_gl_wr_delta_o = 8 * c_gl_stride; int c_gl_wr_delta_i = 4 * (active_threads / 32); int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; c_gl_wr += (2 * thread_n_blocks) * slice_col; constexpr int c_sh_wr_delta = active_threads; int c_sh_wr = threadIdx.x; int row = (threadIdx.x % 32) / 4; if (!first) { // Interestingly, doing direct global accesses here really seems to mess up // the compiler and lead to slowdowns, hence we also use async-copies even // though these fetches are not actually asynchronous. #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { cp_async4_pred( &sh[c_sh_wr + c_sh_wr_delta * i], &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); } cp_async_fence(); cp_async_wait<0>(); } #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; #pragma unroll for (int j = 0; j < 2 * 4; j++) { reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += Dtype::num2float(reinterpret_cast(&c_red)[j]); } } if (!last) { int4 c; #pragma unroll for (int j = 0; j < 2 * 4; j++) { reinterpret_cast(&c)[j] = Dtype::float2num(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; } } } } }; // Write out the reduce final result in the correct layout. We only actually // reshuffle matrix fragments in this step, the reduction above is performed // in fragment layout. auto write_result = [&]() { int c_gl_stride = prob_n / 8; constexpr int c_sh_stride = 2 * thread_n_blocks + 1; int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); c_gl_wr += (2 * thread_n_blocks) * slice_col; int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; c_sh_wr += 32 * (threadIdx.x / 32); int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); int c_gl_wr_end = c_gl_stride * prob_m; // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS& s) { scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); // For per-column quantization we finally apply the scale here (only for // 4-bit) if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) { res = __hmul2(res, s[0]); } ((scalar_t2*)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); } c_sh_wr += 16 * (4 * c_sh_stride); } } __syncthreads(); #pragma unroll for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { if (c_gl_wr < c_gl_wr_end) { C[c_gl_wr] = sh[c_sh_rd]; c_gl_wr += c_gl_wr_delta; c_sh_rd += c_sh_rd_delta; } } }; // Start global fetch and register load pipelines. auto start_pipes = [&]() { #pragma unroll for (int i = 0; i < stages - 1; i++) { if (has_act_order && i == 0) { int last_g_idx = slice_k_start + stages * tb_k * 2; if (last_g_idx >= prob_k) { last_g_idx = prob_k - 1; } fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); } fetch_to_shared(i, i, i < slice_iters); } zero_accums(); wait_for_stage(); init_same_group(0); fetch_to_registers(0, 0); fetch_scales_to_registers(0, 0); a_gl_rd += a_gl_rd_delta_o * (stages - 1); slice_k_start_shared_fetch += tb_k * (stages - 1); }; if (slice_iters) { start_pipes(); } // Main loop. while (slice_iters) { // We unroll over both the global fetch and the register load pipeline to // ensure all shared memory accesses are static. Note that both pipelines // have even length meaning that the next iteration will always start at // index 0. #pragma unroll for (int pipe = 0; pipe < stages;) { #pragma unroll for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); fetch_scales_to_registers(k + 1, pipe); if (k == b_sh_wr_iters - 2) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); pipe++; wait_for_stage(); init_same_group(pipe % stages); } matmul(k); } slice_iters--; if (slice_iters == 0) { break; } } a_gl_rd += a_gl_rd_delta_o * stages; slice_k_start += tb_k * stages; slice_k_start_shared_fetch += tb_k * stages; if constexpr (has_act_order) { int first_group_id = g_idx[slice_k_start]; int last_g_idx = slice_k_start + stages * tb_k * 2; if (last_g_idx >= prob_k) { last_g_idx = prob_k - 1; } int last_group_id = g_idx[last_g_idx]; if (last_group_id >= sh_first_group_id + sh_num_groups) { fetch_scales_to_shared(false, first_group_id, last_group_id); __syncthreads(); } } // Process results and, if necessary, proceed to the next column slice. // While this pattern may not be the most readable, other ways of writing // the loop seemed to noticeably worse performance after compilation. if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out if constexpr (!has_act_order && group_blocks == -1) { if constexpr (num_bits == 8) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); } else { if (last) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { if constexpr (num_bits == 8) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } else { if (last) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } } } // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { scale_float( reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); scale_float( reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + 0]); scale_float( reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); scale_float( reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); } } } } if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; slice_col_par++; slice_col++; init_slice(); if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; if (slice_col == 0) { #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } // Update slice k/n for scales loading if constexpr (has_act_order) { slice_k_start = tb_k * slice_row; slice_k_finish = slice_k_start + tb_k * slice_iters; slice_k_start_shared_fetch = slice_k_start; slice_n_offset = act_s_col_tb_stride * slice_col; } else { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } start_pipes(); } } } } template shared // fetch pipeline const bool has_act_order, // whether act_order is enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > __global__ void Marlin_wrapper(const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn int4* __restrict__ C, // fp16 output buffer of shape mxn const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k int num_groups, // number of scale groups per output channel const int* __restrict__ prob_m_ptr, // batch dimension m int prob_n, // output dimension n int prob_k, // reduction dimension k int* locks // extra global storage for barrier synchronization ) { int prob_m = *prob_m_ptr; prob_m = min(prob_m, 1024); const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks); if(prob_m > 16 * thread_m_blocks) prob_m = (16 * thread_m_blocks) * div_ceil(prob_m, (16 * thread_m_blocks)); /*if (blockIdx.x == 0 && threadIdx.x == 0) printf("marlin prob_m %d\n", prob_m);*/ if (thread_m_blocks == 1) { Marlin( A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n, prob_k, locks); } else if (thread_m_blocks == 2) { Marlin( A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n, prob_k, locks); } else if (thread_m_blocks == 3) { Marlin( A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n, prob_k, locks); } else if (thread_m_blocks == 4) { Marlin( A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n, prob_k, locks); } } #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ Marlin_wrapper, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ Marlin_wrapper<<>>( \ A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m_ptr, prob_n, \ prob_k, locks); \ } typedef struct { int thread_k; int thread_n; int num_threads; } thread_config_t; typedef struct { int max_m_blocks; thread_config_t tb_cfg; } exec_config_t; thread_config_t small_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads {128, 128, 256}, {64, 128, 128}, {128, 64, 128}, }; thread_config_t large_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads {64, 256, 256}, // {128, 128, 256}, {64, 128, 128}, {128, 64, 128}, }; int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full) { bool cache_scales_chunk = has_act_order && !is_k_full; int tb_n = th_config.thread_n; int tb_k = th_config.thread_k; // Get max scale groups per thread-block int tb_groups; if (group_size == -1) { tb_groups = 1; } else if (group_size == 0) { tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size } else { tb_groups = div_ceil(tb_k, group_size); } if (cache_scales_chunk) { int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K load_groups = max(load_groups, 32); // We load at least 32 scale groups return load_groups * tb_n * 2; } else { int tb_scales = tb_groups * tb_n * 2; return tb_scales * pipe_stages; } } bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int scales_cache_size, int max_shared_mem) { int pack_factor = 32 / num_bits; // Get B size int tb_k = th_config.thread_k; int tb_n = th_config.thread_n; int b_size = (tb_k * tb_n / pack_factor) * 4; // Get A size int m_blocks = div_ceil(prob_m, 16); int tb_max_m = 16; // zbx: too ugly // origin /*while (true) { if (m_blocks >= max_m_blocks) { tb_max_m *= max_m_blocks; break; } max_m_blocks--; if (max_m_blocks == 0) { TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); } }*/ // refactor tb_max_m *= std::min(m_blocks, max_m_blocks); int a_size = (tb_max_m * tb_k) * 2; float pipe_size = (a_size + b_size) * pipe_stages; TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); } bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { return false; } // Verify K/N are divisible by thread K/N if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { return false; } // Verify min for thread K/N if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { return false; } // num_threads must be at least 128 (= 4 warps) if (th_config.num_threads < 128) { return false; } // Determine cache for scales int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); // Check that pipeline fits into cache if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, num_bits, scales_cache_size, max_shared_mem)) { return false; } return true; } exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int max_shared_mem) { int max_m_blocks = 4; while (max_m_blocks > 0) { if (prob_m <= 16) { for (auto th_config : small_batch_thread_configs) { if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, max_shared_mem)) { return exec_config_t{ max_m_blocks, th_config }; } } } else { for (auto th_config : large_batch_thread_configs) { if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, max_shared_mem)) { return exec_config_t{ max_m_blocks, th_config }; } } } max_m_blocks--; // Process less M blocks per invocation to reduce cache // usage } return exec_config_t{ 0, {-1, -1, -1} }; } #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) template void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* g_idx, void* perm, void* a_tmp, int* prob_m_ptr, int prob_m, int prob_n, int prob_k, void* workspace, int num_bits, bool has_act_order, bool is_k_full, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k, int thread_n, int sms, int max_par) { TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); int tot_m = prob_m; int tot_m_blocks = div_ceil(tot_m, 16); int pad = 16 * tot_m_blocks - tot_m; if (sms == -1) { cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); } int max_shared_mem = 0; cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); TORCH_CHECK(max_shared_mem > 0); // Set thread config exec_config_t exec_cfg; if (thread_k != -1 && thread_n != -1) { // User-defined config exec_cfg = exec_config_t{ 4, thread_config_t{thread_k, thread_n, default_threads} }; } else { // Auto config exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, max_shared_mem); } TORCH_CHECK( exec_cfg.max_m_blocks > 0 && is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, max_shared_mem), "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, ", thread_k = ", exec_cfg.tb_cfg.thread_k, ", thread_n = ", exec_cfg.tb_cfg.thread_n, ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, ", group_size = ", group_size, ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, ", max_shared_mem = ", max_shared_mem); int num_threads = exec_cfg.tb_cfg.num_threads; thread_k = exec_cfg.tb_cfg.thread_k; thread_n = exec_cfg.tb_cfg.thread_n; int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; int blocks = sms; TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); int group_blocks = 0; if (has_act_order) { if (is_k_full) { TORCH_CHECK(group_size != -1); group_blocks = group_size / 16; TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); } else { TORCH_CHECK(group_size == 0); group_blocks = 0; } } else { if (group_size == -1) { group_blocks = -1; } else { group_blocks = group_size / 16; TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); } } const int4* A_ptr = (const int4*)A; const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; const int4* s_ptr = (const int4*)s; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; int4* a_tmp_ptr = (int4*)a_tmp; int* locks = (int*)workspace; if (has_act_order) { // Permute A columns int block_rows = div_ceil(prob_m, blocks); permute_cols_kernel << > > ( A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows); A_ptr = a_tmp_ptr; } // If we have a full K, then we can run the non-act-order version of Marlin // (since the weight rows are reordered by increasing group ids, and by // having a full K, we have full original groups) if (is_k_full) { has_act_order = false; } // Main loop for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { int thread_m_blocks = tot_m_blocks - i; prob_m = tot_m - 16 * i; int par = 1; if (thread_m_blocks > exec_cfg.max_m_blocks) { // Note that parallel > 1 currently only works for inputs without // any padding par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); if (par > max_par) par = max_par; prob_m = (16 * exec_cfg.max_m_blocks) * par; i += exec_cfg.max_m_blocks * (par - 1); thread_m_blocks = exec_cfg.max_m_blocks; } // Define kernel configurations #define undefined_error \ TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + \ str(prob_n) + ", " + str(prob_k) + "]" + \ ", has_act_order = " + str(has_act_order) + \ ", num_groups = " + str(num_groups) + \ ", group_size = " + str(group_size) + \ ", thread_m_blocks = " + str(thread_m_blocks) + \ ", thread_n_blocks = " + str(thread_n_blocks) + \ ", thread_k_blocks = " + str(thread_k_blocks)); /* std::cout << "MNK = [" + str(prob_m) + ", " + \ str(prob_n) + ", " + str(prob_k) + "]" + \ ", has_act_order = " + str(has_act_order) + \ ", num_groups = " + str(num_groups) + \ ", group_size = " + str(group_size) + \ ", thread_m_blocks = " + str(thread_m_blocks) + \ ", thread_n_blocks = " + str(thread_n_blocks) + \ ", thread_k_blocks = " + str(thread_k_blocks) << std::endl;*/ /*if (false) { } // CALL_IF(4, 32, 2, 256) // CALL_IF(4, 16, 4, 256) __CALL_IF(4, 1, 16, 4, false, 4, 256) __CALL_IF(4, 2, 16, 4, false, 4, 256) // CALL_IF(4, 8, 8, 256) __CALL_IF(4, 1, 8, 8, false, 4, 256) __CALL_IF(4, 2, 8, 8, false, 4, 256) // CALL_IF(4, 16, 4, 128) __CALL_IF(4, 1, 16, 4, false, 4, 128) __CALL_IF(4, 2, 16, 4, false, 4, 128) // CALL_IF(4, 8, 8, 128) __CALL_IF(4, 1, 8, 8, false, 4, 128) __CALL_IF(4, 2, 8, 8, false, 4, 128) else {undefined_error}*/ if (num_bits == 4 && num_threads == 256) { if (false) { } CALL_IF(4, 32, 2, 256) CALL_IF(4, 16, 4, 256) CALL_IF(4, 8, 8, 256) else { undefined_error } } else if (num_bits == 4 && num_threads == 128) { if (false) { } CALL_IF(4, 8, 4, 128) CALL_IF(4, 16, 4, 128) CALL_IF(4, 4, 8, 128) else { undefined_error } } // else if (num_bits == 8 && num_threads == 256) // { // if (false) { // } // CALL_IF(8, 32, 2, 256) // CALL_IF(8, 16, 4, 256) // CALL_IF(8, 8, 8, 256) // else { // undefined_error // } // } // else if (num_bits == 8 && num_threads == 128) // { // if (false) { // } // CALL_IF(8, 8, 4, 128) // CALL_IF(8, 16, 4, 128) // CALL_IF(8, 4, 8, 128) // else { // undefined_error // } // } else { undefined_error } A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; } } } // namespace gptq_marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, torch::Tensor size_m_tensor, int64_t size_m, int64_t size_n, int64_t size_k, int sms, bool is_k_full) { const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); // Verify num_bits TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); int pack_factor = 32 / num_bits; // Verify A TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), ", size_m = ", size_m); TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), ", size_k = ", size_k); // Verify B TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, " is not divisible by tile_size = ", gptq_marlin::tile_size); TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, "b_q_weight.size(1) = ", b_q_weight.size(1), " is not divisible by tile_size = ", gptq_marlin::tile_size); int actual_size_n = (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, ", actual_size_n = ", actual_size_n); // Verify device and strides TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); // Alloc buffers auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); torch::Tensor c = torch::empty({ size_m, size_n }, options); torch::Tensor a_tmp = torch::empty({ size_m, size_k }, options); // thread_k: `k` size of a thread_tile in `weights` (can usually be left as // auto -1) int thread_k = -1; // thread_n: `n` size of a thread_tile in `weights` (can usually be left as // auto -1) int thread_n = -1; // sms: number of SMs to use for the kernel (can usually be left as auto -1) // int sms = -1; //zbx // Verify g_idx and perm TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || (g_idx.size(0) == size_k && perm.size(0) == size_k), "Unexpected g_idx.size(0) = ", g_idx.size(0), " and perm.size(0) = ", perm.size(0), ", where size_k = ", size_k); // Detect groupsize and act_order int num_groups = -1; int group_size = -1; bool has_act_order = g_idx.size(0) != 0; int b_rank = b_scales.sizes().size(); TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), " is not size_n = ", size_n); num_groups = b_scales.size(0); if (has_act_order) { if (is_k_full) { TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); group_size = size_k / num_groups; } else { group_size = 0; } } else { if (num_groups > 1) { TORCH_CHECK( size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); group_size = size_k / num_groups; } else { group_size = -1; } } // Verify workspace size TORCH_CHECK( size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); int min_workspace_size = (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, "workspace.numel = ", workspace.numel(), " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { gptq_marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m_tensor.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, gptq_marlin::max_par); } else if (a.scalar_type() == at::ScalarType::BFloat16) { gptq_marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m_tensor.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, gptq_marlin::max_par); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } return c; } #endif ================================================ FILE: archive/csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh ================================================ // Adapted from // https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin // Copyrigth 2024 The vLLM team. // Copyright (c) 2024 by KVCache.AI, All Rights Reserved. #pragma once #include #include #include #include #include #include #include namespace gptq_marlin { // 8 warps are a good choice since every SM has 4 schedulers and having more // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. static constexpr int default_threads = 256; static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; static constexpr int tile_size = 16; static constexpr int max_par = 16; template struct Vec { T elems[n]; __device__ T &operator[](int i) { return elems[i]; } }; using I4 = Vec; constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // No support for async #else __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, bool pred = true) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("{\n" " .reg .pred p;\n" " setp.ne.b32 p, %0, 0;\n" " @p cp.async.cg.shared.global [%1], [%2], %3;\n" "}\n" ::"r"((int)pred), "r"(smem), "l"(glob_ptr), "n"(BYTES)); } __device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("{\n" " cp.async.cg.shared.global [%0], [%1], %2;\n" "}\n" ::"r"(smem), "l"(glob_ptr), "n"(BYTES)); } __device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); } template __device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); } #endif } // namespace gptq_marlin ================================================ FILE: archive/csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh ================================================ // Adapted from // https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin // Copyrigth 2024 The vLLM team. // Copyright (c) 2024 by KVCache.AI, All Rights Reserved. #ifndef _data_types_cuh #define _data_types_cuh #include "gptq_marlin.cuh" #include #include namespace gptq_marlin { template class ScalarType {}; template <> class ScalarType { public: using scalar_t = half; using scalar_t2 = half2; // Matrix fragments for tensor core instructions; their precise layout is // documented here: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; static __device__ float inline num2float(const half x) { return __half2float(x); } static __device__ half2 inline num2num2(const half x) { return __half2half2(x); } static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); } static __host__ __device__ half inline float2num(const float x) { return __float2half(x); } }; template <> class ScalarType { public: using scalar_t = nv_bfloat16; using scalar_t2 = nv_bfloat162; using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); } static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); } static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); } static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); } #endif }; } // namespace gptq_marlin #endif ================================================ FILE: archive/csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu ================================================ #include "gptq_marlin.cuh" namespace gptq_marlin { static constexpr int repack_stages = 8; static constexpr int repack_threads = 256; static constexpr int tile_k_size = tile_size; static constexpr int tile_n_size = tile_k_size * 4; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 template __global__ void marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {} } // namespace gptq_marlin torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { TORCH_CHECK_NOT_IMPLEMENTED( false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); } #else template __global__ void marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { constexpr int pack_factor = 32 / num_bits; int k_tiles = size_k / tile_k_size; int n_tiles = size_n / tile_n_size; int block_k_tiles = div_ceil(k_tiles, gridDim.x); int start_k_tile = blockIdx.x * block_k_tiles; if (start_k_tile >= k_tiles) { return; } int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); // Wait until the next thread tile has been loaded to shared memory. auto wait_for_stage = [&]() { // We only have `stages - 2` active fetches since we are double buffering // and can only issue the next fetch when it is guaranteed that the previous // shared memory load is fully complete (as it may otherwise be // overwritten). cp_async_wait(); __syncthreads(); }; extern __shared__ int4 sh[]; constexpr int perm_size = tile_k_size / 4; int4* sh_perm_ptr = sh; int4* sh_pipe_ptr = sh_perm_ptr; if constexpr (has_perm) { sh_pipe_ptr += perm_size; } constexpr int tile_ints = tile_k_size / pack_factor; constexpr int stage_n_threads = tile_n_size / 4; constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; constexpr int stage_size = stage_k_threads * stage_n_threads; auto load_perm_to_shared = [&](int k_tile_id) { int first_k_int4 = (k_tile_id * tile_k_size) / 4; int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); if (threadIdx.x < perm_size) { sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; } __syncthreads(); }; auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { if (n_tile_id >= n_tiles) { cp_async_fence(); return; } int first_n = n_tile_id * tile_n_size; int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; if constexpr (has_perm) { if (threadIdx.x < stage_size) { int k_id = threadIdx.x / stage_n_threads; int n_id = threadIdx.x % stage_n_threads; uint32_t const* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); int src_k = sh_perm_int_ptr[k_id]; int src_k_packed = src_k / pack_factor; cp_async4( &sh_ptr[k_id * stage_n_threads + n_id], reinterpret_cast(&( b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); } } else { if (threadIdx.x < stage_size) { int k_id = threadIdx.x / stage_n_threads; int n_id = threadIdx.x % stage_n_threads; int first_k = k_tile_id * tile_k_size; int first_k_packed = first_k / pack_factor; cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], reinterpret_cast( &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); } } cp_async_fence(); }; auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { if (n_tile_id >= n_tiles) { return; } int warp_id = threadIdx.x / 32; int th_id = threadIdx.x % 32; if (warp_id >= 4) { return; } int tc_col = th_id / 4; int tc_row = (th_id % 4) * 2; constexpr int tc_offsets[4] = {0, 1, 8, 9}; int cur_n = warp_id * 16 + tc_col; constexpr int sh_stride = 64; constexpr uint32_t mask = (1 << num_bits) - 1; int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); uint32_t vals[8]; if constexpr (has_perm) { for (int i = 0; i < 4; i++) { int k_idx = tc_row + tc_offsets[i]; uint32_t src_k = sh_perm_int_ptr[k_idx]; uint32_t src_k_pos = src_k % pack_factor; uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; vals[i] = b1_cur_val; vals[4 + i] = b2_cur_val; } } else { uint32_t b1_vals[tile_ints]; uint32_t b2_vals[tile_ints]; #pragma unroll for (int i = 0; i < tile_ints; i++) { b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; } #pragma unroll for (int i = 0; i < 4; i++) { int cur_elem = tc_row + tc_offsets[i]; int cur_int = cur_elem / pack_factor; int cur_pos = cur_elem % pack_factor; vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; } } constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; // Result of: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h if constexpr (num_bits == 4) { constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; uint32_t res = 0; #pragma unroll for (int i = 0; i < 8; i++) { res |= vals[pack_idx[i]] << (i * 4); } out_ptr[out_offset + th_id * 4 + warp_id] = res; } else { constexpr int pack_idx[4] = {0, 2, 1, 3}; uint32_t res1 = 0; uint32_t res2 = 0; #pragma unroll for (int i = 0; i < 4; i++) { res1 |= vals[pack_idx[i]] << (i * 8); res2 |= vals[4 + pack_idx[i]] << (i * 8); } out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; } }; auto start_pipes = [&](int k_tile_id, int n_tile_id) { #pragma unroll for (int pipe = 0; pipe < repack_stages - 1; pipe++) { fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); } wait_for_stage(); }; #pragma unroll for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { int n_tile_id = 0; if constexpr (has_perm) { load_perm_to_shared(k_tile_id); } start_pipes(k_tile_id, n_tile_id); while (n_tile_id < n_tiles) { #pragma unroll for (int pipe = 0; pipe < repack_stages; pipe++) { fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); repack_tile(pipe, k_tile_id, n_tile_id + pipe); wait_for_stage(); } n_tile_id += repack_stages; } } } } // namespace gptq_marlin #define CALL_IF(NUM_BITS, HAS_PERM) \ else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ cudaFuncSetAttribute( \ gptq_marlin::marlin_repack_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ gptq_marlin::marlin_repack_kernel \ <<>>( \ b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ } torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { // Verify compatibility with marlin tile of 16x64 TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); int const pack_factor = 32 / num_bits; // Verify B TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), ", size_k = ", size_k, ", pack_factor = ", pack_factor); TORCH_CHECK(b_q_weight.size(1) == size_n, "b_q_weight.size(1) = ", b_q_weight.size(1), " is not size_n = ", size_n); // Verify device and strides TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt"); // Alloc buffers const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); auto options = torch::TensorOptions() .dtype(b_q_weight.dtype()) .device(b_q_weight.device()); torch::Tensor out = torch::empty({size_k / gptq_marlin::tile_size, size_n * gptq_marlin::tile_size / pack_factor}, options); // Detect if there is act_order bool has_perm = perm.size(0) != 0; // Get ptrs uint32_t const* b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); uint32_t const* perm_ptr = reinterpret_cast(perm.data_ptr()); uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); // Get dev info int dev = b_q_weight.get_device(); cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); int blocks; cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); int max_shared_mem = 0; cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); TORCH_CHECK(max_shared_mem > 0); if (false) { } CALL_IF(4, false) CALL_IF(4, true) CALL_IF(8, false) CALL_IF(8, true) else { TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, ", has_perm = ", has_perm); } return out; } #endif ================================================ FILE: archive/csrc/custom_marlin/gptq_marlin/ops.h ================================================ /** * @Description : * @Author : Azure * @Date : 2024-07-22 09:27:55 * @Version : 1.0.0 * @LastEditors : Azure * @LastEditTime : 2024-07-26 08:35:00 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #pragma once #include #include #include torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor &b_scales, torch::Tensor &g_idx, torch::Tensor &perm, torch::Tensor &workspace, int64_t num_bits, torch::Tensor size_m_tensor, int64_t size_m, int64_t size_n, int64_t size_k, int sms, bool is_k_full); torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor&perm, int64_t size_k, int64_t size_n, int64_t num_bits); ================================================ FILE: archive/csrc/custom_marlin/setup.py ================================================ from setuptools import setup, Extension from torch.utils import cpp_extension from torch.utils.cpp_extension import BuildExtension, CUDAExtension setup( name='vLLMMarlin', ext_modules=[ CUDAExtension( 'vLLMMarlin', [ #'custom_gguf/dequant.cu', 'binding.cpp', 'gptq_marlin/gptq_marlin.cu', 'gptq_marlin/gptq_marlin_repack.cu', ], extra_compile_args={ 'cxx': ['-O3'], 'nvcc': [ '-O3', '--use_fast_math', '-Xcompiler', '-fPIC', ] }, ) ], cmdclass={'build_ext': BuildExtension} ) ================================================ FILE: archive/csrc/custom_marlin/test_cuda_graph.py ================================================ import csv import torch import torch.nn as nn import vLLMMarlin torch.set_grad_enabled(False) from utils.marlin_utils import ( MarlinWorkspace, marlin_quantize, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MIN_THREAD_K, GPTQ_MARLIN_MAX_PARALLEL, ) def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) setup_seed(20241223) torch.set_grad_enabled(False) torch.set_default_dtype(torch.bfloat16) global_dtype=torch.bfloat16 global_device=torch.device("cuda",0) global_num_cases:int=int(50) torch.cuda.set_device(0) torch.backends.cudnn.enabled =True torch.backends.cudnn.benchmark = True max_batch_size = 512 max_tp = 8 L2_size = 73728 * 1024 def get_usable_mem(): properties = torch.cuda.get_device_properties(global_device) #print(f"Total memory: {properties.total_memory / (1024 ** 3):.2f} GB") allocated_memory = torch.cuda.memory_allocated(global_device) #print(f"Currently allocated memory: {allocated_memory / (1024 ** 2):.2f} MB") reserved_memory = torch.cuda.memory_reserved(global_device) #print(f"Currently reserved memory: {reserved_memory / (1024 ** 2):.2f} MB") return properties.total_memory - 512 * 1024 ** 2 - allocated_memory# - reserved_memory def exp_range(start, stop, step = 2): now = start while now <= stop: yield now now *= step def timing(func, iters, epochs=100): #warmup for idx in range(iters): func(idx) torch.cuda.synchronize() cuda_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(cuda_graph): for idx in range(iters): func(idx) for _ in range(2000): cuda_graph.replay() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) stream = torch.cuda.Stream() torch.cuda.synchronize() #with torch.cuda.stream(stream): start_event.record() for _ in range(10): cuda_graph.replay() end_event.record() torch.cuda.synchronize() elapsed_time_ms0 = start_event.elapsed_time(end_event) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize() #with torch.cuda.stream(stream): start_event.record() for _ in range(epochs+10): cuda_graph.replay() end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) - elapsed_time_ms0 #print(elapsed_time_ms0, elapsed_time_ms) return elapsed_time_ms/iters/epochs class LinearMarlin(nn.Linear): marlin_q_w: torch.Tensor marlin_s: torch.Tensor g_idx: torch.Tensor sort_indices: torch.Tensor has_bias: bool def __init__( self, in_features, out_features, bias = False, device: str = "cuda", num_bits: int = 4, # 4-bit/8-bit is supported group_size: int = 64, # -1, 32, 64, 128 act_order: bool = False, is_k_full=True, sms = -1, # sms in GPU **kwargs, ): self.padding = False assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" if in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or out_features%GPTQ_MARLIN_MIN_THREAD_K!=0: #print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding") self.padding = True self.orin_in_features = in_features self.orin_out_features = out_features in_features = (in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K out_features = (out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N #print(f"After padding: in_features={in_features}, out_features={out_features}") super().__init__(in_features, out_features, bias, device) self.has_bias = bias self.device = device self.num_bits = num_bits self.group_size = group_size self.act_order = act_order # TODO: optimize every shape GEMM blocks_k, blocks_n = in_features//128, out_features//128 self.sms = sms self.is_k_full = is_k_full self.weight.requires_grad = False self.weight.t_() # Pack Marlin linear #w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( # self.weight, self.num_bits, self.group_size, self.act_order #) marlin_q_w = torch.randint(int(-1e9), int(1e9), (in_features//16, out_features*2), device=device, dtype=torch.int) marlin_s = torch.randn((in_features//64, out_features), device=device) self.workspace = MarlinWorkspace( self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL, self.device ) self.marlin_q_w = marlin_q_w self.marlin_s = marlin_s self.g_idx = torch.empty((0), dtype=torch.int32, device=self.device) self.sort_indices = torch.empty((0), dtype=torch.int32, device=self.device) self.k = self.weight.shape[0] self.n = self.weight.shape[1] self.weight = None """ print(in_features, out_features) print(marlin_q_w.shape) print(marlin_q_w.dtype) print(marlin_s.shape) print(marlin_s.dtype) print(self.workspace.scratch.shape) print(self.workspace.scratch.dtype) print(self.g_idx.shape) print(self.g_idx.dtype) print(self.sort_indices.shape) print(self.sort_indices.dtype) #print(w_ref.shape) #print(w_ref.dtype) """ #w_ref = None def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor: # Only support input x as BF16 and FP16 x = x.to(self.device) orig_shape = list(x.shape) orig_dtype = x.dtype x = x.reshape(-1, x.shape[-1]) if self.padding: padding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype) padding_input[:,:self.orin_in_features] = x x = padding_input marlin_s = self.marlin_s.to(x.dtype) #print(self.sms * ((orig_shape[0]+63)//64)) sms = self.sms x = vLLMMarlin.gptq_marlin_gemm( x, self.marlin_q_w, marlin_s, self.g_idx, self.sort_indices, self.workspace.scratch, self.num_bits, bsz_tensor, x.shape[0], self.n, x.shape[-1], sms, self.is_k_full, ) # TODO: don't padding bias if self.has_bias: x = x + self.bias if self.padding: x = x[:,:self.orin_out_features] orig_shape[-1] = self.orin_out_features else: orig_shape[-1] = self.out_features return x.reshape(orig_shape).to(orig_dtype) def benchLinearMarlin(input_dim, output_dim):#, out_file print("benchmarking MLP Marlin") print("-----------------------------------------------------------") headers = ["batch_size", "tp", "used_time", "bandwidth GB/s", "TFLOPS", "cases", "padding", "sms"] print(" | ".join(headers) + "\n") rows = [] for batch_size in exp_range(1, 64): for tp in exp_range(1, max_tp): torch.cuda.empty_cache() if output_dim % tp != 0: continue cur_output_dim = output_dim // tp modules = [] inputs = [] data_size = int(0.53125*input_dim*cur_output_dim) input_size = int(2*batch_size*input_dim) output_size = int(2*batch_size*cur_output_dim) usable_mem = get_usable_mem() - 2 * input_dim * cur_output_dim min_cases = max(global_num_cases, (2*L2_size) // (data_size+input_size)) cases = int(min(min_cases, (usable_mem * 0.8) // (data_size+input_size))) #print(usable_mem, data_size, input_size, cases) bsz_tensor = torch.tensor([batch_size], device=global_device, dtype=torch.int32) if cases == 0: row = [f"{batch_size}", "OOM", "OOM", "OOM", "0", "False"] rows.append(row) break for _ in range(cases): modules.append(LinearMarlin(input_dim, cur_output_dim, sms=56, non_equal_division=False).to(device=global_device).eval()) inputs.append(torch.randn(batch_size, 1, input_dim, device=global_device)) def forward(case_id): modules[case_id](inputs[case_id], bsz_tensor) used_time = timing(forward, iters=cases) bandwidth = (data_size+input_size+output_size)/used_time/1e6 flops = 2*batch_size*input_dim*cur_output_dim tflops = flops/used_time/1e9 cur_sms = modules[0].sms row = [f"{batch_size}", f"{tp}", f"{used_time}", f"{bandwidth}", f"{tflops}", f"{cases}", modules[0].padding, cur_sms] rows.append(row) print(f"{batch_size}", f"{tp}", f"{used_time}", f"{bandwidth}", f"{tflops}", f"{cases}", modules[0].padding, cur_sms) """ with open(out_file, 'w', newline='') as csvfile: csvwriter = csv.writer(csvfile) csvwriter.writerow(headers) for row in rows: csvwriter.writerow(row) """ """ markdown_table = " | ".join(headers) + "\n" markdown_table += " | ".join(["---"] * len(headers)) + "\n" for row in rows: markdown_table += " | ".join(row) + "\n" print(markdown_table) """ #print("finish write file", out_file) #print("-------------------------------------------------------------") if __name__ == "__main__": benchLinearMarlin(5120, 3584) exit(0) max_batch = 1 cur_batch = 1 marlin_linear = LinearMarlin(5120, 3584) input_tensor = torch.randn(max_batch, 1, 5120, device="cuda", dtype=torch.bfloat16) bsz_tensor = torch.tensor([max_batch], device="cuda", dtype=torch.int32) out_truth = marlin_linear(input_tensor, bsz_tensor) print(out_truth) g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): out_buf = marlin_linear(input_tensor, bsz_tensor) for i in range(10000): g.replay() #torch.testing.assert_close(out_buf, out_truth, rtol=1e-3, atol=1e-3) marlin_linear = LinearMarlin(5120, 3584) g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): out_buf = marlin_linear(input_tensor, bsz_tensor) new_input = torch.randn(cur_batch, 1, 5120, device="cuda", dtype=torch.bfloat16) bsz_tensor.copy_(torch.tensor([cur_batch], device="cuda", dtype=torch.int32)) new_out_truth = marlin_linear(new_input, bsz_tensor) input_tensor[:cur_batch].copy_(new_input) input_tensor[cur_batch:] = 0 g.replay() torch.cuda.synchronize() def printMinMax(tensor): abs_tensor = torch.abs(tensor) min_val = torch.min(abs_tensor) max_val = torch.max(abs_tensor) min_indices = (abs_tensor == min_val).nonzero(as_tuple=True) max_indices = (abs_tensor == max_val).nonzero(as_tuple=True) print(f"min: {min_val.item()}") print(f"min idx: {min_indices}") print(f"max: {max_val.item()}") print(f"max idx: {max_indices}") print(out_buf[:cur_batch].shape) print(new_out_truth.shape) printMinMax(out_buf[:cur_batch]) printMinMax(new_out_truth) #torch.testing.assert_close(out_buf[:cur_batch, 0, :], new_out_truth[:cur_batch, 0, :], rtol=1e-3, atol=1e-3) ================================================ FILE: archive/csrc/custom_marlin/utils/__init__.py ================================================ ================================================ FILE: archive/csrc/custom_marlin/utils/format24.py ================================================ # # Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es). # import torch # This is PyTorch implementation of main part of reorder_meta() # function, from tools/util/include/cutlass/util/host_reorder.h file # of CUTLASS source tree. Furthermore, CUTLASS template for sparse # GEMM decides upon layout of this matrix, and at the moment for the # sparse GEMM executed on tensor cores, this is layout described by # ColumnMajorInterleaved<2> data structure, in # include/cutlass/layout/matrix.h of CUTLASS source tree. The # reordering of meta matrix into meta_reordered matrix calculated # according to these segments of CUTLASS code is re-implemented here. # Note that this calculation produces offsets for scattering metadata # matrix elements into reordered metadata matrix elements (or, # equivalently, for gathering reordered metadata matrix element back # into metadata matrix elements). def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) # Reorder the rows, then swizzle the 2x2 blocks. group_x = 64 group_y = 32 if meta_dtype.itemsize == 2 else 16 dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 + (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 + ((dst_rows % group_x) // 8) * 4) topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) dst_rows += topright - bottomleft dst_cols -= topright - bottomleft # Assumed that meta tensor is to be stored in CUTLASS # InterleavedColumnMajor layout, and reverse engineered # corresponding code to store values into this tensor. interleave = 2 cols_maj = dst_cols // interleave cols_min = dst_cols % interleave return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) # This function converts dense matrix into sparse semi-structured # representation, producing "compressed" matrix, in the layout used by # CUTLASS backend, and corresponding metadata matrix. def sparse_semi_structured_from_dense_cutlass(dense): if dense.dim() != 2: raise RuntimeError( f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 ) m, k = dense.shape device = dense.device meta_dtype = torch.int8 if dense.dtype == torch.int8: meta_dtype = torch.int32 elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: meta_dtype = torch.int16 else: raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 if quadbits_per_meta_elem not in (4, 8): raise RuntimeError( "Invalid number of elements per meta element calculated") if meta_dtype == torch.int32: if m % 16 != 0: raise RuntimeError( f"Number of rows of dense matrix {m} must be divisible by 16") else: if m % 32 != 0: raise RuntimeError( f"Number of rows of dense matrix {m} must be divisible by 32") if k % (4 * quadbits_per_meta_elem) != 0: raise RuntimeError( f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 ) if dense.dtype != torch.float: ksparse = 4 dense_4 = dense.view(-1, k // ksparse, ksparse) m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) else: ksparse = 2 dense_2 = dense.view(-1, k // ksparse, ksparse) m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) meta_ncols = k // (ksparse * quadbits_per_meta_elem) # Encoding quadruples of True/False values as follows: # [True, True, False, False] -> 0b0100 # [True, False, True, False] -> 0b1000 # [False, True, True, False] -> 0b1001 # [True, False, False, True ] -> 0b1100 # [False, True, False, True ] -> 0b1101 # [False, False, True, True ] -> 0b1110 # Thus, lower two bits in the encoding are index of the True value # at the lowest index in the quadruple, and the higher two bits in # the encoding are index of the other True value in the quadruple. # In case there are less than two True values, than False value or # values at some index or indices are considered True for the # encoding. In case there are more than two True values, then the # excess True value(s) at some indices are considered False for # the encoding. The exact encodings used for these cases are as # follows: # [False, False, False, False] -> 0b1110 # [False, False, False, True ] -> 0b1110 # [False, False, True, False] -> 0b1110 # [False, True, False, False] -> 0b1001 # [False, True, True, True ] -> 0b1101 # [True, False, False, False] -> 0b1000 # [True, False, True, True ] -> 0b1100 # [True, True, False, True ] -> 0b0100 # [True, True, True, False] -> 0b0100 # [True, True, True, True ] -> 0b0100 # These particular encodings are chosen, with the help of Espresso # logic minimizer software, for the purpose of minimization of # corresponding Boolean functions, that translate non-zero flags # into encoding bits. Note also possible choices for the first # and last of these encodings were limited only to (0b0100, # 0b1110), in order to produce valid encodings for 1:2 sparsity # case. expr0 = m0 & m1 expr1 = ~m0 & m1 expr2 = ~m0 & ~m1 bit0 = expr1 bit1 = expr2 bit2 = expr0 | expr2 | m3 bit3 = expr1 | ~m1 idxs0 = bit0 | (bit1.to(torch.int64) << 1) idxs1 = bit2 | (bit3.to(torch.int64) << 1) if dense.dtype != torch.float: sparse0 = dense_4.gather( -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) else: sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( m, k // 2) # type: ignore[possibly-undefined] meta_4 = idxs0 | (idxs1 << 2) meta_n = meta_4.view( (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) if quadbits_per_meta_elem == 4: meta = (meta_n[:, :, 0] | (meta_n[:, :, 1] << 4) | (meta_n[:, :, 2] << 8) | (meta_n[:, :, 3] << 12)) elif quadbits_per_meta_elem == 8: meta = (meta_n[:, :, 0] | (meta_n[:, :, 1] << 4) | (meta_n[:, :, 2] << 8) | (meta_n[:, :, 3] << 12) | (meta_n[:, :, 4] << 16) | (meta_n[:, :, 5] << 20) | (meta_n[:, :, 6] << 24) | (meta_n[:, :, 7] << 28)) # Reorder meta tensor elements. meta_reordered = meta.new_empty( (m * meta_ncols, )) # type: ignore[possibly-undefined] meta_offsets = _calculate_meta_reordering_scatter_offsets( m, meta_ncols, meta_dtype, device) meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) return (sparse, meta_reordered.view(m, meta_ncols)) # This function performs reverse of the function above - it # reconstructs dense matrix from a pair of "compressed" matrix, given # in the layout used by CUTLASS backend, and accompanying metadata # matrix. def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): if sparse.dim() != 2: raise RuntimeError( f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 ) m, k = sparse.shape device = sparse.device if meta_reordered.dim() != 2: raise RuntimeError( f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 ) if meta_reordered.device != device: raise RuntimeError( f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 ) meta_dtype = meta_reordered.dtype if meta_dtype not in (torch.int16, torch.int32): raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 ksparse = 4 if sparse.dtype != torch.float else 2 meta_nrows, meta_ncols = meta_reordered.shape if meta_nrows != m: raise RuntimeError( f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 ) if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: raise RuntimeError( f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 "expected according to the number of columns of meta matrix") # Undo meta tensor elements reordering. meta_offsets = _calculate_meta_reordering_scatter_offsets( m, meta_ncols, meta_dtype, device) meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) # Unpack sparse tensor back to original dense tensor, using # information provided by meta tensor. Note that torch.float # datatype is handled pretty much the same as # torch.half/torch.bfloat16, as metadata for a pair of torch.float # value is encoded as if underlying 8 bytes contain four # torch.half/torch.bfloat16 values, where either first two or last # two are zeros. meta_2 = torch.empty( (m, meta_ncols, 2 * quadbits_per_meta_elem), dtype=meta_dtype, device=device, ) if quadbits_per_meta_elem == 4: meta_2[:, :, 0] = meta & 0b11 meta_2[:, :, 1] = (meta >> 2) & 0b11 meta_2[:, :, 2] = (meta >> 4) & 0b11 meta_2[:, :, 3] = (meta >> 6) & 0b11 meta_2[:, :, 4] = (meta >> 8) & 0b11 meta_2[:, :, 5] = (meta >> 10) & 0b11 meta_2[:, :, 6] = (meta >> 12) & 0b11 meta_2[:, :, 7] = (meta >> 14) & 0b11 elif quadbits_per_meta_elem == 8: meta_2[:, :, 0] = meta & 0b11 meta_2[:, :, 1] = (meta >> 2) & 0b11 meta_2[:, :, 2] = (meta >> 4) & 0b11 meta_2[:, :, 3] = (meta >> 6) & 0b11 meta_2[:, :, 4] = (meta >> 8) & 0b11 meta_2[:, :, 5] = (meta >> 10) & 0b11 meta_2[:, :, 6] = (meta >> 12) & 0b11 meta_2[:, :, 7] = (meta >> 14) & 0b11 meta_2[:, :, 8] = (meta >> 16) & 0b11 meta_2[:, :, 9] = (meta >> 18) & 0b11 meta_2[:, :, 10] = (meta >> 20) & 0b11 meta_2[:, :, 11] = (meta >> 22) & 0b11 meta_2[:, :, 12] = (meta >> 24) & 0b11 meta_2[:, :, 13] = (meta >> 26) & 0b11 meta_2[:, :, 14] = (meta >> 28) & 0b11 meta_2[:, :, 15] = (meta >> 30) & 0b11 dense_offsets = meta_2.view(-1) + ( torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view( -1, 1).repeat(1, 2).view(-1) dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device) if sparse.dtype != torch.float: # dense.scatter_(0, dense_offsets, sparse.view(-1)) dense.scatter_(0, dense_offsets, sparse.reshape(-1)) else: dense.view(torch.half).scatter_(0, dense_offsets, sparse.view(torch.half).view(-1)) return dense.view(m, 2 * k) def mask_creator(tensor): """ Class for creating N:M sparsity masks. Masks will be created using the N:M ratio, where for every block of M weights, N will be pruned based on ranked weight value. Each mask will correspond to the given tensor. :param N: The number of weights in a group to keep :param M: The size of a weight group """ N = 2 M = 4 mask = None # for i, tensor in enumerate(tensors): if tensor.numel() % M != 0: raise ValueError( f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups") num_groups = tensor.numel() // M # N:M sparsity for linear layers tensor_temp = tensor.detach().abs().reshape(num_groups, M) index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)] w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) return mask ================================================ FILE: archive/csrc/custom_marlin/utils/marlin_24_perms.py ================================================ ''' Date: 2024-11-08 02:46:07 LastEditors: djw LastEditTime: 2024-11-08 02:46:41 ''' """This file is used for /tests and /benchmarks""" from typing import Dict, List import numpy import torch # Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501 # # Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 # with the tensor-core format that is described here: # https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 # # As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 # (without the need to use ldmatrix instructions) # noqa: E501 def get_perms_24(num_bits: int): perm_list: List[int] = [] for i in range(32): perm1: List[int] = [] col = i // 4 col_o = col // 2 for block in [0, 1]: for row in [ 2 * (i % 4), 2 * (i % 4) + 1, 2 * (i % 4 + 4), 2 * (i % 4 + 4) + 1, ]: perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) for j in range(4): perm_list.extend([p + 1 * j for p in perm1]) perm = numpy.array(perm_list) if num_bits == 4: interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) elif num_bits == 8: interleave = numpy.array([0, 2, 1, 3]) else: raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() perm = torch.from_numpy(perm) scale_perm: List[int] = [] for i in range(8): scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) scale_perm_single: List[int] = [] for i in range(8): scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) return perm, scale_perm, scale_perm_single marlin_24_perm: Dict[int, torch.Tensor] = {} marlin_24_scale_perm: Dict[int, List[int]] = {} marlin_24_scale_perm_single: Dict[int, List[int]] = {} for num_bits in [4, 8]: perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits) marlin_24_perm[num_bits] = perm_24 marlin_24_scale_perm[num_bits] = scale_perm_24 marlin_24_scale_perm_single[num_bits] = scale_perm_single_24 ================================================ FILE: archive/csrc/custom_marlin/utils/marlin_perms.py ================================================ ''' Date: 2024-11-08 02:46:47 LastEditors: djw LastEditTime: 2024-11-08 02:46:55 ''' """This file is used for /tests and /benchmarks""" from typing import Dict, List import numpy import torch # Precompute permutations for Marlin weight and scale shuffling # noqa: E501 # # Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 # with the tensor-core format that is described here: # https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 # # As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 # (without the need to use ldmatrix instructions) # noqa: E501 def get_perms(num_bits: int): perm_list: List[int] = [] for i in range(32): perm1: List[int] = [] col = i // 4 for block in [0, 1]: for row in [ 2 * (i % 4), 2 * (i % 4) + 1, 2 * (i % 4 + 4), 2 * (i % 4 + 4) + 1, ]: perm1.append(16 * row + col + 8 * block) for j in range(4): perm_list.extend([p + 256 * j for p in perm1]) perm = numpy.array(perm_list) if num_bits == 4: interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) elif num_bits == 8: interleave = numpy.array([0, 2, 1, 3]) else: raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() perm = torch.from_numpy(perm) scale_perm: List[int] = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) scale_perm_single: List[int] = [] for i in range(4): scale_perm_single.extend( [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return perm, scale_perm, scale_perm_single marlin_perm: Dict[int, torch.Tensor] = {} marlin_scale_perm: Dict[int, List[int]] = {} marlin_scale_perm_single: Dict[int, List[int]] = {} for num_bits in [4, 8]: perm, scale_perm, scale_perm_single = get_perms(num_bits) marlin_perm[num_bits] = perm marlin_scale_perm[num_bits] = scale_perm marlin_scale_perm_single[num_bits] = scale_perm_single ================================================ FILE: archive/csrc/custom_marlin/utils/marlin_utils.py ================================================ """This file is used for /tests and /benchmarks""" import random import numpy import torch from .format24 import ( mask_creator, sparse_semi_structured_from_dense_cutlass) from .marlin_24_perms import ( marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single) from .marlin_perms import ( marlin_perm, marlin_scale_perm, marlin_scale_perm_single) from .quant_utils import ( get_pack_factor, quantize_weights, sort_weights, dequantize_weights) __cuda_arch = torch.cuda.get_device_capability() MARLIN_TILE = 16 GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_MIN_THREAD_N = 64 GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MAX_PARALLEL = 16 GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8] GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] GPTQ_MARLIN_SUPPORTED_SYM = [True] def is_marlin_supported(): return __cuda_arch[0] >= 8 def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE): assert q_w.shape == (size_k, size_n) assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" # Permute weights to 16x64 marlin tiles q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) q_w = q_w.permute((0, 2, 1, 3)) q_w = q_w.reshape((size_k // tile, size_n * tile)) q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) return q_w def marlin_weights(q_w, size_k, size_n, num_bits, perm): # Permute q_w = marlin_permute_weights(q_w, size_k, size_n, perm) # Pack pack_factor = get_pack_factor(num_bits) orig_device = q_w.device q_w = q_w.cpu().numpy().astype(numpy.uint32) q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=numpy.uint32) for i in range(pack_factor): q_packed |= q_w[:, i::pack_factor] << num_bits * i q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) return q_packed def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm, scale_perm_single): if group_size < size_k and group_size != -1: s = s.reshape((-1, len(scale_perm)))[:, scale_perm] else: s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] s = s.reshape((-1, size_n)).contiguous() return s def marlin_quantize( w: torch.Tensor, num_bits: int, group_size: int, act_order: bool, ): size_k, size_n = w.shape # Normalize group_size if group_size == -1: group_size = size_k assert group_size <= size_k # Quantize (and apply act_order if provided) w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, act_order) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing sort_indices = torch.empty(0, dtype=torch.int, device=w.device) if act_order: q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Reformat to marlin marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, marlin_perm[num_bits]) marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, marlin_scale_perm[num_bits], marlin_scale_perm_single[num_bits]) # Create result res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] for i in range(len(res_list)): res_list[i] = res_list[i].to(w.device) return res_list def inject_24(w, size_k, size_n): assert w.shape == (size_k, size_n) mask = mask_creator(w.t()).t().cuda().bool() return (mask * w).contiguous(), mask.contiguous() def check_24(w, num_rows_to_sample=50, _verbose=False): BLOCK_SIZE = 4 MAX_NON_ZEROS = 2 w = w.t().contiguous() print("check_24: w.shape = {}".format(w.shape)) num_rows, num_cols = w.shape sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) if _verbose: print(f"Sampled row idxs = {sampled_row_idxs}") total_segments = 0 non_24_segments = 0 for i in sampled_row_idxs: for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): total_segments += 1 block = w[i, j:j + BLOCK_SIZE] num_nonzero = torch.count_nonzero(block) if num_nonzero > MAX_NON_ZEROS: print("i = {} j = {} block = {}".format(i, j, block)) non_24_segments += 1 print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): assert q_24.shape == (size_k, size_n) # Remove zp to normalize over 0 max_q_val = (1 << num_bits) - 1 zp = (max_q_val + 1) // 2 q_24_no_zp = q_24 - zp # Compress q_24_no_zp = q_24_no_zp.t().contiguous() q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( q_24_no_zp) q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() # Restore zp q_24_comp = q_24_no_zp_comp + zp # Resize meta to its actual shape (without moving any data) meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) return q_24_comp, meta def marlin_24_quantize( w: torch.Tensor, num_bits: int, group_size: int, ): size_k, size_n = w.shape # Normalize group_size if group_size == -1: group_size = size_k assert group_size <= size_k # Inject 2:4 sparsity w_24, mask_24 = inject_24(w, size_k, size_n) # Quantize w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24, num_bits, group_size, act_order=False) # Compress quantized weight q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, num_bits) size_k_comp = size_k // 2 # Reformat to marlin marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, num_bits, marlin_24_perm[num_bits]) marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size, marlin_24_scale_perm[num_bits], marlin_24_scale_perm_single[num_bits]) # Create result res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] for i in range(len(res_list)): res_list[i] = res_list[i].to(w.device) return res_list def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( torch.abs(output_ref)) class MarlinWorkspace: def __init__(self, out_features, min_thread_n, max_parallel, device): assert (out_features % min_thread_n == 0), ( "out_features = {} is undivisible by min_thread_n = {}".format( out_features, min_thread_n)) max_workspace_size = ((out_features // min_thread_n) * max_parallel) self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device=device) ================================================ FILE: archive/csrc/custom_marlin/utils/quant_utils.py ================================================ """This file is used for /tests and /benchmarks""" import numpy import torch SUPPORTED_NUM_BITS = [4, 8] SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] def get_pack_factor(num_bits): assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" return 32 // num_bits def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): assert q_w.shape == w_ref.shape orig_device = q_w.device k_size, _ = q_w.shape g_idx = torch.zeros((k_size, ), dtype=torch.int32) for i in range(k_size): g_idx[i] = i // group_size # Simulate act_order by doing a random permutation on K rand_perm = torch.randperm(k_size) g_idx = g_idx[rand_perm].contiguous() q_w = q_w[rand_perm, :].contiguous() w_ref = w_ref[rand_perm, :].contiguous() return ( w_ref.to(device=orig_device), q_w.to(device=orig_device), g_idx.to(device=orig_device), rand_perm.to(device=orig_device), ) # Function: Dequantize quantized weights def dequantize_weights(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'): # Create a tensor for bitwise right shift operation wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=device).unsqueeze(0) # Apply bitwise right shift and convert qzeros to the appropriate type zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8) torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros) # Reshape the zeros tensor zeros = zeros + 1 zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2]) # Reshape the scales tensor scales = scales.reshape(-1, 1, scales.shape[-1]) # Similar bitwise right shift operation for qweight and reshape weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8) torch.bitwise_and(weight, (2 ** bits) - 1, out=weight) weight = weight.reshape(-1, group_size, weight.shape[2]) # Apply dequantization formula and reshape the final weight weight = (scales * (weight - zeros)) weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) # Return the transposed weight return weight.transpose(0, 1) def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, act_order: bool): orig_device = w.device size_k, size_n = w.shape assert w.is_floating_point(), "w must be float" assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" assert group_size in SUPPORTED_GROUP_SIZES + [ size_k ], f"Unsupported groupsize = {group_size}" if group_size == -1: group_size = size_k assert group_size <= size_k max_q_val = 2**num_bits - 1 half_q_val = (max_q_val + 1) // 2 # Reshape to [groupsize, -1] if group_size < size_k: w = w.view((-1, group_size, size_n)) w = w.permute(1, 0, 2) w = w.reshape((group_size, -1)) # Compute scale for each group s = torch.max(torch.abs(w), 0, keepdim=True)[0] s *= 2 / max_q_val # 2 => symmetric # Quantize q_w = torch.round(w / s).int() q_w += half_q_val q_w = torch.clamp(q_w, 0, max_q_val) # Compute ref (dequantized) w_ref = (q_w - half_q_val).half() * s # Restore original shapes if group_size < size_k: def reshape_w(w): w = w.reshape((group_size, -1, size_n)) w = w.permute(1, 0, 2) w = w.reshape((size_k, size_n)).contiguous() return w q_w = reshape_w(q_w) w_ref = reshape_w(w_ref) s = s.reshape((-1, size_n)).contiguous() # Apply act_order g_idx = torch.empty(0, dtype=torch.int, device=w.device) rand_perm = torch.empty(0, dtype=torch.int, device=w.device) if act_order: assert ( group_size < size_k ), "For act_order, groupsize = {} must be less than size_k = {}".format( group_size, size_k) w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size) return ( w_ref.to(device=orig_device), q_w.to(device=orig_device), s.to(device=orig_device), g_idx.to(device=orig_device), rand_perm.to(device=orig_device), ) def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): orig_device = q_w.device sort_indices = torch.argsort(g_idx).to( dtype=torch.int32) # Sort based on g_idx g_idx = g_idx[sort_indices].contiguous() q_w = q_w[sort_indices, :].contiguous() return ( q_w.to(device=orig_device), g_idx.to(device=orig_device), sort_indices.to(device=orig_device), ) def gptq_pack( q_w: torch.Tensor, num_bits: int, size_k: int, size_n: int, ): assert q_w.shape == (size_k, size_n) pack_factor = get_pack_factor(num_bits) assert size_k % pack_factor == 0 orig_device = q_w.device q_w = q_w.cpu().numpy().astype(numpy.uint32) q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) for i in range(pack_factor): q_res |= q_w[i::pack_factor, :] << num_bits * i q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) return q_res def gptq_unpack( q_res: torch.Tensor, num_bits: int, size_k: int, size_n: int, ): pack_factor = 32 // num_bits assert size_k % pack_factor == 0 orig_device = q_res.device q_res = q_res.cpu().numpy() q_w = numpy.zeros((size_k, size_n), dtype=numpy.uint32) for i in range(pack_factor): q_w[i::pack_factor, :] = (q_res >> (num_bits * i)) & ((1 << num_bits) - 1) q_w = torch.from_numpy(q_w.astype(numpy.int32)).to(orig_device) return q_w ================================================ FILE: archive/csrc/ktransformers_ext/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.16) project(cpuinfer_ext VERSION 0.1.0) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -ffast-math -fopenmp") add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI}) set(CMAKE_BUILD_TYPE "Release") # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -ffast-math -fopenmp") # set(CMAKE_BUILD_TYPE "Debug") set(CMAKE_EXPORT_COMPILE_COMMANDS ON) include(CheckCXXCompilerFlag) set(CMAKE_POSITION_INDEPENDENT_CODE ON) option(LLAMA_NATIVE "llama: enable -march=native flag" ON) # instruction set specific if (LLAMA_NATIVE) set(INS_ENB OFF) else() set(INS_ENB ON) endif() option(LLAMA_AVX "llama: enable AVX" OFF) option(LLAMA_AVX2 "llama: enable AVX2" OFF) option(LLAMA_AVX512 "llama: enable AVX512" OFF) option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF) option(LLAMA_FMA "llama: enable FMA" OFF) # in MSVC F16C is implied with AVX2/AVX512 if (NOT MSVC) option(LLAMA_F16C "llama: enable F16C" OFF) endif() option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF) option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" ON) option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF) option(KTRANSFORMERS_USE_ROCM "ktransformers: use ROCM" OFF) option(KTRANSFORMERS_USE_XPU "ktransformers: use XPU" OFF) option(KTRANSFORMERS_USE_NPU "ktransformers: use NPU" OFF) if(KTRANSFORMERS_USE_NPU) add_definitions(-DKTRANSFORMERS_USE_NPU=1) endif() # Architecture specific # TODO: probably these flags need to be tweaked on some architectures # feel free to update the Makefile for your architecture and send a pull request or issue message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") if (MSVC) string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR) message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}") else () set(CMAKE_GENERATOR_PLATFORM_LWR "") endif () if (NOT MSVC) if (LLAMA_STATIC) add_link_options(-static) if (MINGW) add_link_options(-static-libgcc -static-libstdc++) endif() endif() if (LLAMA_GPROF) add_compile_options(-pg) endif() endif() set(ARCH_FLAGS "") if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$")) message(STATUS "ARM detected") if (MSVC) add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead add_compile_definitions(__ARM_NEON) add_compile_definitions(__ARM_FEATURE_FMA) set(CMAKE_REQUIRED_FLAGS_PREV ${CMAKE_REQUIRED_FLAGS}) string(JOIN " " CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS} "/arch:armv8.2") check_cxx_source_compiles("#include \nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_DOTPROD) if (GGML_COMPILER_SUPPORT_DOTPROD) add_compile_definitions(__ARM_FEATURE_DOTPROD) endif () check_cxx_source_compiles("#include \nint main() { float16_t _a; float16x8_t _s = vdupq_n_f16(_a); return 0; }" GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC) if (GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC) add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) endif () set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV}) else() if(KTRANSFORMERS_USE_NPU) list(APPEND ARCH_FLAGS -march=armv8.2-a+fp16+fp16fml+dotprod -lnuma) endif() check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E) if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "") list(APPEND ARCH_FLAGS -mfp16-format=ieee) endif() if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6") # Raspberry Pi 1, Zero list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access) endif() if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7") if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Android") # Android armeabi-v7a list(APPEND ARCH_FLAGS -mfpu=neon-vfpv4 -mno-unaligned-access -funsafe-math-optimizations) else() # Raspberry Pi 2 list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations) endif() endif() if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8") # Android arm64-v8a # Raspberry Pi 3, 4, Zero 2 (32-bit) list(APPEND ARCH_FLAGS -mno-unaligned-access) endif() endif() elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$")) message(STATUS "x86 detected") if(NOT KTRANSFORMERS_USE_NPU) set(HOST_IS_X86 TRUE) set(HAS_AVX512 TRUE) set(__HAS_AMX__ TRUE) add_compile_definitions(__x86_64__) # check AVX512 execute_process( COMMAND lscpu OUTPUT_VARIABLE LSCPU_OUTPUT OUTPUT_STRIP_TRAILING_WHITESPACE ) # message(STATUS "LSCPU_OUTPUT: ${LSCPU_OUTPUT}") string(FIND "${LSCPU_OUTPUT}" "avx512" COMPILER_SUPPORTS_AVX512F) if (COMPILER_SUPPORTS_AVX512F GREATER -1) message(STATUS "Compiler and CPU support AVX512F (tested by compiling a program)") add_compile_definitions(__HAS_AVX512F__) else() message(STATUS "Compiler and/or CPU do NOT support AVX512F") set(HAS_AVX512 False) endif() # check AMX string(FIND "${LSCPU_OUTPUT}" "amx" COMPILER_SUPPORTS_AMX) if(COMPILER_SUPPORTS_AMX GREATER -1) message(STATUS "Compiler supports AMX") add_compile_definitions(__HAS_AMX__) else() message(STATUS "Compiler does NOT support AMX") endif() endif() if (MSVC) # instruction set detection for MSVC only if (LLAMA_NATIVE) include(cmake/FindSIMD.cmake) endif () if (LLAMA_AVX512) list(APPEND ARCH_FLAGS /arch:AVX512) # MSVC has no compile-time flags enabling specific # AVX512 extensions, neither it defines the # macros corresponding to the extensions. # Do it manually. if (LLAMA_AVX512_VBMI) add_compile_definitions($<$:__AVX512VBMI__>) add_compile_definitions($<$:__AVX512VBMI__>) endif() if (LLAMA_AVX512_VNNI) add_compile_definitions($<$:__AVX512VNNI__>) add_compile_definitions($<$:__AVX512VNNI__>) endif() if (LLAMA_AVX512_FANCY_SIMD) add_compile_definitions($<$:__AVX512VL__>) add_compile_definitions($<$:__AVX512VL__>) add_compile_definitions($<$:__AVX512BW__>) add_compile_definitions($<$:__AVX512BW__>) add_compile_definitions($<$:__AVX512DQ__>) add_compile_definitions($<$:__AVX512DQ__>) add_compile_definitions($<$:__AVX512VNNI__>) add_compile_definitions($<$:__AVX512VNNI__>) endif() if (LLAMA_AVX512_BF16) add_compile_definitions($<$:__AVX512BF16__>) add_compile_definitions($<$:__AVX512BF16__>) endif() elseif (LLAMA_AVX2) list(APPEND ARCH_FLAGS /arch:AVX2) elseif (LLAMA_AVX) list(APPEND ARCH_FLAGS /arch:AVX) endif() else() if (LLAMA_NATIVE) list(APPEND ARCH_FLAGS -mfma -mavx -mavx2) list(APPEND ARCH_FLAGS -march=native) endif() if (LLAMA_F16C) list(APPEND ARCH_FLAGS -mf16c) endif() if (LLAMA_FMA) list(APPEND ARCH_FLAGS -mfma) endif() if (LLAMA_AVX) list(APPEND ARCH_FLAGS -mavx) endif() if (LLAMA_AVX2) list(APPEND ARCH_FLAGS -mavx2) endif() if (LLAMA_AVX512) list(APPEND ARCH_FLAGS -mavx512f) list(APPEND ARCH_FLAGS -mavx512bw) endif() if (LLAMA_AVX512_VBMI) list(APPEND ARCH_FLAGS -mavx512vbmi) endif() if (LLAMA_AVX512_VNNI) list(APPEND ARCH_FLAGS -mavx512vnni) endif() if (LLAMA_AVX512_FANCY_SIMD) message(STATUS "AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI enabled") list(APPEND ARCH_FLAGS -mavx512vl) list(APPEND ARCH_FLAGS -mavx512bw) list(APPEND ARCH_FLAGS -mavx512dq) list(APPEND ARCH_FLAGS -mavx512vnni) list(APPEND ARCH_FLAGS -mavx512vpopcntdq) endif() if (LLAMA_AVX512_BF16) list(APPEND ARCH_FLAGS -mavx512bf16) endif() endif() elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") message(STATUS "PowerPC detected") if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") list(APPEND ARCH_FLAGS -mcpu=powerpc64le) else() list(APPEND ARCH_FLAGS -mcpu=native -mtune=native) #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) endif() else() message(STATUS "Unknown architecture") endif() # message(STATUS "CUDAToolkit_ROOT:${CUDAToolkit_ROOT}") # find_package(FindCUDAToolkit REQUIRED) # if(CUDAToolkit_FOUND) # message(STATUS "Found CUDA cudart lib at:${CUDAToolkit_LIBRARY_DIR}") # else() # message(STATUS "Can't found CUDA lib") # endif() if (NOT EXISTS $ENV{ROCM_PATH}) if (NOT EXISTS /opt/rocm) set(ROCM_PATH /usr) else() set(ROCM_PATH /opt/rocm) endif() else() set(ROCM_PATH $ENV{ROCM_PATH}) endif() list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake") if (NOT EXISTS $ENV{MUSA_PATH}) if (NOT EXISTS /opt/musa) set(MUSA_PATH /usr/local/musa) else() set(MUSA_PATH /opt/musa) endif() else() set(MUSA_PATH $ENV{MUSA_PATH}) endif() list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake") add_compile_options("$<$:${ARCH_FLAGS}>") add_compile_options("$<$:${ARCH_FLAGS}>") add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/pybind11 ${CMAKE_CURRENT_BINARY_DIR}/third_party/pybind11) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/llama.cpp) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party) if (WIN32) include_directories("$ENV{CUDA_PATH}/include") add_compile_definitions(KTRANSFORMERS_USE_CUDA=1) elseif (UNIX) if (KTRANSFORMERS_USE_ROCM) find_package(HIP REQUIRED) if(HIP_FOUND) include_directories("${HIP_INCLUDE_DIRS}") add_compile_definitions(KTRANSFORMERS_USE_ROCM=1) endif() elseif (KTRANSFORMERS_USE_MUSA) if (NOT EXISTS $ENV{MUSA_PATH}) if (NOT EXISTS /opt/musa) set(MUSA_PATH /usr/local/musa) else() set(MUSA_PATH /opt/musa) endif() else() set(MUSA_PATH $ENV{MUSA_PATH}) endif() list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake") find_package(MUSAToolkit) if (MUSAToolkit_FOUND) message(STATUS "MUSA Toolkit found") add_compile_definitions(KTRANSFORMERS_USE_MUSA=1) endif() elseif (KTRANSFORMERS_USE_XPU) add_compile_definitions(KTRANSFORMERS_USE_XPU=1) elseif (KTRANSFORMERS_USE_CUDA) find_package(CUDA REQUIRED) include_directories("${CUDA_INCLUDE_DIRS}") include(CheckLanguage) check_language(CUDA) if(CMAKE_CUDA_COMPILER) message(STATUS "CUDA detected") find_package(CUDAToolkit REQUIRED) include_directories(${CUDAToolkit_INCLUDE_DIRS}) endif() message(STATUS "enabling CUDA") enable_language(CUDA) add_compile_definitions(KTRANSFORMERS_USE_CUDA=1) endif() endif() aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3) # aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4) file(GLOB LLAMAFILE_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile/*.cpp") list(REMOVE_ITEM LLAMAFILE_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile/sgemm_arm.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile/sgemm_x86.cpp" ) set(SOURCE_DIR4 ${LLAMAFILE_SOURCES}) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5) if (HOST_IS_X86 AND HAS_AVX512 AND __HAS_AMX__) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/amx SOURCE_DIR6) endif() set(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6}) file(GLOB_RECURSE FMT_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.h") add_custom_target( format COMMAND clang-format -i -style=file ${FMT_SOURCES} COMMENT "Running clang-format on all source files" ) add_library(llamafile STATIC ${SOURCE_DIR4}) message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") message(STATUS "ARCH_FLAGS: ${ARCH_FLAGS}") pybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES}) target_link_libraries(${PROJECT_NAME} PRIVATE llama) if(WIN32) target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart elseif(UNIX) if (KTRANSFORMERS_USE_ROCM) add_compile_definitions(USE_HIP=1) target_link_libraries(${PROJECT_NAME} PRIVATE "${ROCM_PATH}/lib/libamdhip64.so") message(STATUS "Building for HIP") elseif(KTRANSFORMERS_USE_MUSA) target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart) elseif(KTRANSFORMERS_USE_XPU) elseif(KTRANSFORMERS_USE_CUDA AND NOT KTRANSFORMERS_USE_MUSA) target_link_libraries(${PROJECT_NAME} PRIVATE "${CUDAToolkit_LIBRARY_DIR}/libcudart.so") endif() endif() # Define the USE_NUMA option option(USE_NUMA "Disable NUMA support" OFF) # Check if the USE_NUMA environment variable is set if(DEFINED ENV{USE_NUMA}) set(USE_NUMA ON) endif() if(USE_NUMA) message(STATUS "NUMA support is enabled") else() message(STATUS "NUMA support is disabled") endif() find_library(NUMA_LIBRARY NAMES numa) if(NUMA_LIBRARY AND USE_NUMA) message(STATUS "NUMA library found: ${NUMA_LIBRARY} - enabling NUMA support") target_link_libraries(${PROJECT_NAME} PRIVATE ${NUMA_LIBRARY}) target_compile_definitions(${PROJECT_NAME} PRIVATE USE_NUMA) else() if(USE_NUMA) message(FATAL_ERROR "NUMA library not found - maybe sudo apt install libnuma-dev") else() message(STATUS "NUMA library not found or user not set USE_NUMA - disabling NUMA support") endif() endif() ================================================ FILE: archive/csrc/ktransformers_ext/bench/bench_attention.py ================================================ #!/usr/bin/env python # coding=utf-8 """ Description : Author : Jianwei Dong Date : 2024-08-28 10:32:05 Version : 1.0.0 LastEditors : Jianwei Dong LastEditTime : 2024-08-28 10:32:05 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ import os, sys import time sys.path.append(os.path.dirname(__file__) + "/../build") import cpuinfer_ext import torch layer_num = 10 kv_head_num = 8 q_head_num = 32 head_dim = 128 block_len = 128 anchor_num = 1 anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC kv_type = cpuinfer_ext.kvcache.ggml_type.FP16 retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER layer_step: int = 1 token_step: int = 1 layer_offset: int = 0 max_thread_num: int = 64 max_batch_size: int = 1 max_block_num: int = 1024 CPUInfer = cpuinfer_ext.CPUInfer(max_thread_num) warm_up_iter = 1000 test_iter = 10000 def bench_linear(cache_seqlen: int): with torch.inference_mode(mode=True): cache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device="cpu") seqlens_zero = torch.zeros((1,), dtype=torch.int32, device="cpu") config = cpuinfer_ext.kvcache.KVCacheConfig( layer_num, kv_head_num, q_head_num, head_dim, block_len, anchor_num, anchor_type, kv_type, retrieval_type, layer_step, token_step, layer_offset, max_block_num, max_batch_size, max_thread_num, ) local_kvcache = cpuinfer_ext.kvcache.KVCache(config) block_table = ( torch.arange(max_block_num, dtype=torch.int32, device="cpu") .contiguous() .view(1, -1) ) for layer_idx in range(layer_num): k_cache = torch.randn( (1, cache_seqlen, kv_head_num, head_dim), dtype=torch.float16, device="cpu", ).contiguous() v_cache = torch.randn( (1, cache_seqlen, kv_head_num, head_dim), dtype=torch.float16, device="cpu", ).contiguous() CPUInfer.submit( local_kvcache.update_kvcache_fp16( k_cache.data_ptr(), v_cache.data_ptr(), layer_idx, block_table.data_ptr(), 1, max_block_num, seqlens_zero.data_ptr(), cache_seqlen, ) ) CPUInfer.sync() input = torch.randn( (1, 1, q_head_num, head_dim), dtype=torch.float16, device="cpu" ).contiguous() output = torch.empty( (1, 1, q_head_num, head_dim), dtype=torch.float16, device="cpu" ).contiguous() # attn_lse: (bsz, q_len, q_head_num) attn_lse = torch.empty( (1, 1, q_head_num), dtype=torch.float32, device="cpu" ).contiguous() input = input / 100 # warm up for i in range(warm_up_iter): CPUInfer.submit( local_kvcache.attn( input.data_ptr(), output.data_ptr(), attn_lse.data_ptr(), i % layer_num, 0, 1, 1, max_block_num, block_table.data_ptr(), cache_seqlens.data_ptr(), -1, -1, -1, ) ) CPUInfer.sync() # test start = time.perf_counter() for i in range(test_iter): CPUInfer.submit( local_kvcache.attn( input.data_ptr(), output.data_ptr(), attn_lse.data_ptr(), i % layer_num, 0, 1, 1, max_block_num, block_table.data_ptr(), cache_seqlens.data_ptr(), -1, -1, -1, ) ) CPUInfer.sync() end = time.perf_counter() total_time = end - start print("cache sequence length: ", cache_seqlen) print("Time(s): ", total_time) print("Iteration: ", test_iter) print("Time(us) per iteration: ", total_time / test_iter * 1000000) print( "Bandwidth: ", cache_seqlen * kv_head_num * head_dim * 2 * 2 * test_iter / total_time / 1000 / 1000 / 1000, "GB/s", ) print("") bench_linear(1024) bench_linear(4096) bench_linear(16384) bench_linear(32768) bench_linear(65536) ================================================ FILE: archive/csrc/ktransformers_ext/bench/bench_attention_torch.py ================================================ #!/usr/bin/env python # coding=utf-8 """ Description : Author : Jianwei Dong Date : 2024-08-28 10:32:05 Version : 1.0.0 LastEditors : Jianwei Dong LastEditTime : 2024-08-28 10:32:05 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ import os, sys import time sys.path.append(os.path.dirname(__file__) + "/../build") import cpuinfer_ext import torch layer_num = 10 kv_head_num = 8 q_head_num = 32 head_dim = 128 block_len = 128 anchor_num = 1 warm_up_iter = 1000 test_iter = 10000 def bench_linear(cache_seqlen: int, device): with torch.inference_mode(mode=True): kvcaches = [] for layer_idx in range(layer_num): k_cache = torch.randn( (1, 32, cache_seqlen, head_dim), dtype=torch.float16, device=device, ).contiguous() v_cache = torch.randn( (1, 32, cache_seqlen, head_dim), dtype=torch.float16, device=device, ).contiguous() kvcaches.append((k_cache, v_cache)) input = torch.randn( (1, q_head_num, 1, head_dim), dtype=torch.float16, device=device ).contiguous() input = input / 100 # warm up for i in range(warm_up_iter): k_cache = kvcaches[i % layer_num][0] v_cache = kvcaches[i % layer_num][1] torch.nn.functional.scaled_dot_product_attention(input, k_cache, v_cache) # test start = time.perf_counter() for i in range(test_iter): k_cache = kvcaches[i % layer_num][0] v_cache = kvcaches[i % layer_num][1] torch.nn.functional.scaled_dot_product_attention(input, k_cache, v_cache) end = time.perf_counter() total_time = end - start print("cache sequence length: ", cache_seqlen) print("Time(s): ", total_time) print("Iteration: ", test_iter) print("Time(us) per iteration: ", total_time / test_iter * 1000000) print( "Bandwidth: ", cache_seqlen * q_head_num * head_dim * 2 * 2 * test_iter / total_time / 1000 / 1000 / 1000, "GB/s", ) print("") bench_linear(1024, "cpu") bench_linear(4096, "cpu") bench_linear(1024, "cuda") bench_linear(4096, "cuda") bench_linear(16384, "cuda") bench_linear(32768, "cuda") bench_linear(65536, "cuda") ================================================ FILE: archive/csrc/ktransformers_ext/bench/bench_linear.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : chenht2022 Date : 2024-07-25 10:31:59 Version : 1.0.0 LastEditors : chenht2022 LastEditTime : 2024-08-06 10:35:35 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import os, sys import time sys.path.append(os.path.dirname(__file__) + '/../build') import cpuinfer_ext import torch input_size = 16384 output_size = 5120 stride = 16 group_max_len = 1024 layer_num = 10 qlen = 1 CPUInfer = cpuinfer_ext.CPUInfer(64) warm_up_iter = 1000 test_iter = 10000 def bench_linear(quant_mode: str): with torch.inference_mode(mode=True): hidden_type = 30 # ggml_type::GGML_TYPE_BF16 if quant_mode == "fp32": proj_type = 0 # ggml_type::GGML_TYPE_F32 bytes_per_elem = 4.000000 elif quant_mode == "fp16": proj_type = 1 # ggml_type::GGML_TYPE_F16 bytes_per_elem = 2.000000 elif quant_mode == "bf16": proj_type = 30 # ggml_type::GGML_TYPE_BF16 bytes_per_elem = 2.000000 elif quant_mode == "q8_0": proj_type = 8 # ggml_type::GGML_TYPE_Q8_0 bytes_per_elem = 1.062500 elif quant_mode == "q6_k": proj_type = 14 # ggml_type::GGML_TYPE_Q6_K bytes_per_elem = 0.820312 elif quant_mode == "q5_k_m": proj_type = 13 # ggml_type::GGML_TYPE_Q5_K bytes_per_elem = 0.687500 elif quant_mode == "q4_k_m": proj_type = 12 # ggml_type::GGML_TYPE_Q4_K bytes_per_elem = 0.562500 elif quant_mode == "q3_k_m": proj_type = 11 # ggml_type::GGML_TYPE_Q3_K bytes_per_elem = 0.429688 elif quant_mode == "q2_k": proj_type = 10 # ggml_type::GGML_TYPE_Q2_K bytes_per_elem = 0.328125 elif quant_mode == "iq3_xs": proj_type = 21 # ggml_type::GGML_TYPE_IQ3_S bytes_per_elem = 0.429688 elif quant_mode == "iq2_xxs": proj_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS bytes_per_elem = 0.257812 else: assert(False) linears = [] projs = [] for _ in range(layer_num): proj = torch.randn((output_size, input_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() config = cpuinfer_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type) linear = cpuinfer_ext.linear.Linear(config) projs.append(proj) linears.append(linear) input = torch.randn((layer_num, qlen, input_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous() output = torch.empty((layer_num, qlen, output_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous() # warm up for i in range(warm_up_iter): CPUInfer.submit( linears[i % layer_num].forward( qlen, input[i % layer_num].data_ptr(), output[i % layer_num].data_ptr() ) ) CPUInfer.sync() # test start = time.perf_counter() for i in range(test_iter): CPUInfer.submit( linears[i % layer_num].forward( qlen, input[i % layer_num].data_ptr(), output[i % layer_num].data_ptr() ) ) CPUInfer.sync() end = time.perf_counter() total_time = end - start print('Quant mode: ', quant_mode) print('Time(s): ', total_time) print('Iteration: ', test_iter) print('Time(us) per iteration: ', total_time / test_iter * 1000000) print('Bandwidth: ', input_size * output_size * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s') print('') bench_linear("fp32") bench_linear("fp16") bench_linear("bf16") bench_linear("q8_0") bench_linear("q6_k") bench_linear("q5_k_m") bench_linear("q4_k_m") bench_linear("q3_k_m") bench_linear("q2_k") # Not supported on __x86_64__ # bench_linear("iq3_xs") # bench_linear("iq2_xxs") ================================================ FILE: archive/csrc/ktransformers_ext/bench/bench_linear_torch.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : chenht2022 Date : 2024-07-25 10:31:59 Version : 1.0.0 LastEditors : chenht2022 LastEditTime : 2024-07-25 10:32:48 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import os, sys import time import torch import torch.nn.quantized as nnq scale, zero_point = 0.1, 0 # Adjust scale and zero_point based on your dataset input_size = 16384 output_size = 5120 layer_num = 10 qlen = 1 warm_up_iter = 1000 test_iter = 10000 def bench_linear(quant_mode: str): with torch.inference_mode(mode=True): if quant_mode == "fp32": proj_type = torch.float32 bytes_per_elem = 4.000000 elif quant_mode == "fp16": proj_type = torch.float16 bytes_per_elem = 2.000000 elif quant_mode == "bf16": proj_type = torch.bfloat16 bytes_per_elem = 2.000000 elif quant_mode == "qint8": proj_type = torch.qint8 bytes_per_elem = 1.000000 else: assert(False) projs = [] for _ in range(layer_num): proj = torch.randn((output_size, input_size), dtype = torch.float32, device = "cuda").to("cpu").contiguous() if quant_mode == "qint8": proj_q = torch.quantize_per_tensor(proj, scale, zero_point, torch.qint8) quantized_layer = nnq.Linear(input_size, output_size) quantized_layer.set_weight_bias(proj_q, None) projs.append(quantized_layer) else: projs.append(proj.to(proj_type)) input = torch.randn((layer_num, qlen, input_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous() # warm up for i in range(warm_up_iter): if isinstance(projs[i % layer_num], nnq.Linear): input_q = torch.quantize_per_tensor(input[i % layer_num].to(torch.float32), scale, zero_point, torch.quint8) t_output = projs[i % layer_num](input_q) else: t_output = torch.mm(input[i % layer_num].to(proj_type), projs[i % layer_num].t()) # test start = time.perf_counter() for i in range(test_iter): if isinstance(projs[i % layer_num], nnq.Linear): input_q = torch.quantize_per_tensor(input[i % layer_num].to(torch.float32), scale, zero_point, torch.quint8) t_output = projs[i % layer_num](input_q) else: t_output = torch.mm(input[i % layer_num].to(proj_type), projs[i % layer_num].t()) end = time.perf_counter() total_time = end - start print('Quant mode: ', quant_mode) print('Time(s): ', total_time) print('Iteration: ', test_iter) print('Time(us) per iteration: ', total_time / test_iter * 1000000) print('Bandwidth: ', input_size * output_size * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s') print('') bench_linear("fp32") bench_linear("fp16") bench_linear("bf16") bench_linear("qint8") ================================================ FILE: archive/csrc/ktransformers_ext/bench/bench_mlp.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : chenht2022 Date : 2024-07-16 10:43:18 Version : 1.0.0 LastEditors : chenht2022 LastEditTime : 2024-08-06 10:36:04 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import os, sys import time sys.path.append(os.path.dirname(__file__) + '/../build') import cpuinfer_ext import torch hidden_size = 5120 intermediate_size = 3072 stride = 16 group_max_len = 1024 layer_num = 10 qlen = 1 CPUInfer = cpuinfer_ext.CPUInfer(64) warm_up_iter = 1000 test_iter = 10000 def bench_mlp(quant_mode: str): with torch.inference_mode(mode=True): hidden_type = 30 # ggml_type::GGML_TYPE_BF16 if quant_mode == "fp32": gate_type = 0 # ggml_type::GGML_TYPE_F32 up_type = 0 # ggml_type::GGML_TYPE_F32 down_type = 0 # ggml_type::GGML_TYPE_F32 bytes_per_elem = 4.000000 elif quant_mode == "fp16": gate_type = 1 # ggml_type::GGML_TYPE_F16 up_type = 1 # ggml_type::GGML_TYPE_F16 down_type = 1 # ggml_type::GGML_TYPE_F16 bytes_per_elem = 2.000000 elif quant_mode == "bf16": gate_type = 30 # ggml_type::GGML_TYPE_BF16 up_type = 30 # ggml_type::GGML_TYPE_BF16 down_type = 30 # ggml_type::GGML_TYPE_BF16 bytes_per_elem = 2.000000 elif quant_mode == "q8_0": gate_type = 8 # ggml_type::GGML_TYPE_Q8_0 up_type = 8 # ggml_type::GGML_TYPE_Q8_0 down_type = 8 # ggml_type::GGML_TYPE_Q8_0 bytes_per_elem = 1.062500 elif quant_mode == "q6_k": gate_type = 14 # ggml_type::GGML_TYPE_Q6_K up_type = 14 # ggml_type::GGML_TYPE_Q6_K down_type = 14 # ggml_type::GGML_TYPE_Q6_K bytes_per_elem = 0.820312 elif quant_mode == "q5_k_m": gate_type = 13 # ggml_type::GGML_TYPE_Q5_K up_type = 13 # ggml_type::GGML_TYPE_Q5_K down_type = 14 # ggml_type::GGML_TYPE_Q6_K bytes_per_elem = 0.731771 elif quant_mode == "q4_k_m": gate_type = 12 # ggml_type::GGML_TYPE_Q4_K up_type = 12 # ggml_type::GGML_TYPE_Q4_K down_type = 14 # ggml_type::GGML_TYPE_Q6_K bytes_per_elem = 0.648437 elif quant_mode == "q3_k_m": gate_type = 11 # ggml_type::GGML_TYPE_Q3_K up_type = 11 # ggml_type::GGML_TYPE_Q3_K down_type = 13 # ggml_type::GGML_TYPE_Q5_K bytes_per_elem = 0.515625 elif quant_mode == "q2_k": gate_type = 10 # ggml_type::GGML_TYPE_Q2_K up_type = 10 # ggml_type::GGML_TYPE_Q2_K down_type = 11 # ggml_type::GGML_TYPE_Q3_K bytes_per_elem = 0.328125 elif quant_mode == "iq3_xs": gate_type = 21 # ggml_type::GGML_TYPE_IQ3_S up_type = 21 # ggml_type::GGML_TYPE_IQ3_S down_type = 21 # ggml_type::GGML_TYPE_IQ3_S bytes_per_elem = 0.429688 elif quant_mode == "iq2_xxs": gate_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS up_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS down_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS bytes_per_elem = 0.257812 else: assert(False) mlps = [] gate_projs = [] up_projs = [] down_projs = [] for _ in range(layer_num): gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() config = cpuinfer_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type) mlp = cpuinfer_ext.mlp.MLP(config) gate_projs.append(gate_proj) up_projs.append(up_proj) down_projs.append(down_proj) mlps.append(mlp) input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous() output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous() # warm up for i in range(warm_up_iter): CPUInfer.submit( mlps[i % layer_num].forward( qlen, input[i % layer_num].data_ptr(), output[i % layer_num].data_ptr() ) ) CPUInfer.sync() # test start = time.perf_counter() for i in range(test_iter): CPUInfer.submit( mlps[i % layer_num].forward( qlen, input[i % layer_num].data_ptr(), output[i % layer_num].data_ptr() ) ) CPUInfer.sync() end = time.perf_counter() total_time = end - start print('Quant mode: ', quant_mode) print('Time(s): ', total_time) print('Iteration: ', test_iter) print('Time(us) per iteration: ', total_time / test_iter * 1000000) print('Bandwidth: ', hidden_size * intermediate_size * 3 * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s') print('') bench_mlp("fp32") bench_mlp("fp16") bench_mlp("bf16") bench_mlp("q8_0") bench_mlp("q6_k") bench_mlp("q5_k_m") bench_mlp("q4_k_m") bench_mlp("q3_k_m") bench_mlp("q2_k") # Not supported on __x86_64__ # bench_linear("iq3_xs") # bench_linear("iq2_xxs") ================================================ FILE: archive/csrc/ktransformers_ext/bench/bench_mlp_torch.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : chenht2022 Date : 2024-07-16 10:43:18 Version : 1.0.0 LastEditors : chenht2022 LastEditTime : 2024-07-25 10:32:53 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import os, sys import time import torch import torch.nn.quantized as nnq scale, zero_point = 0.1, 0 # Adjust scale and zero_point based on your dataset hidden_size = 5120 intermediate_size = 3072 layer_num = 10 qlen = 1 warm_up_iter = 1000 test_iter = 10000 def act_fn(x): return x / (1.0 + torch.exp(-x)) def mlp_torch(input, gate_proj, up_proj, down_proj): if isinstance(gate_proj, nnq.Linear): input_q = torch.quantize_per_tensor(input.to(torch.float32), scale, zero_point, torch.quint8) gate_buf = gate_proj(input_q) up_buf = up_proj(input_q) gate_buf = gate_buf.dequantize() up_buf = up_buf.dequantize() intermediate = act_fn(gate_buf) * up_buf intermediate_q = torch.quantize_per_tensor(intermediate, scale, zero_point, torch.quint8) expert_output = down_proj(intermediate_q) ret = expert_output.dequantize() else: gate_buf = torch.mm(input.to(gate_proj.dtype), gate_proj.t()) up_buf = torch.mm(input.to(up_proj.dtype), up_proj.t()) intermediate = act_fn(gate_buf) * up_buf ret = torch.mm(intermediate.to(down_proj.dtype), down_proj.t()) return ret def bench_mlp(quant_mode: str): with torch.inference_mode(mode=True): if quant_mode == "fp32": proj_type = torch.float32 bytes_per_elem = 4.000000 elif quant_mode == "fp16": proj_type = torch.float16 bytes_per_elem = 2.000000 elif quant_mode == "bf16": proj_type = torch.bfloat16 bytes_per_elem = 2.000000 elif quant_mode == "qint8": proj_type = torch.qint8 bytes_per_elem = 1.000000 else: assert(False) gate_projs = [] up_projs = [] down_projs = [] for _ in range(layer_num): gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() if quant_mode == "qint8": gate_proj_q = torch.quantize_per_tensor(gate_proj, scale, zero_point, torch.qint8) quantized_gate = nnq.Linear(hidden_size, intermediate_size) quantized_gate.set_weight_bias(gate_proj_q, None) up_proj_q = torch.quantize_per_tensor(up_proj, scale, zero_point, torch.qint8) quantized_up = nnq.Linear(hidden_size, intermediate_size) quantized_up.set_weight_bias(up_proj_q, None) down_proj_q = torch.quantize_per_tensor(down_proj, scale, zero_point, torch.qint8) quantized_down = nnq.Linear(intermediate_size, hidden_size) quantized_down.set_weight_bias(down_proj_q, None) gate_projs.append(quantized_gate) up_projs.append(quantized_up) down_projs.append(quantized_down) else: gate_projs.append(gate_proj.to(proj_type)) up_projs.append(up_proj.to(proj_type)) down_projs.append(down_proj.to(proj_type)) input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous() # warm up for i in range(warm_up_iter): mlp_torch(input[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num]) # test start = time.perf_counter() for i in range(test_iter): mlp_torch(input[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num]) end = time.perf_counter() total_time = end - start print('Quant mode: ', quant_mode) print('Time(s): ', total_time) print('Iteration: ', test_iter) print('Time(us) per iteration: ', total_time / test_iter * 1000000) print('Bandwidth: ', hidden_size * intermediate_size * 3 * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s') print('') bench_mlp("fp32") bench_mlp("fp16") bench_mlp("bf16") bench_mlp("qint8") ================================================ FILE: archive/csrc/ktransformers_ext/bench/bench_moe.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : chenht2022 Date : 2024-07-25 10:32:05 Version : 1.0.0 LastEditors : chenht2022 LastEditTime : 2024-08-06 10:41:28 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import os, sys import time sys.path.append(os.path.dirname(__file__) + '/../build') import cpuinfer_ext import torch expert_num = 160 hidden_size = 5120 intermediate_size = 1536 stride = 16 group_min_len = 10 group_max_len = 1024 n_routed_experts = 6 layer_num = 10 qlen = 1 CPUInfer = cpuinfer_ext.CPUInfer(64) warm_up_iter = 1000 test_iter = 10000 def bench_moe(quant_mode: str): with torch.inference_mode(mode=True): hidden_type = 30 # ggml_type::GGML_TYPE_BF16 if quant_mode == "fp32": gate_type = 0 # ggml_type::GGML_TYPE_F32 up_type = 0 # ggml_type::GGML_TYPE_F32 down_type = 0 # ggml_type::GGML_TYPE_F32 bytes_per_elem = 4.000000 elif quant_mode == "fp16": gate_type = 1 # ggml_type::GGML_TYPE_F16 up_type = 1 # ggml_type::GGML_TYPE_F16 down_type = 1 # ggml_type::GGML_TYPE_F16 bytes_per_elem = 2.000000 elif quant_mode == "bf16": gate_type = 30 # ggml_type::GGML_TYPE_BF16 up_type = 30 # ggml_type::GGML_TYPE_BF16 down_type = 30 # ggml_type::GGML_TYPE_BF16 bytes_per_elem = 2.000000 elif quant_mode == "q8_0": gate_type = 8 # ggml_type::GGML_TYPE_Q8_0 up_type = 8 # ggml_type::GGML_TYPE_Q8_0 down_type = 8 # ggml_type::GGML_TYPE_Q8_0 bytes_per_elem = 1.062500 elif quant_mode == "q6_k": gate_type = 14 # ggml_type::GGML_TYPE_Q6_K up_type = 14 # ggml_type::GGML_TYPE_Q6_K down_type = 14 # ggml_type::GGML_TYPE_Q6_K bytes_per_elem = 0.820312 elif quant_mode == "q5_k_m": gate_type = 13 # ggml_type::GGML_TYPE_Q5_K up_type = 13 # ggml_type::GGML_TYPE_Q5_K down_type = 14 # ggml_type::GGML_TYPE_Q6_K bytes_per_elem = 0.731771 elif quant_mode == "q4_k_m": gate_type = 12 # ggml_type::GGML_TYPE_Q4_K up_type = 12 # ggml_type::GGML_TYPE_Q4_K down_type = 14 # ggml_type::GGML_TYPE_Q6_K bytes_per_elem = 0.648437 elif quant_mode == "q3_k_m": gate_type = 11 # ggml_type::GGML_TYPE_Q3_K up_type = 11 # ggml_type::GGML_TYPE_Q3_K down_type = 13 # ggml_type::GGML_TYPE_Q5_K bytes_per_elem = 0.515625 elif quant_mode == "q2_k": gate_type = 10 # ggml_type::GGML_TYPE_Q2_K up_type = 10 # ggml_type::GGML_TYPE_Q2_K down_type = 11 # ggml_type::GGML_TYPE_Q3_K bytes_per_elem = 0.328125 elif quant_mode == "iq3_xs": gate_type = 21 # ggml_type::GGML_TYPE_IQ3_S up_type = 21 # ggml_type::GGML_TYPE_IQ3_S down_type = 21 # ggml_type::GGML_TYPE_IQ3_S bytes_per_elem = 0.429688 elif quant_mode == "iq2_xxs": gate_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS up_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS down_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS bytes_per_elem = 0.257812 else: assert(False) moes = [] gate_projs = [] up_projs = [] down_projs = [] for _ in range(layer_num): gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() config = cpuinfer_ext.moe.MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, stride, group_min_len, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type) moe = cpuinfer_ext.moe.MOE(config) gate_projs.append(gate_proj) up_projs.append(up_proj) down_projs.append(down_proj) moes.append(moe) expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = "cuda")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to("cpu").contiguous() weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = "cuda").to("cpu").contiguous() input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous() output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous() # warm up for i in range(warm_up_iter): CPUInfer.submit( moes[i % layer_num].forward( qlen, n_routed_experts, expert_ids[i % layer_num].data_ptr(), weights[i % layer_num].data_ptr(), input[i % layer_num].data_ptr(), output[i % layer_num].data_ptr() ) ) CPUInfer.sync() # test start = time.perf_counter() for i in range(test_iter): CPUInfer.submit( moes[i % layer_num].forward( qlen, n_routed_experts, expert_ids[i % layer_num].data_ptr(), weights[i % layer_num].data_ptr(), input[i % layer_num].data_ptr(), output[i % layer_num].data_ptr() ) ) CPUInfer.sync() end = time.perf_counter() total_time = end - start print('Quant mode: ', quant_mode) print('Time(s): ', total_time) print('Iteration: ', test_iter) print('Time(us) per iteration: ', total_time / test_iter * 1000000) print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s') print('') bench_moe("fp32") bench_moe("fp16") bench_moe("bf16") bench_moe("q8_0") bench_moe("q6_k") bench_moe("q5_k_m") bench_moe("q4_k_m") bench_moe("q3_k_m") bench_moe("q2_k") # Not supported on __x86_64__ # bench_linear("iq3_xs") # bench_linear("iq2_xxs") ================================================ FILE: archive/csrc/ktransformers_ext/bench/bench_moe_amx.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : chenht2022 Date : 2025-04-25 18:28:12 Version : 1.0.0 LastEditors : chenht2022 LastEditTime : 2025-04-25 18:28:12 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import os, sys import time sys.path.append(os.path.dirname(__file__) + '/../build') import cpuinfer_ext import torch expert_num = 8 hidden_size = 7168 intermediate_size = 2048 max_len = 25600 n_routed_experts = 8 layer_num = 10 qlen = 1024 CPUInfer = cpuinfer_ext.CPUInfer(65) warm_up_iter = 100 test_iter = 100 def bench_moe(quant_mode: str): with torch.inference_mode(mode=True): if quant_mode == "bf16": bytes_per_elem = 2.000000 elif quant_mode == "int8": bytes_per_elem = 1.000000 else: assert(False) moes = [] gate_projs = [] up_projs = [] down_projs = [] for _ in range(layer_num): gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() config = cpuinfer_ext.moe.AMX_MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr()) if quant_mode == "bf16": moe = cpuinfer_ext.moe.AMXBF16_MOE(config) CPUInfer.submit(moe.load_weights()) CPUInfer.sync() elif quant_mode == "int8": moe = cpuinfer_ext.moe.AMXInt8_MOE(config) CPUInfer.submit(moe.load_weights()) CPUInfer.sync() gate_projs.append(gate_proj) up_projs.append(up_proj) down_projs.append(down_proj) moes.append(moe) expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = "cuda")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to("cpu").contiguous() weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = "cuda").to("cpu").contiguous() input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous() output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous() qlen_tensor = torch.tensor([qlen], dtype=torch.int32) # warm up for i in range(warm_up_iter): CPUInfer.submit( moes[i % layer_num].forward( qlen, n_routed_experts, expert_ids[i % layer_num].data_ptr(), weights[i % layer_num].data_ptr(), input[i % layer_num].data_ptr(), output[i % layer_num].data_ptr(), qlen_tensor.data_ptr() ) ) CPUInfer.sync() # test start = time.perf_counter() for i in range(test_iter): CPUInfer.submit( moes[i % layer_num].forward( qlen, n_routed_experts, expert_ids[i % layer_num].data_ptr(), weights[i % layer_num].data_ptr(), input[i % layer_num].data_ptr(), output[i % layer_num].data_ptr(), qlen_tensor.data_ptr() ) ) CPUInfer.sync() end = time.perf_counter() total_time = end - start print('Quant mode: ', quant_mode) print('Time(s): ', total_time) print('Iteration: ', test_iter) print('Time(us) per iteration: ', total_time / test_iter * 1000000) print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s') print('Flops: ', hidden_size * intermediate_size * qlen * 3 * n_routed_experts * 2 * test_iter / total_time / 1000 / 1000 / 1000, 'GFLOPS') print('') bench_moe("bf16") bench_moe("int8") ================================================ FILE: archive/csrc/ktransformers_ext/bench/bench_moe_torch.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : chenht2022 Date : 2024-07-25 10:32:05 Version : 1.0.0 LastEditors : chenht2022 LastEditTime : 2024-07-25 10:32:57 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import os, sys import time import torch import torch.nn.quantized as nnq scale, zero_point = 0.1, 0 # Adjust scale and zero_point based on your dataset expert_num = 160 hidden_size = 5120 intermediate_size = 1536 n_routed_experts = 6 layer_num = 10 qlen = 1 warm_up_iter = 1000 test_iter = 10000 def act_fn(x): return x / (1.0 + torch.exp(-x)) def mlp_torch(input, gate_proj, up_proj, down_proj): if isinstance(gate_proj, nnq.Linear): input_q = torch.quantize_per_tensor(input.to(torch.float32), scale, zero_point, torch.quint8) gate_buf = gate_proj(input_q) up_buf = up_proj(input_q) gate_buf = gate_buf.dequantize() up_buf = up_buf.dequantize() intermediate = act_fn(gate_buf) * up_buf intermediate_q = torch.quantize_per_tensor(intermediate, scale, zero_point, torch.quint8) expert_output = down_proj(intermediate_q) ret = expert_output.dequantize() else: gate_buf = torch.mm(input.to(gate_proj.dtype), gate_proj.t()) up_buf = torch.mm(input.to(up_proj.dtype), up_proj.t()) intermediate = act_fn(gate_buf) * up_buf ret = torch.mm(intermediate.to(down_proj.dtype), down_proj.t()) return ret def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj): cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num)) cnts.scatter_(1, expert_ids, 1) tokens_per_expert = cnts.sum(dim=0) idxs = expert_ids.view(-1).argsort() sorted_tokens = input[idxs // expert_ids.shape[1]] outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]) outputs.append(expert_out) start_idx = end_idx outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) new_x = torch.empty_like(outs) new_x[idxs] = outs t_output = ( new_x.view(*expert_ids.shape, -1) .type(weights.dtype) .mul_(weights.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype) ) return t_output def bench_moe(quant_mode: str): with torch.inference_mode(mode=True): if quant_mode == "fp32": proj_type = torch.float32 bytes_per_elem = 4.000000 elif quant_mode == "fp16": proj_type = torch.float16 bytes_per_elem = 2.000000 elif quant_mode == "bf16": proj_type = torch.bfloat16 bytes_per_elem = 2.000000 elif quant_mode == "qint8": proj_type = torch.qint8 bytes_per_elem = 1.000000 else: assert(False) gate_projs = [] up_projs = [] down_projs = [] for _ in range(layer_num): gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = "cuda").to("cpu").contiguous() if quant_mode == "qint8": quantized_gate_proj = [] quantized_up_proj = [] quantized_down_proj = [] for i in range(expert_num): gate_proj_q = torch.quantize_per_tensor(gate_proj[i], scale, zero_point, torch.qint8) quantized_gate = nnq.Linear(hidden_size, intermediate_size) quantized_gate.set_weight_bias(gate_proj_q, None) quantized_gate_proj.append(quantized_gate) up_proj_q = torch.quantize_per_tensor(up_proj[i], scale, zero_point, torch.qint8) quantized_up = nnq.Linear(hidden_size, intermediate_size) quantized_up.set_weight_bias(up_proj_q, None) quantized_up_proj.append(quantized_up) down_proj_q = torch.quantize_per_tensor(down_proj[i], scale, zero_point, torch.qint8) quantized_down = nnq.Linear(intermediate_size, hidden_size) quantized_down.set_weight_bias(down_proj_q, None) quantized_down_proj.append(quantized_down) gate_projs.append(quantized_gate_proj) up_projs.append(quantized_up_proj) down_projs.append(quantized_down_proj) else: gate_projs.append(gate_proj.to(proj_type)) up_projs.append(up_proj.to(proj_type)) down_projs.append(down_proj.to(proj_type)) expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = "cuda")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to("cpu").contiguous() weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = "cuda").to("cpu").contiguous() input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = "cuda").to("cpu").contiguous() # warm up for i in range(warm_up_iter): moe_torch(input[i % layer_num], expert_ids[i % layer_num], weights[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num]) # test start = time.perf_counter() for i in range(test_iter): moe_torch(input[i % layer_num], expert_ids[i % layer_num], weights[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num]) end = time.perf_counter() total_time = end - start print('Quant mode: ', quant_mode) print('Time(s): ', total_time) print('Iteration: ', test_iter) print('Time(us) per iteration: ', total_time / test_iter * 1000000) print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s') print('') bench_moe("fp32") bench_moe("fp16") bench_moe("bf16") bench_moe("qint8") ================================================ FILE: archive/csrc/ktransformers_ext/cmake/FindSIMD.cmake ================================================ include(CheckCSourceRuns) set(AVX_CODE " #include int main() { __m256 a; a = _mm256_set1_ps(0); return 0; } ") set(AVX512_CODE " #include int main() { __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); __m512i b = a; __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ); return 0; } ") set(AVX2_CODE " #include int main() { __m256i a = {0}; a = _mm256_abs_epi16(a); __m256i x; _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code return 0; } ") set(FMA_CODE " #include int main() { __m256 acc = _mm256_setzero_ps(); const __m256 d = _mm256_setzero_ps(); const __m256 p = _mm256_setzero_ps(); acc = _mm256_fmadd_ps( d, p, acc ); return 0; } ") macro(check_sse type flags) set(__FLAG_I 1) set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) foreach (__FLAG ${flags}) if (NOT ${type}_FOUND) set(CMAKE_REQUIRED_FLAGS ${__FLAG}) check_c_source_runs("${${type}_CODE}" HAS_${type}_${__FLAG_I}) if (HAS_${type}_${__FLAG_I}) set(${type}_FOUND TRUE CACHE BOOL "${type} support") set(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags") endif() math(EXPR __FLAG_I "${__FLAG_I}+1") endif() endforeach() set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) if (NOT ${type}_FOUND) set(${type}_FOUND FALSE CACHE BOOL "${type} support") set(${type}_FLAGS "" CACHE STRING "${type} flags") endif() mark_as_advanced(${type}_FOUND ${type}_FLAGS) endmacro() # flags are for MSVC only! check_sse("AVX" " ;/arch:AVX") if (NOT ${AVX_FOUND}) set(LLAMA_AVX OFF) else() set(LLAMA_AVX ON) endif() check_sse("AVX2" " ;/arch:AVX2") check_sse("FMA" " ;/arch:AVX2") if ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND})) set(LLAMA_AVX2 OFF) else() set(LLAMA_AVX2 ON) endif() check_sse("AVX512" " ;/arch:AVX512") if (NOT ${AVX512_FOUND}) set(LLAMA_AVX512 OFF) else() set(LLAMA_AVX512 ON) endif() ================================================ FILE: archive/csrc/ktransformers_ext/cpu_backend/backend.cpp ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2024-07-22 02:03:05 * @Version : 1.0.0 * @LastEditors : chenht2022 * @LastEditTime : 2024-07-25 10:33:34 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "backend.h" #ifdef USE_NUMA #include #include thread_local int Backend::numa_node = -1; #endif thread_local int Backend::thread_local_id = -1; Backend::Backend(int max_thread_num) { max_thread_num_ = max_thread_num; thread_state_.resize(max_thread_num_); for (int i = 0; i < max_thread_num_; i++) { thread_state_[i].curr = std::make_unique>(); thread_state_[i].status = std::make_unique>(ThreadStatus::WAITING); } workers_.resize(max_thread_num_); for (int i = 1; i < max_thread_num_; i++) { workers_[i] = std::thread(&Backend::worker_thread, this, i); } } Backend::~Backend() { for (int i = 0; i < max_thread_num_; i++) { thread_state_[i].status->store(ThreadStatus::EXIT, std::memory_order_release); } for (int i = 1; i < max_thread_num_; i++) { if (workers_[i].joinable()) { workers_[i].join(); } } } int Backend::get_thread_num() { return max_thread_num_; } void Backend::do_work_stealing_job(int task_num, std::function init_func, std::function compute_func, std::function finalize_func) { init_func_ = init_func; compute_func_ = compute_func; finalize_func_ = finalize_func; #ifdef USE_NUMA // numa node location will be calculated based on the number of threads thread_num_ = max_thread_num_; #else thread_num_ = std::min(max_thread_num_, task_num); #endif int base = task_num / thread_num_; int remain = task_num % thread_num_; thread_state_[0].end = base + (0 < remain); // 为主线程设置 thread_local_id thread_local_id = 0; for (int i = 1; i < thread_num_; i++) { thread_state_[i].curr->store(thread_state_[i - 1].end, std::memory_order_relaxed); thread_state_[i].end = thread_state_[i - 1].end + base + (i < remain); thread_state_[i].status->store(ThreadStatus::WORKING, std::memory_order_release); } thread_state_[0].curr->store(0, std::memory_order_relaxed); thread_state_[0].status->store(ThreadStatus::WORKING, std::memory_order_release); process_tasks(0); for (int i = 1; i < thread_num_; i++) { while (thread_state_[i].status->load(std::memory_order_acquire) == ThreadStatus::WORKING) { } } } void Backend::process_tasks(int thread_id) { #ifdef USE_NUMA if(numa_node == -1){ numa_node = thread_id * numa_num_configured_nodes() / thread_num_; struct bitmask* mask = numa_bitmask_alloc(numa_num_configured_nodes()); numa_bitmask_setbit(mask, numa_node); numa_bind(mask); } #endif if (init_func_ != nullptr) { init_func_(thread_id); } while (true) { int task_id = thread_state_[thread_id].curr->fetch_add( 1, std::memory_order_acq_rel); if (task_id >= thread_state_[thread_id].end) { break; } compute_func_(task_id); } for (int t_offset = 1; t_offset < thread_num_; t_offset++) { int t_i = (thread_id + t_offset) % thread_num_; if (thread_state_[t_i].status->load(std::memory_order_acquire) != ThreadStatus::WORKING) { continue; } while (true) { int task_id = thread_state_[t_i].curr->fetch_add( 1, std::memory_order_acq_rel); if (task_id >= thread_state_[t_i].end) { break; } compute_func_(task_id); } } if (finalize_func_ != nullptr) { finalize_func_(thread_id); } thread_state_[thread_id].status->store(ThreadStatus::WAITING, std::memory_order_release); } void Backend::worker_thread(int thread_id) { auto start = std::chrono::steady_clock::now(); thread_local_id = thread_id; // 设置线程本地变量 while (true) { ThreadStatus status = thread_state_[thread_id].status->load(std::memory_order_acquire); if (status == ThreadStatus::WORKING) { process_tasks(thread_id); start = std::chrono::steady_clock::now(); } else if (status == ThreadStatus::WAITING) { auto now = std::chrono::steady_clock::now(); auto duration = std::chrono::duration_cast(now - start) .count(); if (duration > 50) { std::this_thread::sleep_for(std::chrono::milliseconds(1)); } } else if (status == ThreadStatus::EXIT) { return; } } } ================================================ FILE: archive/csrc/ktransformers_ext/cpu_backend/backend.h ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2024-07-22 02:03:05 * @Version : 1.0.0 * @LastEditors : chenht2022 * @LastEditTime : 2024-07-25 10:33:38 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #ifndef CPUINFER_BACKEND_H #define CPUINFER_BACKEND_H #include #include #include #include #include #include #include enum ThreadStatus { WORKING, WAITING, EXIT, }; struct ThreadState { std::unique_ptr> status; std::unique_ptr> curr; int end; }; class Backend { public: Backend(int); ~Backend(); int get_thread_num(); void do_work_stealing_job(int, std::function, std::function, std::function); #ifdef USE_NUMA static thread_local int numa_node; #endif static thread_local int thread_local_id; private: int thread_num_; int max_thread_num_; std::vector thread_state_; // [thread_num] std::function init_func_; std::function compute_func_; std::function finalize_func_; std::vector workers_; void process_tasks(int); void worker_thread(int); }; #endif ================================================ FILE: archive/csrc/ktransformers_ext/cpu_backend/cpuinfer.h ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2024-07-16 10:43:18 * @Version : 1.0.0 * @LastEditors : chenht2022 * @LastEditTime : 2024-08-07 09:47:43 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #ifndef CPUINFER_CPUINFER_H #define CPUINFER_CPUINFER_H #include #include #include #include #include #include #include #include #ifdef KTRANSFORMERS_USE_CUDA #include "vendors/cuda.h" #elif KTRANSFORMERS_USE_MUSA #include "vendors/musa.h" #elif KTRANSFORMERS_USE_ROCM #define __HIP_PLATFORM_AMD__ #include "vendors/hip.h" #endif #include "backend.h" #include "task_queue.h" #include "./vendors/vendor.h" #include "llama.cpp/ggml-impl.h" class CPUInfer { public: CPUInfer(int thread_num) { backend_ = new Backend(thread_num - 1); task_queue_ = new TaskQueue(); for (int i = 0; i < (1 << 16); ++i) { ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i); } } ~CPUInfer() { delete backend_; delete task_queue_; } template void enqueue(Func f, Obj* obj, Args... args) { task_queue_->enqueue([=]() { std::invoke(f, *obj, args..., backend_); }); } void submit(std::pair params) { void (*func)(void*) = (void (*)(void*))params.first; void* args = (void*)params.second; *((CPUInfer**)args) = this; func(args); } void sync() { task_queue_->sync(); } void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair params) { #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM) void (*func)(void*) = (void (*)(void*))params.first; void* args = (void*)params.second; *((CPUInfer**)args) = this; cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args); #else throw std::runtime_error("submit_with_cuda_stream is not supported on this platforma"); #endif } static void sync_(void* cpu_infer_ptr) { CPUInfer* cpuinfer = (CPUInfer*)cpu_infer_ptr; cpuinfer->sync(); } void sync_with_cuda_stream(intptr_t user_cuda_stream) { #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM) cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this); #else throw std::runtime_error("sync_with_cuda_stream is not supported on this platforma"); #endif } public: Backend* backend_; TaskQueue* task_queue_; }; #endif ================================================ FILE: archive/csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.cpp ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2024-08-05 04:49:08 * @Version : 1.0.0 * @LastEditors : chenht2022 * @LastEditTime : 2024-08-05 09:21:29 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "shared_mem_buffer.h" #include SharedMemBuffer::SharedMemBuffer() { buffer_ = nullptr; size_ = 0; } SharedMemBuffer::~SharedMemBuffer() { if (buffer_) { free(buffer_); } } void SharedMemBuffer::alloc(void* object, std::vector> requests) { uint64_t size = 0; for (auto& request : requests) { size += request.second; } if (size > size_) { if (buffer_) { free(buffer_); } buffer_ = std::aligned_alloc(64, size); size_ = size; for (auto& obj_requests : hist_requests_) { for (auto& requests : obj_requests.second) { arrange(requests); } } } arrange(requests); hist_requests_[object].push_back(requests); } void SharedMemBuffer::dealloc(void* object) { hist_requests_.erase(object); } void SharedMemBuffer::arrange(std::vector> requests) { uint64_t offset = 0; for (auto& request : requests) { *(request.first) = (uint8_t*)buffer_ + offset; offset += request.second; } } ================================================ FILE: archive/csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.h ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2024-08-05 04:49:08 * @Version : 1.0.0 * @LastEditors : chenht2022 * @LastEditTime : 2024-08-05 06:36:41 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #ifndef CPUINFER_SHAREDMEMBUFFER_H #define CPUINFER_SHAREDMEMBUFFER_H #include #include #include #include class SharedMemBuffer { public: SharedMemBuffer(); ~SharedMemBuffer(); void alloc(void* object, std::vector> requests); void dealloc(void* object); private: void* buffer_; uint64_t size_; std::map>>> hist_requests_; void arrange(std::vector> requests); }; static SharedMemBuffer shared_mem_buffer; #endif ================================================ FILE: archive/csrc/ktransformers_ext/cpu_backend/task_queue.cpp ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2024-07-17 12:25:51 * @Version : 1.0.0 * @LastEditors : chenht2022 * @LastEditTime : 2024-10-09 11:08:10 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "task_queue.h" TaskQueue::TaskQueue() { worker = std::thread(&TaskQueue::processTasks, this); sync_flag.store(true, std::memory_order_seq_cst); exit_flag.store(false, std::memory_order_seq_cst); } TaskQueue::~TaskQueue() { { mutex.lock(); exit_flag.store(true, std::memory_order_seq_cst); mutex.unlock(); } cv.notify_all(); if (worker.joinable()) { worker.join(); } } void TaskQueue::enqueue(std::function task) { { mutex.lock(); tasks.push(task); sync_flag.store(false, std::memory_order_seq_cst); mutex.unlock(); } cv.notify_one(); } void TaskQueue::sync() { while (!sync_flag.load(std::memory_order_seq_cst)) ; } void TaskQueue::processTasks() { while (true) { std::function task; { mutex.lock(); cv.wait(mutex, [this]() { return !tasks.empty() || exit_flag.load(std::memory_order_seq_cst); }); if (exit_flag.load(std::memory_order_seq_cst) && tasks.empty()) { return; } task = tasks.front(); tasks.pop(); mutex.unlock(); } task(); { mutex.lock(); if (tasks.empty()) { sync_flag.store(true, std::memory_order_seq_cst); } mutex.unlock(); } } } ================================================ FILE: archive/csrc/ktransformers_ext/cpu_backend/task_queue.h ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2024-07-16 10:43:18 * @Version : 1.0.0 * @LastEditors : chenht * @LastEditTime : 2024-10-09 11:08:07 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #ifndef CPUINFER_TASKQUEUE_H #define CPUINFER_TASKQUEUE_H #include #include #include #include #include #include #include #ifdef _WIN32 #include #endif class custom_mutex { private: #ifdef _WIN32 CRITICAL_SECTION cs; #else std::mutex mtx; #endif public: custom_mutex() { #ifdef _WIN32 InitializeCriticalSection(&cs); #else // No initialization required for std::mutex #endif } ~custom_mutex() { #ifdef _WIN32 DeleteCriticalSection(&cs); #endif } void lock() { #ifdef _WIN32 EnterCriticalSection(&cs); #else mtx.lock(); #endif } void unlock() { #ifdef _WIN32 LeaveCriticalSection(&cs); #else mtx.unlock(); #endif } #ifdef _WIN32 CRITICAL_SECTION* get_handle() { return &cs; } #else std::mutex* get_handle() { return &mtx; } #endif }; class custom_condition_variable { private: #ifdef _WIN32 CONDITION_VARIABLE cond_var; #else std::condition_variable cond_var; #endif public: custom_condition_variable() { #ifdef _WIN32 InitializeConditionVariable(&cond_var); #endif } template void wait(custom_mutex& mutex, Predicate pred) { #ifdef _WIN32 while (!pred()) { SleepConditionVariableCS(&cond_var, mutex.get_handle(), INFINITE); } #else std::unique_lock lock(*mutex.get_handle(), std::adopt_lock); cond_var.wait(lock, pred); lock.release(); #endif } void notify_one() { #ifdef _WIN32 WakeConditionVariable(&cond_var); #else cond_var.notify_one(); #endif } void notify_all() { #ifdef _WIN32 WakeAllConditionVariable(&cond_var); #else cond_var.notify_all(); #endif } }; class TaskQueue { public: TaskQueue(); ~TaskQueue(); void enqueue(std::function); void sync(); private: void processTasks(); std::queue> tasks; custom_mutex mutex; custom_condition_variable cv; std::thread worker; std::atomic sync_flag; std::atomic exit_flag; }; #endif ================================================ FILE: archive/csrc/ktransformers_ext/cpu_backend/vendors/README.md ================================================ ## TODO This directory can be removed after updating the version of `llama.cpp`. ================================================ FILE: archive/csrc/ktransformers_ext/cpu_backend/vendors/cuda.h ================================================ #pragma once #include #include #include #include #include #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH #define CUBLAS_COMPUTE_16F CUDA_R_16F #define CUBLAS_COMPUTE_32F CUDA_R_32F #define cublasComputeType_t cudaDataType_t #endif // CUDART_VERSION < 11020 ================================================ FILE: archive/csrc/ktransformers_ext/cpu_backend/vendors/hip.h ================================================ #pragma once #define HIP_ENABLE_WARP_SYNC_BUILTINS 1 #include #include #include #include #ifdef __HIP_PLATFORM_AMD__ // for rocblas_initialize() #include "rocblas/rocblas.h" #endif // __HIP_PLATFORM_AMD__ #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_OP_N HIPBLAS_OP_N #define CUBLAS_OP_T HIPBLAS_OP_T #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_TF32_TENSOR_OP_MATH 0 #define CUDA_R_16F HIPBLAS_R_16F #define CUDA_R_32F HIPBLAS_R_32F #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended #define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 #define cublasCreate hipblasCreate #define cublasDestroy hipblasDestroy #define cublasGemmEx hipblasGemmEx #define cublasGemmBatchedEx hipblasGemmBatchedEx #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx #define cublasHandle_t hipblasHandle_t #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS #define cublasSetStream hipblasSetStream #define cublasSgemm hipblasSgemm #define cublasStatus_t hipblasStatus_t #define cublasOperation_t hipblasOperation_t #define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6 #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags hipEventCreateWithFlags #define cudaEventDisableTiming hipEventDisableTiming #define cudaEventRecord hipEventRecord #define cudaEventSynchronize hipEventSynchronize #define cudaEvent_t hipEvent_t #define cudaEventDestroy hipEventDestroy #define cudaFree hipFree #define cudaFreeHost hipHostFree #define cudaGetDevice hipGetDevice #define cudaGetDeviceCount hipGetDeviceCount #define cudaGetDeviceProperties hipGetDeviceProperties #define cudaGetErrorString hipGetErrorString #define cudaGetLastError hipGetLastError #define cudaHostRegister hipHostRegister #define cudaHostRegisterPortable hipHostRegisterPortable #define cudaHostRegisterReadOnly hipHostRegisterReadOnly #define cudaHostUnregister hipHostUnregister #define cudaLaunchHostFunc hipLaunchHostFunc #define cudaMalloc hipMalloc #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) #define cudaMemcpy hipMemcpy #define cudaMemcpyAsync hipMemcpyAsync #define cudaMemcpyPeerAsync hipMemcpyPeerAsync #define cudaMemcpy2DAsync hipMemcpy2DAsync #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost #define cudaMemcpyHostToDevice hipMemcpyHostToDevice #define cudaMemcpyKind hipMemcpyKind #define cudaMemset hipMemset #define cudaMemsetAsync hipMemsetAsync #define cudaMemGetInfo hipMemGetInfo #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize #define cudaSetDevice hipSetDevice #define cuDeviceGet hipDeviceGet #define CUdevice hipDevice_t #define CUdeviceptr hipDeviceptr_t #define cuMemUnmap hipMemUnmap #define CUmemAccessDesc hipMemAccessDesc #define cuMemAddressFree hipMemAddressFree #define cuMemRelease hipMemRelease #define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t #define cuMemCreate hipMemCreate #define cuMemAddressReserve hipMemAddressReserve #define cuMemMap hipMemMap #define cuMemSetAccess hipMemSetAccess #define cuMemGetAllocationGranularity hipMemGetAllocationGranularity #define CUmemAllocationProp hipMemAllocationProp #define cuDeviceGetAttribute hipDeviceGetAttribute #define cudaStreamCreateWithFlags hipStreamCreateWithFlags #define cudaStreamDestroy hipStreamDestroy #define cudaStreamFireAndForget hipStreamFireAndForget #define cudaStreamNonBlocking hipStreamNonBlocking #define cudaStreamPerThread hipStreamPerThread #define cudaStreamSynchronize hipStreamSynchronize #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) #define cudaGraphExec_t hipGraphExec_t #define cudaGraphNode_t hipGraphNode_t #define cudaKernelNodeParams hipKernelNodeParams #define cudaKernelNodeParams hipKernelNodeParams #define cudaGraphExecDestroy hipGraphExecDestroy #define cudaGraphLaunch hipGraphLaunch #define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure #define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult #define cudaGraphNodeType hipGraphNodeType #define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel #define cudaGraphInstantiate hipGraphInstantiate #define cudaStreamEndCapture hipStreamEndCapture #define cudaGraphDestroy hipGraphDestroy #define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams #define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction #define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams #define cudaGraphNodeGetType hipGraphNodeGetType #define cudaGraphGetNodes hipGraphGetNodes #define cudaGraphExecUpdate hipGraphExecUpdate #define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed #define cudaStreamBeginCapture hipStreamBeginCapture #define cudaGraph_t hipGraph_t #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess #define cudaHostFn_t hipHostFn_t #define __trap() do { abort(); __builtin_unreachable(); } while(0) #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED #define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE #define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH #define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED #define __CUDA_ARCH__ 1300 #if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) #define GCN #endif #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) #define CDNA #endif #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ defined(__gfx1150__) || defined(__gfx1151__) #define RDNA3 #endif #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) #define RDNA2 #endif #if defined(__gfx1010__) || defined(__gfx1012__) #define RDNA1 #endif #ifndef __has_builtin #define __has_builtin(x) 0 #endif typedef hip_bfloat16 nv_bfloat16; ================================================ FILE: archive/csrc/ktransformers_ext/cpu_backend/vendors/musa.h ================================================ #pragma once #include #include #include #include #include #define CUBLAS_COMPUTE_16F CUDA_R_16F #define CUBLAS_COMPUTE_32F CUDA_R_32F #define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F #define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT #define CUBLAS_OP_N MUBLAS_OP_N #define CUBLAS_OP_T MUBLAS_OP_T #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT #define CUDA_R_16F MUSA_R_16F #define CUDA_R_32F MUSA_R_32F #define cublasComputeType_t cudaDataType_t #define cublasCreate mublasCreate #define cublasDestroy mublasDestroy #define cublasGemmEx mublasGemmEx #define cublasGemmBatchedEx mublasGemmBatchedEx #define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx #define cublasHandle_t mublasHandle_t #define cublasSetMathMode mublasSetMathMode #define cublasSetStream mublasSetStream #define cublasSgemm mublasSgemm #define cublasStatus_t mublasStatus_t #define cublasOperation_t mublasOperation_t #define cublasGetStatusString mublasStatus_to_string #define cudaDataType_t musaDataType_t #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess #define cudaDeviceProp musaDeviceProp #define cudaDeviceSynchronize musaDeviceSynchronize #define cudaError_t musaError_t #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags musaEventCreateWithFlags #define cudaEventDisableTiming musaEventDisableTiming #define cudaEventRecord musaEventRecord #define cudaEventSynchronize musaEventSynchronize #define cudaEvent_t musaEvent_t #define cudaEventDestroy musaEventDestroy #define cudaFree musaFree #define cudaFreeHost musaFreeHost #define cudaGetDevice musaGetDevice #define cudaGetDeviceCount musaGetDeviceCount #define cudaGetDeviceProperties musaGetDeviceProperties #define cudaGetErrorString musaGetErrorString #define cudaGetLastError musaGetLastError #define cudaHostRegister musaHostRegister #define cudaHostRegisterPortable musaHostRegisterPortable #define cudaHostRegisterReadOnly musaHostRegisterReadOnly #define cudaHostUnregister musaHostUnregister #define cudaLaunchHostFunc musaLaunchHostFunc #define cudaMalloc musaMalloc #define cudaMallocHost musaMallocHost #define cudaMallocManaged musaMallocManaged #define cudaMemcpy musaMemcpy #define cudaMemcpyAsync musaMemcpyAsync #define cudaMemcpyPeerAsync musaMemcpyPeerAsync #define cudaMemcpy2DAsync musaMemcpy2DAsync #define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice #define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost #define cudaMemcpyHostToDevice musaMemcpyHostToDevice #define cudaMemcpyKind musaMemcpyKind #define cudaMemset musaMemset #define cudaMemsetAsync musaMemsetAsync #define cudaMemGetInfo musaMemGetInfo #define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize #define cudaSetDevice musaSetDevice #define cudaStreamCreateWithFlags musaStreamCreateWithFlags #define cudaStreamDestroy musaStreamDestroy #define cudaStreamFireAndForget musaStreamFireAndForget #define cudaStreamNonBlocking musaStreamNonBlocking #define cudaStreamPerThread musaStreamPerThread #define cudaStreamSynchronize musaStreamSynchronize #define cudaStreamWaitEvent musaStreamWaitEvent #define cudaStream_t musaStream_t #define cudaSuccess musaSuccess // Additional mappings for MUSA virtual memory pool #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED #define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED #define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE #define CUdevice MUdevice #define CUdeviceptr MUdeviceptr #define CUmemAccessDesc MUmemAccessDesc #define CUmemAllocationProp MUmemAllocationProp #define CUmemGenericAllocationHandle MUmemGenericAllocationHandle #define cuDeviceGet muDeviceGet #define cuDeviceGetAttribute muDeviceGetAttribute #define cuMemAddressFree muMemAddressFree #define cuMemAddressReserve muMemAddressReserve #define cuMemCreate muMemCreate #define cuMemGetAllocationGranularity muMemGetAllocationGranularity #define cuMemMap muMemMap #define cuMemRelease muMemRelease #define cuMemSetAccess muMemSetAccess #define cuMemUnmap muMemUnmap #define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize #define cudaFuncSetAttribute musaFuncSetAttribute #define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms #define make_cudaExtent make_musaExtent #define make_cudaPitchedPtr make_musaPitchedPtr // Additional mappings for MUSA graphs #define CUDA_SUCCESS MUSA_SUCCESS #define CUresult MUresult #define cuGetErrorString muGetErrorString #define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure #define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction #define cudaGraphDestroy musaGraphDestroy #define cudaGraphExecDestroy musaGraphExecDestroy #define cudaGraphExec_t musaGraphExec_t #define cudaGraphExecUpdate musaGraphExecUpdate #define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult #define cudaGraphGetNodes musaGraphGetNodes #define cudaGraphInstantiate musaGraphInstantiate #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams #define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams #define cudaGraphLaunch musaGraphLaunch #define cudaGraphNodeGetType musaGraphNodeGetType #define cudaGraphNode_t musaGraphNode_t #define cudaGraphNodeType musaGraphNodeType #define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel #define cudaGraph_t musaGraph_t #define cudaKernelNodeParams musaKernelNodeParams #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed #define cudaStreamEndCapture musaStreamEndCapture typedef mt_bfloat16 nv_bfloat16; ================================================ FILE: archive/csrc/ktransformers_ext/cpu_backend/vendors/vendor.h ================================================ #ifndef CPUINFER_VENDOR_VENDOR_H #define CPUINFER_VENDOR_VENDOR_H #ifdef USE_CUDA #include "cuda.h" #elif USE_HIP #define __HIP_PLATFORM_AMD__ #include "hip.h" #elif USE_MUSA #include "musa.h" #endif #endif // CPUINFER_VENDOR_VENDOR_H ================================================ FILE: archive/csrc/ktransformers_ext/cuda/binding.cpp ================================================ /** * @Description : * @Author : Azure-Tang, Boxin Zhang * @Date : 2024-07-25 13:38:30 * @Version : 0.2.2 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "custom_gguf/ops.h" #ifdef KTRANSFORMERS_USE_CUDA #include "gptq_marlin/ops.h" #endif // Python bindings #include #include #include #include #include // namespace py = pybind11; PYBIND11_MODULE(KTransformersOps, m) { m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q8_0 data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q6_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q5_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q4_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q3_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize q2_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); }, "Function to dequantize iq4_xs data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); #ifdef KTRANSFORMERS_USE_CUDA m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full")); #endif } ================================================ FILE: archive/csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu ================================================ /* * @Description : * @Author : Azure-Tang, Boxin Zhang * @Date : 2024-07-25 13:38:30 * @Version : 0.2.2 * Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c * Copyright (c) 2023-2024 The ggml authors * Copyright (c) 2024 by KVCache.AI, All Rights Reserved. */ #include #include #include #include #include #include #include #include #ifdef __HIP_PLATFORM_AMD__ typedef __hip_bfloat16 nv_bfloat16; #endif __global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk); const int8_t* cur_block = data + block_id * blk_size; float scale = __half2float(*((half*)cur_block)); cur_block += 2; for (int i = 0; i < ele_per_blk; i++){ output_blk[i] = scale * cur_block[i]; } } } __global__ void dequantize_q8_0_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) { __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk); const int8_t* cur_block = data + block_id * blk_size; float scale = __half2float(*((half*)cur_block)); cur_block += 2; for (int i = 0; i < ele_per_blk; i++) { output_blk[i] = __float2half(scale * cur_block[i]); } } } __global__ void dequantize_q8_0_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) { nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk); const int8_t* cur_block = data + block_id * blk_size; float scale = __half2float(*((half*)cur_block)); cur_block += 2; for (int i = 0; i < ele_per_blk; i++) { output_blk[i] = __float2bfloat16(scale * cur_block[i]); } } } // __device__ void get_scale_min_k4(int j, const uint8_t * __restrict__ q, uint8_t * __restrict__ d, uint8_t * __restrict__ m) { __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict__ d, uint8_t * __restrict__ m) { if (j < 4) { *d = q[j] & 63; *m = q[j + 4] & 63; } else { *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); } } __global__ void dequantize_q2_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 80))); const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 82))); const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16); int is = 0; float dl, ml; for (int n = 0; n < 256; n += 128) { int shift = 0; for (int j = 0; j < 4; ++j) { uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++)); uint8_t sc = *scales; dl = d * (sc & 0xF); ml = min * (sc >> 4); for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml; scales = (uint8_t*)(data + block_id * blk_size + (is++)); sc = *scales; dl = d * (sc & 0xF); ml = min * (sc >> 4); for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml; shift += 2; } q += 32; } } } __global__ void dequantize_q2_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 80))); const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 82))); const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16); int is = 0; float dl, ml; for (int n = 0; n < 256; n += 128) { int shift = 0; for (int j = 0; j < 4; ++j) { uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++)); uint8_t sc = *scales; dl = d * (sc & 0xF); ml = min * (sc >> 4); for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l] >> shift) & 3)) - ml); scales = (uint8_t*)(data + block_id * blk_size + (is++)); sc = *scales; dl = d * (sc & 0xF); ml = min * (sc >> 4); for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml); shift += 2; } q += 32; } } } __global__ void dequantize_q2_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 80))); const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 82))); const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16); int is = 0; float dl, ml; for (int n = 0; n < 256; n += 128) { int shift = 0; for (int j = 0; j < 4; ++j) { uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++)); uint8_t sc = *scales; dl = d * (sc & 0xF); ml = min * (sc >> 4); for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l] >> shift) & 3)) - ml); scales = (uint8_t*)(data + block_id * blk_size + (is++)); sc = *scales; dl = d * (sc & 0xF); ml = min * (sc >> 4); for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml); shift += 2; } q += 32; } } } __global__ void dequantize_q3_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; const uint32_t kmask1 = 0x03030303; const uint32_t kmask2 = 0x0f0f0f0f; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 108))); const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32); const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0); uint8_t m = 1; uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96); for (int i = 0; i < 3; i++) { aux[i] = 0; for (int j = 0; j < 4; j++) { aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8); } } uint32_t tmp = aux[2]; aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); int is = 0; float dl; for (int n = 0; n < 256; n += 128) { int shift = 0; for (int j = 0; j < 4; ++j) { dl = d_all * (scales[is++] - 32); for (int l = 0; l < 16; ++l) { *output_blk++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)); } dl = d_all * (scales[is++] - 32); for (int l = 0; l < 16; ++l) { *output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)); } shift += 2; m <<= 1; } q += 32; } } } __global__ void dequantize_q3_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; const uint32_t kmask1 = 0x03030303; const uint32_t kmask2 = 0x0f0f0f0f; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 108))); const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32); const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0); uint8_t m = 1; uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96); for (int i = 0; i < 3; i++) { aux[i] = 0; for (int j = 0; j < 4; j++) { aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8); } } uint32_t tmp = aux[2]; aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); int is = 0; float dl; for (int n = 0; n < 256; n += 128) { int shift = 0; for (int j = 0; j < 4; ++j) { dl = d_all * (scales[is++] - 32); for (int l = 0; l < 16; ++l) { *output_blk++ = __float2half(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4))); } dl = d_all * (scales[is++] - 32); for (int l = 0; l < 16; ++l) { *output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4))); } shift += 2; m <<= 1; } q += 32; } } } __global__ void dequantize_q3_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; const uint32_t kmask1 = 0x03030303; const uint32_t kmask2 = 0x0f0f0f0f; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 108))); const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32); const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0); uint8_t m = 1; uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96); for (int i = 0; i < 3; i++) { aux[i] = 0; for (int j = 0; j < 4; j++) { aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8); } } uint32_t tmp = aux[2]; aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); int is = 0; float dl; for (int n = 0; n < 256; n += 128) { int shift = 0; for (int j = 0; j < 4; ++j) { dl = d_all * (scales[is++] - 32); for (int l = 0; l < 16; ++l) { *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4))); } dl = d_all * (scales[is++] - 32); for (int l = 0; l < 16; ++l) { *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4))); } shift += 2; m <<= 1; } q += 32; } } } __global__ void dequantize_q4_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * 144 + 0))); const float min = __half2float(*(reinterpret_cast(data + block_id * 144 + 2))); int is = 0; uint8_t sc, m; for (int j = 0; j < ele_per_blk; j += 64) { uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4); get_scale_min_k4(is + 0, scales, &sc, &m); const float d1 = d * sc; const float m1 = min * m; get_scale_min_k4(is + 1, scales, &sc, &m); const float d2 = d * sc; const float m2 = min * m; for (int l = 0; l < 32; ++l) *output_blk++ = d1 * (q[l] & 0xF) - m1; for (int l = 0; l < 32; ++l) *output_blk++ = d2 * (q[l] >> 4) - m2; q += 32; is += 2; } } } __global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * 144 + 0))); const float min = __half2float(*(reinterpret_cast(data + block_id * 144 + 2))); int is = 0; uint8_t sc, m; for (int j = 0; j < ele_per_blk; j += 64) { uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4); get_scale_min_k4(is + 0, scales, &sc, &m); const float d1 = d * sc; const float m1 = min * m; get_scale_min_k4(is + 1, scales, &sc, &m); const float d2 = d * sc; const float m2 = min * m; for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * (q[l] & 0xF) - m1); for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * (q[l] >> 4) - m2); q += 32; is += 2; } } } __global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * 144 + 0))); const float min = __half2float(*(reinterpret_cast(data + block_id * 144 + 2))); int is = 0; uint8_t sc, m; for (int j = 0; j < ele_per_blk; j += 64) { uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4); get_scale_min_k4(is + 0, scales, &sc, &m); const float d1 = d * sc; const float m1 = min * m; get_scale_min_k4(is + 1, scales, &sc, &m); const float d2 = d * sc; const float m2 = min * m; for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * (q[l] & 0xF) - m1); for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * (q[l] >> 4) - m2); q += 32; is += 2; } } } __global__ void dequantize_q5_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk); const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 0))); const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 2))); const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48); int is = 0; uint8_t sc, m; uint8_t u1 = 1, u2 = 2; uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4); for (int j = 0; j < 256; j += 64) { get_scale_min_k4(is + 0, scales, &sc, &m); const float d1 = d * sc; const float m1 = min * m; get_scale_min_k4(is + 1, scales, &sc, &m); const float d2 = d * sc; const float m2 = min * m; for (int l = 0; l < 32; ++l) *output_blk++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1; for (int l = 0; l < 32; ++l) *output_blk++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2; ql += 32; is += 2; u1 <<= 2; u2 <<= 2; } } } __global__ void dequantize_q5_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk); const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 0))); const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 2))); const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48); int is = 0; uint8_t sc, m; uint8_t u1 = 1, u2 = 2; uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4); for (int j = 0; j < 256; j += 64) { get_scale_min_k4(is + 0, scales, &sc, &m); const float d1 = d * sc; const float m1 = min * m; get_scale_min_k4(is + 1, scales, &sc, &m); const float d2 = d * sc; const float m2 = min * m; for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1); for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2); ql += 32; is += 2; u1 <<= 2; u2 <<= 2; } } } __global__ void dequantize_q5_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk); const float d = __half2float(*(reinterpret_cast(data + block_id * blk_size + 0))); const float min = __half2float(*(reinterpret_cast(data + block_id * blk_size + 2))); const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48); int is = 0; uint8_t sc, m; uint8_t u1 = 1, u2 = 2; uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4); for (int j = 0; j < 256; j += 64) { get_scale_min_k4(is + 0, scales, &sc, &m); const float d1 = d * sc; const float m1 = min * m; get_scale_min_k4(is + 1, scales, &sc, &m); const float d2 = d * sc; const float m2 = min * m; for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1); for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2); ql += 32; is += 2; u1 <<= 2; u2 <<= 2; } } } __global__ void dequantize_q6_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 208))); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128); const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192); for (int n = 0; n < ele_per_blk; n += 128) { for (int l = 0; l < 32; ++l) { int is = l/16; const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; output_blk[l + 0] = d * sc[is + 0] * q1; output_blk[l + 32] = d * sc[is + 2] * q2; output_blk[l + 64] = d * sc[is + 4] * q3; output_blk[l + 96] = d * sc[is + 6] * q4; } output_blk += 128; ql += 64; qh += 32; sc += 8; } } } __global__ void dequantize_q6_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 208))); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128); const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192); for (int n = 0; n < ele_per_blk; n += 128) { for (int l = 0; l < 32; ++l) { int is = l/16; const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; output_blk[l + 0] = __float2half(d * sc[is + 0] * q1); output_blk[l + 32] = __float2half(d * sc[is + 2] * q2); output_blk[l + 64] = __float2half(d * sc[is + 4] * q3); output_blk[l + 96] = __float2half(d * sc[is + 6] * q4); } output_blk += 128; ql += 64; qh += 32; sc += 8; } } } __global__ void dequantize_q6_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size + 208))); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128); const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192); for (int n = 0; n < ele_per_blk; n += 128) { for (int l = 0; l < 32; ++l) { int is = l/16; const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; output_blk[l + 0] = __float2bfloat16(d * sc[is + 0] * q1); output_blk[l + 32] = __float2bfloat16(d * sc[is + 2] * q2); output_blk[l + 64] = __float2bfloat16(d * sc[is + 4] * q3); output_blk[l + 96] = __float2bfloat16(d * sc[is + 6] * q4); } output_blk += 128; ql += 64; qh += 32; sc += 8; } } } static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; __global__ void dequantize_iq4_xs_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size))); const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4); for (int ib = 0; ib < 8; ++ib) { const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4); const float dl = d * (ls - 32); for (int j = 0; j < 16; ++j) { output_blk[j + 0] = dl * kvalues_iq4nl[qs[j] & 0xf]; output_blk[j + 16] = dl * kvalues_iq4nl[qs[j] >> 4]; } output_blk += 32; qs += 16; } } } __global__ void dequantize_iq4_xs_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size))); const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4); for (int ib = 0; ib < 8; ++ib) { const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4); const float dl = d * (ls - 32); for (int j = 0; j < 16; ++j) { output_blk[j + 0] = __float2half(dl * kvalues_iq4nl[qs[j] & 0xf]); output_blk[j + 16] = __float2half(dl * kvalues_iq4nl[qs[j] >> 4]); } output_blk += 32; qs += 16; } } } __global__ void dequantize_iq4_xs_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) { long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (long long block_id=global_idx; block_id(data + block_id * blk_size))); const uint16_t scales_h = *(reinterpret_cast(data + block_id * blk_size + 2)); const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4); for (int ib = 0; ib < 8; ++ib) { const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4); const float dl = d * (ls - 32); for (int j = 0; j < 16; ++j) { output_blk[j + 0] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] & 0xf]); output_blk[j + 16] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] >> 4]); } output_blk += 32; qs += 16; } } } torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto data_gpu = torch::empty({ num_bytes }, options); cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); //data_gpu.copy_(data, false); // Create output tensor auto output = torch::zeros({ num_blocks, 32 }, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: dequantize_q8_0_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kBFloat16: dequantize_q8_0_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kFloat32: dequantize_q8_0_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; default: printf("target type not support\n"); exit(0); } cudaDeviceSynchronize(); return output; } torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { // data.numel%blk_size should be 0, else raise err int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto data_gpu = torch::empty({num_bytes}, options); cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); //data_gpu.copy_(data, false); // Create output tensor auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: dequantize_q6_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kBFloat16: dequantize_q6_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kFloat32: dequantize_q6_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; default: printf("target type not support\n"); exit(0); } cudaDeviceSynchronize(); return output; } torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto data_gpu = torch::empty({num_bytes}, options); cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); //data_gpu.copy_(data, false); // Create output tensor auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: dequantize_q5_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kBFloat16: dequantize_q5_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kFloat32: dequantize_q5_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; default: printf("target type not support\n"); exit(0); } cudaDeviceSynchronize(); return output; } torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { // data.numel%blk_size should be 0, else raise err int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto data_gpu = torch::empty({num_bytes}, options); cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); //data_gpu.copy_(data, false); // Create output tensor auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: dequantize_q4_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kBFloat16: dequantize_q4_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kFloat32: dequantize_q4_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; default: printf("target type not support\n"); exit(0); } cudaDeviceSynchronize(); return output; } torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto data_gpu = torch::empty({num_bytes}, options); cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); //data_gpu.copy_(data, false); // Create output tensor auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: dequantize_q3_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kBFloat16: dequantize_q3_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kFloat32: dequantize_q3_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; default: printf("target type not support\n"); exit(0); } cudaDeviceSynchronize(); return output; } torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto data_gpu = torch::empty({num_bytes}, options); cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); //data_gpu.copy_(data, false); // Create output tensor auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: dequantize_q2_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kBFloat16: dequantize_q2_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kFloat32: dequantize_q2_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; default: printf("target type not support\n"); exit(0); } cudaDeviceSynchronize(); return output; } torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) { int num_blocks = num_bytes / blk_size; const at::cuda::OptionalCUDAGuard device_guard(device); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto data_gpu = torch::empty({num_bytes}, options); cudaMemcpy(data_gpu.data_ptr(), data, num_bytes, cudaMemcpyHostToDevice); //data_gpu.copy_(data, false); // Create output tensor auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: dequantize_iq4_xs_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kBFloat16: dequantize_iq4_xs_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; case torch::kFloat32: dequantize_iq4_xs_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr(), output.data_ptr(), blk_size, ele_per_blk, num_blocks); break; default: printf("target type not support\n"); exit(0); } cudaDeviceSynchronize(); return output; } ================================================ FILE: archive/csrc/ktransformers_ext/cuda/custom_gguf/ops.h ================================================ /** * @Description : * @Author : Azure-Tang * @Date : 2024-07-22 09:27:55 * @Version : 1.0.0 * @LastEditors : kkk1nak0 * @LastEditTime : 2024-08-12 03:48:46 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #pragma once #include #include #include torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype); ================================================ FILE: archive/csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu ================================================ /* * Modified by Neural Magic * Copyright (C) Marlin.2024 Elias Frantar * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Adapted from https://github.com/IST-DASLab/marlin */ /* * Adapted from https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin */ #include "gptq_marlin.cuh" #include "gptq_marlin_dtypes.cuh" #include #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert(std::is_same::value || \ std::is_same::value, \ "only float16 and bfloat16 is supported"); template inline std::string str(T x) { return std::to_string(x); } namespace gptq_marlin { #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__) __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, int4* __restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) {} template shared // fetch pipeline const bool has_act_order, // whether act_order is enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn int4* __restrict__ C, // fp16 output buffer of shape mxn const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m int prob_n, // output dimension n int prob_k, // reduction dimension k int* locks // extra global storage for barrier synchronization ) {} } // namespace gptq_marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full) { TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); } #else // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. template __device__ inline void mma(const typename ScalarType::FragA& a_frag, const typename ScalarType::FragB& frag_b, typename ScalarType::FragC& frag_c) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); float* c = reinterpret_cast(&frag_c); if constexpr (std::is_same::value) { asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } else if constexpr (std::is_same::value) { asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } else { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. template __device__ inline void ldsm4(typename ScalarType::FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)); } // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. template __device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); return res; } // Constructs destination register by taking bytes from 2 sources (based on // mask) template __device__ inline uint32_t prmt(uint32_t a) { uint32_t res; asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask)); return res; } // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 // values. We mostly follow the strategy in the link below, with some small // changes: // - FP16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 // - BF16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 template __device__ inline typename ScalarType::FragB dequant_4bit(int q) { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } template <> __device__ inline typename ScalarType::FragB dequant_4bit(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; // Guarantee that the `(a & b) | c` operations are LOP3s. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point // directly into `SUB` and `ADD`. const int SUB = 0x64086408; const int MUL = 0x2c002c00; const int ADD = 0xd480d480; typename ScalarType::FragB frag_b; frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); frag_b[1] = __hfma2(*reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); return frag_b; } template <> __device__ inline typename ScalarType::FragB dequant_4bit(int q) { static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t EX = 0x43004300; // Guarantee that the `(a & b) | c` operations are LOP3s. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); q >>= 4; int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); typename ScalarType::FragB frag_b; static constexpr uint32_t MUL = 0x3F803F80; static constexpr uint32_t ADD = 0xC308C308; frag_b[0] = __hfma2(*reinterpret_cast(&lo), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); frag_b[1] = __hfma2(*reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); return frag_b; } // Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or // bf16 Reference: // - FP16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 // - BF16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 template __device__ inline typename ScalarType::FragB dequant_8bit(int q) { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } template <> __device__ inline typename ScalarType::FragB dequant_8bit(int q) { static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t start_byte_for_fp16 = 0x64646464; uint32_t lo = prmt(q); uint32_t hi = prmt(q); static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; typename ScalarType::FragB frag_b; frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); frag_b[1] = __hsub2(*reinterpret_cast(&hi), *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); return frag_b; } template <> __device__ inline typename ScalarType::FragB dequant_8bit(int q) { typename ScalarType::FragB frag_b; float fp32_intermediates[4]; uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); static constexpr uint32_t fp32_base = 0x4B000000; fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); fp32_intermediates[0] -= 8388736.f; fp32_intermediates[1] -= 8388736.f; fp32_intermediates[2] -= 8388736.f; fp32_intermediates[3] -= 8388736.f; uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); return frag_b; } // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. template __device__ inline void scale(typename ScalarType::FragB& frag_b, typename ScalarType::FragS& frag_s, int i) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } // Same as above, but for act_order (each K is multiplied individually) template __device__ inline void scale4(typename ScalarType::FragB& frag_b, typename ScalarType::FragS& frag_s_1, typename ScalarType::FragS& frag_s_2, typename ScalarType::FragS& frag_s_3, typename ScalarType::FragS& frag_s_4, int i) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s_val_1_2; s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; scalar_t2 s_val_3_4; s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; frag_b[0] = __hmul2(frag_b[0], s_val_1_2); frag_b[1] = __hmul2(frag_b[1], s_val_3_4); } // Given 2 floats multiply by 2 scales (halves) template __device__ inline void scale_float(float* c, typename ScalarType::FragS& s) { scalar_t* s_ptr = reinterpret_cast(&s); c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. __device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { int state = -1; do // Guarantee that subsequent writes by this threadblock will be visible // globally. asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); while (state != count); } __syncthreads(); } // Release barrier and increment visitation count. __device__ inline void barrier_release(int* lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { lock[0] = 0; return; } int val = 1; // Make sure that all writes since acquiring this barrier are visible // globally, while releasing the barrier. asm volatile("fence.acq_rel.gpu;\n"); asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); } } // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, int4* __restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) { int start_row = block_rows * blockIdx.x; int finish_row = start_row + block_rows; if (finish_row > size_m) { finish_row = size_m; } int cur_block_rows = finish_row - start_row; int row_stride = size_k * sizeof(half) / 16; auto permute_row = [&](int row) { int iters = size_k / default_threads; int rest = size_k % default_threads; int offset = row * row_stride; half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); half* out_half = reinterpret_cast(out_int4_ptr + offset); int base_k = 0; for (int i = 0; i < iters; i++) { int cur_k = base_k + threadIdx.x; int src_pos = perm_int_ptr[cur_k]; out_half[cur_k] = a_row_half[src_pos]; base_k += default_threads; } if (rest) { if (threadIdx.x < rest) { int cur_k = base_k + threadIdx.x; int src_pos = perm_int_ptr[cur_k]; out_half[cur_k] = a_row_half[src_pos]; } } }; for (int i = 0; i < cur_block_rows; i++) { int cur_row = start_row + i; if (cur_row < size_m) { permute_row(cur_row); } } } template shared // fetch pipeline const bool has_act_order, // whether act_order is enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn int4* __restrict__ C, // fp16 output buffer of shape mxn const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn const int* __restrict__ g_idx, // int32 group indices of shape k int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m int prob_n, // output dimension n int prob_k, // reduction dimension k int* locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM // example: // 0 1 3 // 0 2 3 // 1 2 4 // While this kind of partitioning makes things somewhat more complicated, it // ensures good utilization of all SMs for many kinds of shape and GPU // configurations, while requiring as few slow global cross-threadblock // reductions as possible. using Dtype = ScalarType; using scalar_t2 = typename ScalarType::scalar_t2; using FragA = typename ScalarType::FragA; using FragB = typename ScalarType::FragB; using FragC = typename ScalarType::FragC; using FragS = typename ScalarType::FragS; constexpr int pack_factor = 32 / num_bits; // For larger GEMMs we run multiple batchsize 64 versions in parallel for a // better partitioning with less reductions int parallel = 1; if (prob_m > 16 * thread_m_blocks) { parallel = prob_m / (16 * thread_m_blocks); prob_m = 16 * thread_m_blocks; } int k_tiles = prob_k / 16 / thread_k_blocks; int n_tiles = prob_n / 16 / thread_n_blocks; int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); if constexpr (!has_act_order && group_blocks != -1) { if (group_blocks >= thread_k_blocks) { // Ensure that the number of tiles in each stripe is a multiple of the // groupsize; this avoids an annoying special case where a stripe starts // in the middle of group. iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks)); } } int slice_row = (iters * blockIdx.x) % k_tiles; int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col = slice_col_par; int slice_iters; // number of threadblock tiles in the current slice int slice_count = 0; // total number of active threadblocks in the current slice int slice_idx; // index of threadblock in current slice; numbered bottom to // top // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers if (slice_col_par >= n_tiles) { A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; locks += (slice_col_par / n_tiles) * n_tiles; slice_col = slice_col_par % n_tiles; } // Compute all information about the current slice which is required for // synchronization. auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; if (slice_iters == 0) return; if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = div_ceil(k_tiles - col_off, iters); if (col_off > 0) slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; if (col_off > 0) slice_idx--; } } if (slice_col == n_tiles) { A += 16 * thread_m_blocks * prob_k / 8; C += 16 * thread_m_blocks * prob_n / 8; locks += n_tiles; slice_col = 0; } }; init_slice(); // A sizes/strides // stride of the A matrix in global memory int a_gl_stride = prob_k / 8; // stride of an A matrix tile in shared memory constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // between subsequent accesses within a tile int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between shared memory writes constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory tile reads constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // within a shared memory tile constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // overall size of a tile constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // number of shared write iterations for a tile constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); // B sizes/strides int b_gl_stride = 16 * prob_n / (pack_factor * 4); constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); constexpr int b_sh_wr_delta = threads * b_thread_vecs; constexpr int b_sh_rd_delta = threads * b_thread_vecs; constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order int s_gl_stride = prob_n / 8; constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; // Scale size/strides with act_order constexpr int tb_k = 16 * thread_k_blocks; constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; // constexpr int act_s_row_stride = 1; // int act_s_col_stride = act_s_row_stride * num_groups; int act_s_col_stride = 1; int act_s_col_warp_stride = act_s_col_stride * 8; int tb_n_warps = thread_n_blocks / 4; int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; // Global A read index of current thread. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); a_gl_rd += a_gl_rd_delta_o * slice_row; // Shared write index of current thread. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); // Shared read index. int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; int b_sh_wr = threadIdx.x * b_thread_vecs; int b_sh_rd = threadIdx.x * b_thread_vecs; // For act_order constexpr int k_iter_size = tb_k / b_sh_wr_iters; int slice_k_start = tb_k * slice_row; int slice_k_finish = slice_k_start + tb_k * slice_iters; int slice_k_start_shared_fetch = slice_k_start; int slice_n_offset = act_s_col_tb_stride * slice_col; // No act_order int s_gl_rd; if constexpr (!has_act_order) { if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; } } int s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; // Precompute which thread should not read memory in which iterations; this is // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; // To ensure that writing and reading A tiles to/from shared memory, the // latter in fragment format, is fully bank conflict free, we need to use a // rather fancy XOR-based layout. The key here is that neither reads nor // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the // same shared memory banks. Further, it seems (based on NSight-Compute) that // each warp must also write a consecutive memory segment? auto transform_a = [&](int i) { int row = i / a_gl_rd_delta_o; return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; }; // Since the computation of this remapping is non-trivial and, due to our main // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll for (int j = 0; j < thread_m_blocks; j++) a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); } // Since B-accesses have non-constant stride they have to be computed at // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. const int4* B_ptr[b_sh_wr_iters]; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. int4* sh_a = sh; int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_g_idx = sh_b + (stages * b_sh_stage); int4* sh_s = sh_g_idx + (stages * g_idx_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; FragS frag_s[2][4]; // No act-order FragS act_frag_s[2][4][4]; // For act-order // Zero accumulators. auto zero_accums = [&]() { #pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) reinterpret_cast(frag_c)[i] = 0; }; int sh_first_group_id = -1; int sh_num_groups = -1; constexpr int sh_max_num_groups = 32; auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) { sh_first_group_id = first_group_id; sh_num_groups = last_group_id - first_group_id + 1; if (sh_num_groups < sh_max_num_groups) { sh_num_groups = sh_max_num_groups; } if (sh_first_group_id + sh_num_groups > num_groups) { sh_num_groups = num_groups - sh_first_group_id; } int row_offset = first_group_id * s_gl_stride; if (is_async) { for (int i = 0; i < sh_num_groups; i++) { if (threadIdx.x < s_sh_stride) { cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); } } } else { for (int i = 0; i < sh_num_groups; i++) { if (threadIdx.x < s_sh_stride) { sh_s[(i * s_sh_stride) + threadIdx.x] = scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; } } } }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { int4* sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll for (int j = 0; j < b_thread_vecs; j++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); } B_ptr[i] += b_gl_rd_delta_o; } if constexpr (has_act_order) { // Fetch g_idx thread-block portion int full_pipe = a_off; int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; if (cur_k < prob_k && cur_k < slice_k_finish) { int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; int4 const* cur_g_idx_stage_ptr = reinterpret_cast(&g_idx[cur_k]); if (threadIdx.x < g_idx_stage) { cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]); } } } else { if constexpr (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * pipe; if constexpr (group_blocks >= thread_k_blocks) { // Only fetch scales if this tile starts a new group if (pipe % (group_blocks / thread_k_blocks) == 0) { if (s_sh_wr_pred) { cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); } s_gl_rd += s_gl_rd_delta; } } else { for (int i = 0; i < s_tb_groups; i++) { if (s_sh_wr_pred) { cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); } s_gl_rd += s_gl_rd_delta; } } } } } // Insert a fence even when we are winding down the pipeline to ensure that // waiting is also correct at this point. cp_async_fence(); }; // Wait until the next thread tile has been loaded to shared memory. auto wait_for_stage = [&]() { // We only have `stages - 2` active fetches since we are double buffering // and can only issue the next fetch when it is guaranteed that the previous // shared memory load is fully complete (as it may otherwise be // overwritten). cp_async_wait(); __syncthreads(); }; // Load the next sub-tile from the current location in the shared memory pipe // into the current register buffer. auto fetch_to_registers = [&](int k, int pipe) { int4* sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll for (int i = 0; i < thread_m_blocks; i++) ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_thread_vecs; i++) { frag_b_quant[k % 2][i] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); } }; bool is_same_group[stages]; int same_group_id[stages]; auto init_same_group = [&](int pipe) { if constexpr (!has_act_order) { is_same_group[pipe] = false; same_group_id[pipe] = 0; return; } int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); int group_id_1 = sh_g_idx_int_ptr[0]; int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; is_same_group[pipe] = group_id_1 == group_id_2; same_group_id[pipe] = group_id_1; }; auto fetch_scales_to_registers = [&](int k, int full_pipe) { int pipe = full_pipe % stages; if constexpr (!has_act_order) { // No act-order case if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } else { int warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; int warp_row = warp_id / n_warps; int cur_k = warp_row * 16; cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; int cur_group_id = k_blocks / group_blocks; int4* sh_s_stage = sh_s + s_sh_stage * pipe; reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; } } return; } // Act-order case // Determine K of the "current" thread-block int cur_k = slice_k_start + tb_k * full_pipe; if (cur_k >= prob_k || cur_k >= slice_k_finish) { return; } // Reset (to current thread-block) since we read g_idx portion from the // shared memory cur_k = 0; // Progress to current iteration cur_k += k_iter_size * (k % b_sh_wr_iters); // Determine "position" inside the thread-block (based on warp and // thread-id) int warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N int warp_row = warp_id / n_warps; int warp_col = warp_id % n_warps; cur_k += warp_row * 16; int th_id = threadIdx.x % 32; cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix int s_col_shift = /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride; if (is_same_group[pipe]) { if (k % 2 == 0) { *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; } else { *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); } for (int i = 1; i < 4; i++) { *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); } return; } int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); constexpr int k_frag_offsets[4] = {0, 1, 8, 9}; // Tensor core offsets per thread #pragma unroll for (int i = 0; i < 4; i++) { int actual_k = cur_k + k_frag_offsets[i]; int group_id = sh_g_idx_int_ptr[actual_k]; int rel_group_id = group_id - sh_first_group_id; *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; } }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { FragB frag_b0; FragB frag_b1; if constexpr (num_bits == 4) { int b_quant = frag_b_quant[k % 2][0][j]; int b_quant_shift = b_quant >> 8; frag_b0 = dequant_4bit(b_quant); frag_b1 = dequant_4bit(b_quant_shift); } else { int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; frag_b0 = dequant_8bit(b_quant_0); frag_b1 = dequant_8bit(b_quant_1); } // Apply scale to frag_b0 if constexpr (has_act_order) { scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); } else { if constexpr (group_blocks != -1) { scale(frag_b0, frag_s[k % 2][j], 0); } } // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); } else { if constexpr (group_blocks != -1) { scale(frag_b1, frag_s[k % 2][j], 1); } } #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); } } }; // Since we slice across the k dimension of a tile in order to increase the // number of warps while keeping the n dimension of a tile reasonable, we have // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&]() { constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { int red_idx = threadIdx.x / b_sh_stride_threads; constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; constexpr int red_sh_delta = b_sh_stride_threads; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); // Parallel logarithmic shared memory reduction. We make sure to avoid any // unnecessary read or write iterations, e.g., for two warps we write only // once by warp 1 and read only once by warp 0. #pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { #pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { #pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { float* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); float* c_wr = reinterpret_cast(&sh[red_sh_wr]); #pragma unroll for (int k = 0; k < 4; k++) reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { #pragma unroll for (int i = 0; i < 4 * 2; i++) { float* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); #pragma unroll for (int j = 0; j < 4; j++) reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } __syncthreads(); } } }; // Since multiple threadblocks may process parts of the same column slice, we // finally have to globally reduce over the results. As the striped // partitioning minimizes the number of such reductions and our outputs are // usually rather small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out // results in FP16 (but still reduce with FP32 compute). constexpr int active_threads = 32 * thread_n_blocks / 4; if (threadIdx.x < active_threads) { int c_gl_stride = prob_n / 8; int c_gl_wr_delta_o = 8 * c_gl_stride; int c_gl_wr_delta_i = 4 * (active_threads / 32); int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; c_gl_wr += (2 * thread_n_blocks) * slice_col; constexpr int c_sh_wr_delta = active_threads; int c_sh_wr = threadIdx.x; int row = (threadIdx.x % 32) / 4; if (!first) { // Interestingly, doing direct global accesses here really seems to mess up // the compiler and lead to slowdowns, hence we also use async-copies even // though these fetches are not actually asynchronous. #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { cp_async4_pred( &sh[c_sh_wr + c_sh_wr_delta * i], &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); } cp_async_fence(); cp_async_wait<0>(); } #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; #pragma unroll for (int j = 0; j < 2 * 4; j++) { reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += Dtype::num2float(reinterpret_cast(&c_red)[j]); } } if (!last) { int4 c; #pragma unroll for (int j = 0; j < 2 * 4; j++) { reinterpret_cast(&c)[j] = Dtype::float2num(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; } } } } }; // Write out the reduce final result in the correct layout. We only actually // reshuffle matrix fragments in this step, the reduction above is performed // in fragment layout. auto write_result = [&]() { int c_gl_stride = prob_n / 8; constexpr int c_sh_stride = 2 * thread_n_blocks + 1; int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); c_gl_wr += (2 * thread_n_blocks) * slice_col; int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; c_sh_wr += 32 * (threadIdx.x / 32); int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); int c_gl_wr_end = c_gl_stride * prob_m; // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS& s) { scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); // For per-column quantization we finally apply the scale here (only for // 4-bit) if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) { res = __hmul2(res, s[0]); } ((scalar_t2*)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); } c_sh_wr += 16 * (4 * c_sh_stride); } } __syncthreads(); #pragma unroll for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { if (c_gl_wr < c_gl_wr_end) { C[c_gl_wr] = sh[c_sh_rd]; c_gl_wr += c_gl_wr_delta; c_sh_rd += c_sh_rd_delta; } } }; // Start global fetch and register load pipelines. auto start_pipes = [&]() { #pragma unroll for (int i = 0; i < stages - 1; i++) { if (has_act_order && i == 0) { int last_g_idx = slice_k_start + stages * tb_k * 2; if (last_g_idx >= prob_k) { last_g_idx = prob_k - 1; } fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); } fetch_to_shared(i, i, i < slice_iters); } zero_accums(); wait_for_stage(); init_same_group(0); fetch_to_registers(0, 0); fetch_scales_to_registers(0, 0); a_gl_rd += a_gl_rd_delta_o * (stages - 1); slice_k_start_shared_fetch += tb_k * (stages - 1); }; if (slice_iters) { start_pipes(); } // Main loop. while (slice_iters) { // We unroll over both the global fetch and the register load pipeline to // ensure all shared memory accesses are static. Note that both pipelines // have even length meaning that the next iteration will always start at // index 0. #pragma unroll for (int pipe = 0; pipe < stages;) { #pragma unroll for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); fetch_scales_to_registers(k + 1, pipe); if (k == b_sh_wr_iters - 2) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); pipe++; wait_for_stage(); init_same_group(pipe % stages); } matmul(k); } slice_iters--; if (slice_iters == 0) { break; } } a_gl_rd += a_gl_rd_delta_o * stages; slice_k_start += tb_k * stages; slice_k_start_shared_fetch += tb_k * stages; if constexpr (has_act_order) { int first_group_id = g_idx[slice_k_start]; int last_g_idx = slice_k_start + stages * tb_k * 2; if (last_g_idx >= prob_k) { last_g_idx = prob_k - 1; } int last_group_id = g_idx[last_g_idx]; if (last_group_id >= sh_first_group_id + sh_num_groups) { fetch_scales_to_shared(false, first_group_id, last_group_id); __syncthreads(); } } // Process results and, if necessary, proceed to the next column slice. // While this pattern may not be the most readable, other ways of writing // the loop seemed to noticeably worse performance after compilation. if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out if constexpr (!has_act_order && group_blocks == -1) { if constexpr (num_bits == 8) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); } else { if (last) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { if constexpr (num_bits == 8) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } else { if (last) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } } } // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { scale_float( reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); scale_float( reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + 0]); scale_float( reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); scale_float( reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); } } } } if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; slice_col_par++; slice_col++; init_slice(); if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; if (slice_col == 0) { #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } // Update slice k/n for scales loading if constexpr (has_act_order) { slice_k_start = tb_k * slice_row; slice_k_finish = slice_k_start + tb_k * slice_iters; slice_k_start_shared_fetch = slice_k_start; slice_n_offset = act_s_col_tb_stride * slice_col; } else { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } start_pipes(); } } } } #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ Marlin, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ Marlin<<>>( \ A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ prob_k, locks); \ } typedef struct { int thread_k; int thread_n; int num_threads; } thread_config_t; typedef struct { int max_m_blocks; thread_config_t tb_cfg; } exec_config_t; thread_config_t small_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads {128, 128, 256}, {64, 128, 128}, {128, 64, 128}, }; thread_config_t large_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads {64, 256, 256}, {64, 128, 128}, {128, 64, 128}, }; int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full) { bool cache_scales_chunk = has_act_order && !is_k_full; int tb_n = th_config.thread_n; int tb_k = th_config.thread_k; // Get max scale groups per thread-block int tb_groups; if (group_size == -1) { tb_groups = 1; } else if (group_size == 0) { tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size } else { tb_groups = div_ceil(tb_k, group_size); } if (cache_scales_chunk) { int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K load_groups = max(load_groups, 32); // We load at least 32 scale groups return load_groups * tb_n * 2; } else { int tb_scales = tb_groups * tb_n * 2; return tb_scales * pipe_stages; } } bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int scales_cache_size, int max_shared_mem) { int pack_factor = 32 / num_bits; // Get B size int tb_k = th_config.thread_k; int tb_n = th_config.thread_n; int b_size = (tb_k * tb_n / pack_factor) * 4; // Get A size int m_blocks = div_ceil(prob_m, 16); int tb_max_m = 16; while (true) { if (m_blocks >= max_m_blocks) { tb_max_m *= max_m_blocks; break; } max_m_blocks--; if (max_m_blocks == 0) { TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); } } int a_size = (tb_max_m * tb_k) * 2; float pipe_size = (a_size + b_size) * pipe_stages; TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); } bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { return false; } // Verify K/N are divisible by thread K/N if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { return false; } // Verify min for thread K/N if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { return false; } // num_threads must be at least 128 (= 4 warps) if (th_config.num_threads < 128) { return false; } // Determine cache for scales int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); // Check that pipeline fits into cache if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, num_bits, scales_cache_size, max_shared_mem)) { return false; } return true; } exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int max_shared_mem) { int max_m_blocks = 4; while (max_m_blocks > 0) { if (prob_m <= 16) { for (auto th_config : small_batch_thread_configs) { if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, max_shared_mem)) { return exec_config_t{max_m_blocks, th_config}; } } } else { for (auto th_config : large_batch_thread_configs) { if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, max_shared_mem)) { return exec_config_t{max_m_blocks, th_config}; } } } max_m_blocks--; // Process less M blocks per invocation to reduce cache // usage } return exec_config_t{0, {-1, -1, -1}}; } #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) template void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* g_idx, void* perm, void* a_tmp, int prob_m, int prob_n, int prob_k, void* workspace, int num_bits, bool has_act_order, bool is_k_full, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k, int thread_n, int sms, int max_par) { TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); int tot_m = prob_m; int tot_m_blocks = div_ceil(tot_m, 16); int pad = 16 * tot_m_blocks - tot_m; if (sms == -1) { cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); } int max_shared_mem = 0; cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); TORCH_CHECK(max_shared_mem > 0); // Set thread config exec_config_t exec_cfg; if (thread_k != -1 && thread_n != -1) { // User-defined config exec_cfg = exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; } else { // Auto config exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, max_shared_mem); } TORCH_CHECK(exec_cfg.max_m_blocks > 0 && is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, max_shared_mem), "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, ", thread_k = ", exec_cfg.tb_cfg.thread_k, ", thread_n = ", exec_cfg.tb_cfg.thread_n, ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, ", group_size = ", group_size, ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, ", max_shared_mem = ", max_shared_mem); int num_threads = exec_cfg.tb_cfg.num_threads; thread_k = exec_cfg.tb_cfg.thread_k; thread_n = exec_cfg.tb_cfg.thread_n; int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; int blocks = sms; TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); int group_blocks = 0; if (has_act_order) { if (is_k_full) { TORCH_CHECK(group_size != -1); group_blocks = group_size / 16; TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); } else { TORCH_CHECK(group_size == 0); group_blocks = 0; } } else { if (group_size == -1) { group_blocks = -1; } else { group_blocks = group_size / 16; TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); } } const int4* A_ptr = (const int4*)A; const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; const int4* s_ptr = (const int4*)s; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; int4* a_tmp_ptr = (int4*)a_tmp; int* locks = (int*)workspace; if (has_act_order) { // Permute A columns int block_rows = div_ceil(prob_m, blocks); permute_cols_kernel<<>>( A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows); A_ptr = a_tmp_ptr; } // If we have a full K, then we can run the non-act-order version of Marlin // (since the weight rows are reordered by increasing group ids, and by having // a full K, we have full original groups) if (is_k_full) { has_act_order = false; } // Main loop for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { int thread_m_blocks = tot_m_blocks - i; prob_m = tot_m - 16 * i; int par = 1; if (thread_m_blocks > exec_cfg.max_m_blocks) { // Note that parallel > 1 currently only works for inputs without any // padding par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); if (par > max_par) par = max_par; prob_m = (16 * exec_cfg.max_m_blocks) * par; i += exec_cfg.max_m_blocks * (par - 1); thread_m_blocks = exec_cfg.max_m_blocks; } // Define kernel configurations #define undefined_error TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + \ str(prob_n) + ", " + str(prob_k) + "]" + \ ", has_act_order = " + str(has_act_order) + \ ", num_groups = " + str(num_groups) + \ ", group_size = " + str(group_size) + \ ", thread_m_blocks = " + str(thread_m_blocks) + \ ", thread_n_blocks = " + str(thread_n_blocks) + \ ", thread_k_blocks = " + str(thread_k_blocks)); if (num_bits == 4 && num_threads == 256) { if (false) { } CALL_IF(4, 32, 2, 256) CALL_IF(4, 16, 4, 256) CALL_IF(4, 8, 8, 256) else { undefined_error } } else if (num_bits == 4 && num_threads == 128) { if (false) { } CALL_IF(4, 8, 4, 128) CALL_IF(4, 4, 8, 128) else { undefined_error } } else if (num_bits == 8 && num_threads == 256) { if (false) { } CALL_IF(8, 32, 2, 256) CALL_IF(8, 16, 4, 256) CALL_IF(8, 8, 8, 256) else { undefined_error } } else if (num_bits == 8 && num_threads == 128) { if (false) { } CALL_IF(8, 8, 4, 128) CALL_IF(8, 4, 8, 128) else { undefined_error } } else { undefined_error } A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; } } } // namespace gptq_marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full) { const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); // Verify num_bits TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); int pack_factor = 32 / num_bits; // Verify A TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), ", size_m = ", size_m); TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), ", size_k = ", size_k); // Verify B TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, " is not divisible by tile_size = ", gptq_marlin::tile_size); TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, "b_q_weight.size(1) = ", b_q_weight.size(1), " is not divisible by tile_size = ", gptq_marlin::tile_size); int actual_size_n = (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, ", actual_size_n = ", actual_size_n); // Verify device and strides TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); // Alloc buffers auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); torch::Tensor c = torch::empty({size_m, size_n}, options); torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); // thread_k: `k` size of a thread_tile in `weights` (can usually be left as // auto -1) int thread_k = -1; // thread_n: `n` size of a thread_tile in `weights` (can usually be left as // auto -1) int thread_n = -1; // sms: number of SMs to use for the kernel (can usually be left as auto -1) int sms = -1; // Verify g_idx and perm TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || (g_idx.size(0) == size_k && perm.size(0) == size_k), "Unexpected g_idx.size(0) = ", g_idx.size(0), " and perm.size(0) = ", perm.size(0), ", where size_k = ", size_k); // Detect groupsize and act_order int num_groups = -1; int group_size = -1; bool has_act_order = g_idx.size(0) != 0; int b_rank = b_scales.sizes().size(); TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), " is not size_n = ", size_n); num_groups = b_scales.size(0); if (has_act_order) { if (is_k_full) { TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); group_size = size_k / num_groups; } else { group_size = 0; } } else { if (num_groups > 1) { TORCH_CHECK( size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); group_size = size_k / num_groups; } else { group_size = -1; } } // Verify workspace size TORCH_CHECK( size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); int min_workspace_size = (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, "workspace.numel = ", workspace.numel(), " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { gptq_marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, gptq_marlin::max_par); } else if (a.scalar_type() == at::ScalarType::BFloat16) { gptq_marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, gptq_marlin::max_par); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } return c; } #endif ================================================ FILE: archive/csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh ================================================ // Adapted from // https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin // Copyrigth 2024 The vLLM team. // Copyright (c) 2024 by KVCache.AI, All Rights Reserved. #pragma once #include #include #include #include #include #include #include namespace gptq_marlin { // 8 warps are a good choice since every SM has 4 schedulers and having more // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. static constexpr int default_threads = 256; static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; static constexpr int tile_size = 16; static constexpr int max_par = 16; template struct Vec { T elems[n]; __device__ T& operator[](int i) { return elems[i]; } }; using I4 = Vec; constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined (__HIP_PLATFORM_AMD__) // No support for async #else __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "{\n" " .reg .pred p;\n" " setp.ne.b32 p, %0, 0;\n" " @p cp.async.cg.shared.global [%1], [%2], %3;\n" "}\n" ::"r"((int)pred), "r"(smem), "l"(glob_ptr), "n"(BYTES)); } __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "{\n" " cp.async.cg.shared.global [%0], [%1], %2;\n" "}\n" ::"r"(smem), "l"(glob_ptr), "n"(BYTES)); } __device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); } template __device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); } #endif } // namespace gptq_marlin ================================================ FILE: archive/csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh ================================================ // Adapted from // https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin // Copyrigth 2024 The vLLM team. // Copyright (c) 2024 by KVCache.AI, All Rights Reserved. #ifndef _data_types_cuh #define _data_types_cuh #include "gptq_marlin.cuh" #include #include #ifdef __HIP_PLATFORM_AMD__ typedef __hip_bfloat16 nv_bfloat16; typedef __hip_bfloat162 nv_bfloat162; #endif namespace gptq_marlin { template class ScalarType {}; template <> class ScalarType { public: using scalar_t = half; using scalar_t2 = half2; // Matrix fragments for tensor core instructions; their precise layout is // documented here: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; static __device__ float inline num2float(const half x) { return __half2float(x); } static __device__ half2 inline num2num2(const half x) { return __half2half2(x); } static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); } static __host__ __device__ half inline float2num(const float x) { return __float2half(x); } }; template <> class ScalarType { public: using scalar_t = nv_bfloat16; using scalar_t2 = nv_bfloat162; using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); } static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); } static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); } static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); } #endif }; } // namespace gptq_marlin #endif ================================================ FILE: archive/csrc/ktransformers_ext/cuda/gptq_marlin/ops.h ================================================ /** * @Description : * @Author : Azure * @Date : 2024-07-22 09:27:55 * @Version : 1.0.0 * @LastEditors : Azure * @LastEditTime : 2024-07-26 08:35:00 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #pragma once #include #include #include torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full); // torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, // int64_t size_k, int64_t size_n, // int64_t num_bits); ================================================ FILE: archive/csrc/ktransformers_ext/cuda/setup.py ================================================ from setuptools import setup, Extension from torch.utils import cpp_extension from torch.utils.cpp_extension import BuildExtension, CUDAExtension setup( name='KTransformersOps', ext_modules=[ CUDAExtension( 'KTransformersOps', [ 'custom_gguf/dequant.cu', 'binding.cpp', 'gptq_marlin/gptq_marlin.cu', # 'gptq_marlin_repack.cu', ], extra_compile_args={ 'cxx': ['-O3'], 'nvcc': [ '-O3', '--use_fast_math', '-Xcompiler', '-fPIC', ] }, ) ], cmdclass={'build_ext': BuildExtension} ) ================================================ FILE: archive/csrc/ktransformers_ext/cuda/test_dequant.py ================================================ import os import sys sys.path.insert(0,"/home/zbx/ktransformers") from ktransformers.util.custom_loader import GGUFLoader import torch gguf_loader_1 = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf") gguf_loader_2 = GGUFLoader("/mnt/data/chenht/model/gguf_for_ktransformers/DeepSeek-V3-bf16/") torch.set_default_dtype(torch.bfloat16) tensor_1 = gguf_loader_1.load_gguf_tensor("blk.0.attn_kv_a_mqa.weight", "cuda") tensor_2 = gguf_loader_2.load_gguf_tensor("blk.0.attn_kv_a_mqa.weight", "cuda") print(tensor_1[0, -64:]) print(tensor_2[0, -64:]) ================================================ FILE: archive/csrc/ktransformers_ext/examples/test_attention.py ================================================ #!/usr/bin/env python # coding=utf-8 """ Description : Author : Jianwei Dong Date : 2024-08-28 10:32:05 Version : 1.0.0 LastEditors : chenht2022 LastEditTime : 2024-08-28 10:32:05 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ import os, sys import time sys.path.append(os.path.dirname(__file__) + "/../build") import cpuinfer_ext from flash_attn import flash_attn_with_kvcache import torch layer_num = 10 kv_head_num = 8 q_head_num = 32 head_dim = 128 block_len = 128 anchor_num = 1 cache_seqlen = 8192 cache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device="cpu") seqlens_zero = torch.zeros((1,), dtype=torch.int32, device="cpu") anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC kv_type = cpuinfer_ext.kvcache.ggml_type.FP16 retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER layer_step: int = 1 token_step: int = 1 layer_offset: int = 0 max_thread_num: int = 2 max_batch_size: int = 1 max_block_num: int = 512 CPUInfer = cpuinfer_ext.CPUInfer(max_thread_num) validation_iter = 100 with torch.inference_mode(mode=True): config = cpuinfer_ext.kvcache.KVCacheConfig( layer_num, kv_head_num, q_head_num, head_dim, block_len, anchor_num, anchor_type, kv_type, retrieval_type, layer_step, token_step, layer_offset, max_block_num, max_batch_size, max_thread_num, ) local_kvcache = cpuinfer_ext.kvcache.KVCache(config) kvcaches = [] block_table = ( torch.arange(max_block_num, dtype=torch.int32, device="cpu") .contiguous() .view(1, -1) ) for layer_idx in range(layer_num): k_cache = torch.randn( (1, cache_seqlen, kv_head_num, head_dim), dtype=torch.float16, device="cpu" ).contiguous() v_cache = torch.randn( (1, cache_seqlen, kv_head_num, head_dim), dtype=torch.float16, device="cpu" ).contiguous() CPUInfer.submit( local_kvcache.update_kvcache_fp16( k_cache.data_ptr(), v_cache.data_ptr(), layer_idx, block_table.data_ptr(), 1, max_block_num, seqlens_zero.data_ptr(), cache_seqlen, ) ) CPUInfer.sync() kvcaches.append((k_cache.to("cuda"), v_cache.to("cuda"))) # validation for i in range(validation_iter): k_cache = kvcaches[i % layer_num][0] v_cache = kvcaches[i % layer_num][1] input = torch.randn( (1, 1, q_head_num, head_dim), dtype=torch.float16, device="cpu" ).contiguous() output = torch.empty( (1, 1, q_head_num, head_dim), dtype=torch.float16, device="cpu" ).contiguous() # attn_lse: (bsz, q_len, q_head_num) attn_lse = torch.empty( (1, 1, q_head_num), dtype=torch.float32, device="cpu" ).contiguous() input = input / 100 CPUInfer.submit( local_kvcache.attn( input.data_ptr(), output.data_ptr(), attn_lse.data_ptr(), i % layer_num, 0, 1, 1, max_block_num, block_table.data_ptr(), cache_seqlens.data_ptr(), -1, -1, -1, ) ) CPUInfer.sync() # print("cpuinfer output", output) t_output = flash_attn_with_kvcache( q=input.to("cuda"), k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens.to("cuda"), ) # print("torch output", t_output) diff = torch.mean(torch.abs(output.to("cuda") - t_output)) / torch.mean( torch.abs(t_output) ) print("diff = ", diff) assert diff < 0.001 ================================================ FILE: archive/csrc/ktransformers_ext/examples/test_linear.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : chenht2022 Date : 2024-07-25 10:32:05 Version : 1.0.0 LastEditors : chenht2022 LastEditTime : 2024-08-06 10:36:59 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import os, sys import time sys.path.append(os.path.dirname(__file__) + '/../build') import cpuinfer_ext import torch input_size = 16384 output_size = 5120 stride = 32 group_max_len = 1024 proj_type = 1 # ggml_type::GGML_TYPE_F16 hidden_type = 1 # ggml_type::GGML_TYPE_F16 qlen = 30 layer_num = 10 CPUInfer = cpuinfer_ext.CPUInfer(48) validation_iter = 100 with torch.inference_mode(mode=True): linears = [] projs = [] for _ in range(layer_num): proj = torch.randn((output_size, input_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous() config = cpuinfer_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type) linear = cpuinfer_ext.linear.Linear(config) projs.append(proj) linears.append(linear) # validation for i in range(validation_iter): linear = linears[i % layer_num] input = torch.randn((qlen, input_size), dtype=torch.float16).contiguous() output = torch.empty((qlen, output_size), dtype=torch.float16).contiguous() input = input / 100 CPUInfer.submit( linear.forward( qlen, input.data_ptr(), output.data_ptr() ) ) CPUInfer.sync() # print('cpuinfer output', output) proj = projs[i%layer_num] t_output = torch.mm(input, proj.t()) # print('torch output', t_output) diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output)) print('diff = ', diff) assert(diff < 0.001) ================================================ FILE: archive/csrc/ktransformers_ext/examples/test_mlp.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : chenht2022 Date : 2024-07-25 10:32:05 Version : 1.0.0 LastEditors : chenht2022 LastEditTime : 2024-08-06 10:37:28 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import os, sys import time sys.path.append(os.path.dirname(__file__) + '/../build') import cpuinfer_ext import torch hidden_size = 5120 intermediate_size = 3072 stride = 32 group_max_len = 1024 gate_type = 1 # ggml_type::GGML_TYPE_F16 up_type = 1 # ggml_type::GGML_TYPE_F16 down_type = 1 # ggml_type::GGML_TYPE_F16 hidden_type = 1 # ggml_type::GGML_TYPE_F16 qlen = 30 layer_num = 10 CPUInfer = cpuinfer_ext.CPUInfer(48) validation_iter = 100 def act_fn(x): return x / (1.0 + torch.exp(-x)) def mlp_torch(input, gate_proj, up_proj, down_proj): gate_buf = torch.mm(input, gate_proj.t()) up_buf = torch.mm(input, up_proj.t()) intermediate = act_fn(gate_buf) * up_buf ret = torch.mm(intermediate, down_proj.t()) return ret with torch.inference_mode(mode=True): mlps = [] gate_projs = [] up_projs = [] down_projs = [] for _ in range(layer_num): gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous() up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous() down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous() config = cpuinfer_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type) mlp = cpuinfer_ext.mlp.MLP(config) gate_projs.append(gate_proj) up_projs.append(up_proj) down_projs.append(down_proj) mlps.append(mlp) # validation for i in range(validation_iter): mlp = mlps[i % layer_num] input = torch.randn((qlen, hidden_size), dtype=torch.float16).contiguous() output = torch.empty((qlen, hidden_size), dtype=torch.float16).contiguous() input = input / 100 CPUInfer.submit( mlp.forward( qlen, input.data_ptr(), output.data_ptr() ) ) CPUInfer.sync() # print('cpuinfer output', output) gate_proj = gate_projs[i%layer_num] up_proj = up_projs[i%layer_num] down_proj = down_projs[i%layer_num] t_output = mlp_torch(input, gate_proj, up_proj, down_proj) # print('torch output', t_output) diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output)) print('diff = ', diff) assert(diff < 0.001) ================================================ FILE: archive/csrc/ktransformers_ext/examples/test_moe.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : chenht2022 Date : 2024-07-25 10:32:05 Version : 1.0.0 LastEditors : chenht2022 LastEditTime : 2024-08-06 10:38:05 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import os, sys import time sys.path.append(os.path.dirname(__file__) + '/../build') import cpuinfer_ext import torch expert_num = 160 hidden_size = 5120 intermediate_size = 1536 stride = 32 group_min_len = 10 group_max_len = 1024 gate_type = 1 # ggml_type::GGML_TYPE_F16 up_type = 1 # ggml_type::GGML_TYPE_F16 down_type = 1 # ggml_type::GGML_TYPE_F16 hidden_type = 1 # ggml_type::GGML_TYPE_F16 n_routed_experts = 6 qlen = 30 layer_num = 10 CPUInfer = cpuinfer_ext.CPUInfer(48) validation_iter = 100 def act_fn(x): return x / (1.0 + torch.exp(-x)) def mlp_torch(input, gate_proj, up_proj, down_proj): gate_buf = torch.mm(input, gate_proj.t()) up_buf = torch.mm(input, up_proj.t()) intermediate = act_fn(gate_buf) * up_buf ret = torch.mm(intermediate, down_proj.t()) return ret def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj): cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num)) cnts.scatter_(1, expert_ids, 1) tokens_per_expert = cnts.sum(dim=0) idxs = expert_ids.view(-1).argsort() sorted_tokens = input[idxs // expert_ids.shape[1]] outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]) outputs.append(expert_out) start_idx = end_idx outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) new_x = torch.empty_like(outs) new_x[idxs] = outs t_output = ( new_x.view(*expert_ids.shape, -1) .type(weights.dtype) .mul_(weights.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype) ) return t_output with torch.inference_mode(mode=True): moes = [] gate_projs = [] up_projs = [] down_projs = [] for _ in range(layer_num): gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous() up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous() down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous() config = cpuinfer_ext.moe.MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, stride, group_min_len, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type) moe = cpuinfer_ext.moe.MOE(config) gate_projs.append(gate_proj) up_projs.append(up_proj) down_projs.append(down_proj) moes.append(moe) # validation for i in range(validation_iter): expert_ids = torch.stack([torch.randperm(expert_num)[:n_routed_experts] for _ in range(qlen)]).contiguous() weights = torch.rand((qlen, n_routed_experts), dtype=torch.float32).contiguous() input = torch.randn((qlen, hidden_size), dtype=torch.float16).contiguous() output = torch.empty((qlen, hidden_size), dtype=torch.float16).contiguous() input = input / 100 moe = moes[i % layer_num] CPUInfer.submit( moe.forward( qlen, n_routed_experts, expert_ids.data_ptr(), weights.data_ptr(), input.data_ptr(), output.data_ptr() ) ) CPUInfer.sync() # print('cpuinfer output', output) gate_proj = gate_projs[i%layer_num] up_proj = up_projs[i%layer_num] down_proj = down_projs[i%layer_num] t_output = moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj) # print('torch output', t_output) diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output)) print('diff = ', diff) assert(diff < 0.001) ================================================ FILE: archive/csrc/ktransformers_ext/ext_bindings.cpp ================================================ /** * @Description : * @Author : chenht2022, Jianwei Dong * @Date : 2024-07-22 02:03:22 * @Version : 1.0.0 * @LastEditors : Jianwei Dong * @LastEditTime : 2024-08-26 22:47:06 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ // Python bindings #include "cpu_backend/cpuinfer.h" #if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU) && !defined(KTRANSFORMERS_USE_NPU) #include "device_launch_parameters.h" #endif #include "llamafile/flags.h" #include "operators/kvcache/kvcache.h" #include "operators/llamafile/linear.h" #include "operators/llamafile/mlp.h" #include "operators/llamafile/moe.h" #if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__) #include "operators/amx/moe.hpp" #endif #include "pybind11/functional.h" #include "pybind11/operators.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" #include #include #include namespace py = pybind11; using namespace pybind11::literals; // Binding functions for the KVCache class class KVCacheBindings { public: class AttnBindings { public: struct Args { CPUInfer *cpuinfer; KVCache *kv_cache; const ggml_fp16_t *q_in; ggml_fp16_t *output; float *attn_lse; int layer_idx; int generate_token_idx; int q_len; int batch_size; int max_block_num; int *block_table; int *cache_seqlens; int pick_block_num; int init_block_num; int local_block_num; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue( &KVCache::attn, args_->kv_cache, args_->q_in, args_->output, args_->attn_lse, args_->layer_idx, args_->generate_token_idx, args_->q_len, args_->batch_size, args_->max_block_num, args_->block_table, args_->cache_seqlens, args_->pick_block_num, args_->init_block_num, args_->local_block_num); } static std::pair cpuinfer_interface(KVCache &kv_cache, intptr_t q_in, intptr_t output, intptr_t attn_lse, int layer_idx, int generate_token_idx, int q_len, int batch_size, int max_block_num, intptr_t block_table, intptr_t cache_seqlens, int pick_block_num, int init_block_num, int local_block_num) { Args *args = new Args{nullptr, &kv_cache, (const ggml_fp16_t *)q_in, (ggml_fp16_t *)output, (float *)attn_lse, layer_idx, generate_token_idx, q_len, batch_size, max_block_num, (int *)block_table, (int *)cache_seqlens, pick_block_num, init_block_num, local_block_num}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; class GetAllKVCacheOneLayerBindings { public: struct Args { CPUInfer *cpuinfer; KVCache *kv_cache; int layer_id; ggml_fp16_t *k_in; ggml_fp16_t *v_in; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue(&KVCache::get_all_kvcache_one_layer, args_->kv_cache, args_->layer_id, args_->k_in, args_->v_in); } static std::pair cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in, int layer_id) { Args *args = new Args{nullptr, &kv_cache, layer_id, (ggml_fp16_t *)k_in, (ggml_fp16_t *)v_in}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; class GetAndUpdateKVCacheFp16Bindings { public: struct Args { CPUInfer *cpuinfer; KVCache *kv_cache; ggml_fp16_t *k_in; ggml_fp16_t *v_in; int layer_id; int *block_table; int batch_size; int max_block_num; int *cache_seqlens; int q_len; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue(&KVCache::get_and_update_kvcache_fp16, args_->kv_cache, args_->k_in, args_->v_in, args_->layer_id, args_->block_table, args_->batch_size, args_->max_block_num, args_->cache_seqlens, args_->q_len); } static std::pair cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in, int layer_id, intptr_t block_table, int batch_size, int max_block_num, intptr_t cache_seqlens, int q_len) { Args *args = new Args{nullptr, &kv_cache, (ggml_fp16_t *)k_in, (ggml_fp16_t *)v_in, layer_id, (int *)block_table, batch_size, max_block_num, (int *)cache_seqlens, q_len}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; class GetKVCacheFp16Bindings { public: struct Args { CPUInfer *cpuinfer; KVCache *kv_cache; ggml_fp16_t *k_in; ggml_fp16_t *v_in; int layer_id; int *block_table; int batch_size; int max_block_num; int *cache_seqlens; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue( &KVCache::get_kvcache_fp16, args_->kv_cache, args_->k_in, args_->v_in, args_->layer_id, args_->block_table, args_->batch_size, args_->max_block_num, args_->cache_seqlens); } static std::pair cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in, int layer_id, intptr_t block_table, int batch_size, int max_block_num, intptr_t cache_seqlens) { Args *args = new Args{nullptr, &kv_cache, (ggml_fp16_t *)k_in, (ggml_fp16_t *)v_in, layer_id, (int *)block_table, batch_size, max_block_num, (int *)cache_seqlens}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; class UpdateKVCacheFp16Bindings { public: struct Args { CPUInfer *cpuinfer; KVCache *kv_cache; ggml_fp16_t *k_in; ggml_fp16_t *v_in; int layer_id; int *block_table; int batch_size; int max_block_num; int *cache_seqlens; int q_len; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue(&KVCache::update_kvcache_fp16, args_->kv_cache, args_->k_in, args_->v_in, args_->layer_id, args_->block_table, args_->batch_size, args_->max_block_num, args_->cache_seqlens, args_->q_len); } static std::pair cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in, int layer_id, intptr_t block_table, int batch_size, int max_block_num, intptr_t cache_seqlens, int q_len) { Args *args = new Args{nullptr, &kv_cache, (ggml_fp16_t *)k_in, (ggml_fp16_t *)v_in, layer_id, (int *)block_table, batch_size, max_block_num, (int *)cache_seqlens, q_len}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; class UpdateImportanceBindings { public: struct Args { CPUInfer *cpuinfer; KVCache *kv_cache; const ggml_fp16_t *importance; int layer_id; int *block_table; int batch_size; int max_block_num; int *offset; int width; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue( &KVCache::update_importance, args_->kv_cache, args_->importance, args_->layer_id, args_->block_table, args_->batch_size, args_->max_block_num, args_->offset, args_->width); } static std::pair cpuinfer_interface(KVCache &kv_cache, intptr_t importance, int layer_id, intptr_t block_table, int batch_size, int max_block_num, intptr_t offset, int width) { Args *args = new Args{nullptr, &kv_cache, (const ggml_fp16_t *)importance, layer_id, (int *)block_table, batch_size, max_block_num, (int *)offset, width}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; class AttnWithKVCacheBindings { public: struct Args { CPUInfer *cpuinfer; KVCache *kv_cache; const ggml_fp16_t *q_in; const ggml_fp16_t *k_in; const ggml_fp16_t *v_in; ggml_fp16_t *output; float *attn_lse; int layer_idx; int generate_token_idx; int q_len; int batch_size; int max_block_num; int *block_table; int *cache_seqlens; int topk; int local; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue( &KVCache::attn_with_kvcache, args_->kv_cache, args_->q_in, args_->k_in, args_->v_in, args_->output, args_->attn_lse, args_->layer_idx, args_->generate_token_idx, args_->q_len, args_->batch_size, args_->max_block_num, args_->block_table, args_->cache_seqlens, args_->topk, args_->local); } static std::pair cpuinfer_interface(KVCache &kv_cache, intptr_t q_in, intptr_t k_in, intptr_t v_in, intptr_t output, intptr_t attn_lse, int layer_idx, int generate_token_idx, int q_len, int batch_size, int max_block_num, intptr_t block_table, intptr_t cache_seqlens, int topk, int local) { Args *args = new Args{nullptr, &kv_cache, (const ggml_fp16_t *)q_in, (const ggml_fp16_t *)k_in, (const ggml_fp16_t *)v_in, (ggml_fp16_t *)output, (float *)attn_lse, layer_idx, generate_token_idx, q_len, batch_size, max_block_num, (int *)block_table, (int *)cache_seqlens, topk, local}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; class ClearImportanceAllLayersBindings { public: struct Args { CPUInfer *cpuinfer; KVCache *kv_cache; int *block_table; int *cache_seqlens; int batch_size; int max_block_num; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue(&KVCache::clear_importance_all_layers, args_->kv_cache, args_->block_table, args_->cache_seqlens, args_->batch_size, args_->max_block_num); } static std::pair cpuinfer_interface(KVCache &kv_cache, intptr_t block_table, intptr_t cache_seqlens, int batch_size, int max_block_num) { Args *args = new Args{nullptr, &kv_cache, (int *)block_table, (int *)cache_seqlens, batch_size, max_block_num}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; class CalcAnchorAllLayersBindinds { public: struct Args { CPUInfer *cpuinfer; KVCache *kv_cache; int *block_table; int *cache_seqlens; int batch_size; int max_block_num; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue(&KVCache::calc_anchor_all_layers, args_->kv_cache, args_->block_table, args_->cache_seqlens, args_->batch_size, args_->max_block_num); } static std::pair cpuinfer_interface(KVCache &kv_cache, intptr_t block_table, intptr_t cache_seqlens, int batch_size, int max_block_num) { Args *args = new Args{nullptr, &kv_cache, (int *)block_table, (int *)cache_seqlens, batch_size, max_block_num}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; class LoadKVCacheBindings { public: struct Args { CPUInfer *cpuinfer; KVCache *kv_cache; std::string tensor_file_path; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue(&KVCache::load_kvcache, args_->kv_cache, args_->tensor_file_path); } static std::pair cpuinfer_interface(KVCache &kv_cache, std::string tensor_file_path) { Args *args = new Args{nullptr, &kv_cache, (std::string)tensor_file_path}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; class DumpKVCacheBindings { public: struct Args { CPUInfer *cpuinfer; KVCache *kv_cache; int *block_table; int cache_total_len; std::string tensor_file_path; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue(&KVCache::dump_kvcache, args_->kv_cache, args_->block_table, args_->cache_total_len, args_->tensor_file_path); } static std::pair cpuinfer_interface(KVCache &kv_cache, intptr_t block_table, int cache_total_len, std::string tensor_file_path) { Args *args = new Args{nullptr, &kv_cache, (int *)block_table, cache_total_len, (std::string)tensor_file_path}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; }; class LinearBindings { public: class WarmUpBindinds { public: struct Args { CPUInfer *cpuinfer; Linear *linear; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue(&Linear::warm_up, args_->linear); } static std::pair cpuinfer_interface(Linear &linear) { Args *args = new Args{nullptr, &linear}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; class ForwardBindings { public: struct Args { CPUInfer *cpuinfer; Linear *linear; int qlen; const void *input; void *output; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue(&Linear::forward, args_->linear, args_->qlen, args_->input, args_->output); } static std::pair cpuinfer_interface(Linear &linear, int qlen, intptr_t input, intptr_t output) { Args *args = new Args{nullptr, &linear, qlen, (const void *)input, (void *)output}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; }; class MLPBindings { public: class WarmUpBindinds { public: struct Args { CPUInfer *cpuinfer; MLP *mlp; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue(&MLP::warm_up, args_->mlp); } static std::pair cpuinfer_interface(MLP &mlp) { Args *args = new Args{nullptr, &mlp}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; class ForwardBindings { public: struct Args { CPUInfer *cpuinfer; MLP *mlp; int qlen; const void *input; void *output; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue(&MLP::forward, args_->mlp, args_->qlen, args_->input, args_->output); } static std::pair cpuinfer_interface(MLP &mlp, int qlen, intptr_t input, intptr_t output) { Args *args = new Args{nullptr, &mlp, qlen, (const void *)input, (void *)output}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; }; class MOEBindings { public: class WarmUpBindinds { public: struct Args { CPUInfer *cpuinfer; MOE *moe; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue(&MOE::warm_up, args_->moe); } static std::pair cpuinfer_interface(MOE &moe) { Args *args = new Args{nullptr, &moe}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; class ForwardBindings { public: struct Args { CPUInfer *cpuinfer; MOE *moe; int qlen; int k; const uint64_t *expert_ids; const float *weights; const void *input; void *output; int *batch_size_tensor; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue( &MOE::forward, args_->moe, args_->qlen, args_->k, args_->expert_ids, args_->weights, args_->input, args_->output, args_->batch_size_tensor); } static std::pair cpuinfer_interface(MOE &moe, int qlen, int k, intptr_t expert_ids, intptr_t weights, intptr_t input, intptr_t output, intptr_t batch_size_tensor) { Args *args = new Args{nullptr, &moe, qlen, k, (const uint64_t *)expert_ids, (const float *)weights, (const void *)input, (void *)output, (int *)batch_size_tensor}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; }; #if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__) template class AMX_MOEBindings { public: class WarmUpBindings { public: struct Args { CPUInfer *cpuinfer; AMX_MOE *moe; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue(&AMX_MOE::warm_up, args_->moe); } static std::pair cpuinfer_interface(AMX_MOE &moe) { Args *args = new Args{nullptr, &moe}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; class LoadWeightsBindings { public: struct Args { CPUInfer *cpuinfer; AMX_MOE *moe; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue(&AMX_MOE::load_weights, args_->moe); } static std::pair cpuinfer_interface(AMX_MOE &moe) { Args *args = new Args{nullptr, &moe}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; class ForwardBindings { public: struct Args { CPUInfer *cpuinfer; AMX_MOE *moe; int qlen; int k; const uint64_t *expert_ids; const float *weights; const void *input; void *output; int *batch_size_tensor; }; static void inner(void *args) { Args *args_ = (Args *)args; args_->cpuinfer->enqueue( &AMX_MOE::forward, args_->moe, args_->qlen, args_->k, args_->expert_ids, args_->weights, args_->input, args_->output, args_->batch_size_tensor); } static std::pair cpuinfer_interface(AMX_MOE &moe, int qlen, int k, intptr_t expert_ids, intptr_t weights, intptr_t input, intptr_t output, intptr_t batch_size_tensor) { Args *args = new Args{nullptr, &moe, qlen, k, (const uint64_t *)expert_ids, (const float *)weights, (const void *)input, (void *)output, (int *)batch_size_tensor}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; }; #endif PYBIND11_MODULE(cpuinfer_ext, m) { py::class_(m, "CPUInfer") .def(py::init()) .def("submit", &CPUInfer::submit) .def("submit_with_cuda_stream", &CPUInfer::submit_with_cuda_stream) .def("sync", &CPUInfer::sync) .def("sync_with_cuda_stream", &CPUInfer::sync_with_cuda_stream); auto linear_module = m.def_submodule("linear"); py::class_(linear_module, "LinearConfig") .def(py::init([](int hidden_size, int intermediate_size, int stride, int group_max_len, intptr_t proj, int proj_type, int hidden_type) { return LinearConfig(hidden_size, intermediate_size, stride, group_max_len, (void *)proj, (ggml_type)proj_type, (ggml_type)hidden_type); })); py::class_(linear_module, "Linear") .def(py::init()) .def("warm_up", &LinearBindings::WarmUpBindinds::cpuinfer_interface) .def("forward", &LinearBindings::ForwardBindings::cpuinfer_interface); auto mlp_module = m.def_submodule("mlp"); py::class_(mlp_module, "MLPConfig") .def(py::init([](int hidden_size, int intermediate_size, int stride, int group_max_len, intptr_t gate_proj, intptr_t up_proj, intptr_t down_proj, int gate_type, int up_type, int down_type, int hidden_type) { return MLPConfig(hidden_size, intermediate_size, stride, group_max_len, (void *)gate_proj, (void *)up_proj, (void *)down_proj, (ggml_type)gate_type, (ggml_type)up_type, (ggml_type)down_type, (ggml_type)hidden_type); })); py::class_(mlp_module, "MLP") .def(py::init()) .def("warm_up", &MLPBindings::WarmUpBindinds::cpuinfer_interface) .def("forward", &MLPBindings::ForwardBindings::cpuinfer_interface); auto moe_module = m.def_submodule("moe"); py::class_(moe_module, "MOEConfig") .def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, bool use_silu, intptr_t gate_proj, intptr_t up_proj, intptr_t down_proj, int gate_type, int up_type, int down_type, int hidden_type) { return MOEConfig(expert_num, routed_expert_num, hidden_size, intermediate_size, stride, group_min_len, group_max_len, use_silu, (void *)gate_proj, (void *)up_proj, (void *)down_proj, (ggml_type)gate_type, (ggml_type)up_type, (ggml_type)down_type, (ggml_type)hidden_type); })); py::class_(moe_module, "MOE") .def(py::init()) .def("warm_up", &MOEBindings::WarmUpBindinds::cpuinfer_interface) .def("forward", &MOEBindings::ForwardBindings::cpuinfer_interface); #if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__) py::class_(moe_module, "AMX_MOEConfig") .def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len, bool use_silu, intptr_t gate_proj, intptr_t up_proj, intptr_t down_proj) { return AMX_MOEConfig(expert_num, routed_expert_num, hidden_size, intermediate_size, max_len, use_silu, (void *)gate_proj, (void *)up_proj, (void *)down_proj); })); py::class_>(moe_module, "AMXBF16_MOE") .def(py::init()) .def("warm_up", &AMX_MOEBindings::WarmUpBindings::cpuinfer_interface) .def("load_weights", &AMX_MOEBindings::LoadWeightsBindings::cpuinfer_interface) .def("forward", &AMX_MOEBindings::ForwardBindings::cpuinfer_interface); py::class_>(moe_module, "AMXInt8_MOE") .def(py::init()) .def("warm_up", &AMX_MOEBindings::WarmUpBindings::cpuinfer_interface) .def("load_weights", &AMX_MOEBindings::LoadWeightsBindings::cpuinfer_interface) .def("forward", &AMX_MOEBindings::ForwardBindings::cpuinfer_interface); #endif auto kvcache_module = m.def_submodule("kvcache"); py::enum_(kvcache_module, "AnchorType") .value("FIXED", AnchorType::FIXED_ANCHOR) .value("DYNAMIC", AnchorType::DYNAMIC) .value("QUEST", AnchorType::QUEST) .value("BLOCK_MAX", AnchorType::BLOCK_MAX) .value("BLOCK_MEAN", AnchorType::BLOCK_MEAN); py::enum_(kvcache_module, "ggml_type") .value("FP16", ggml_type::GGML_TYPE_F16) .value("FP32", ggml_type::GGML_TYPE_F32) .value("Q4_0", ggml_type::GGML_TYPE_Q4_0) .value("Q8_0", ggml_type::GGML_TYPE_Q8_0); py::enum_(kvcache_module, "RetrievalType") .value("LAYER", RetrievalType::LAYER) .value("KVHEAD", RetrievalType::KVHEAD) .value("QHEAD", RetrievalType::QHEAD); py::class_(kvcache_module, "KVCacheConfig") .def(py::init()) .def_readwrite("layer_num", &KVCacheConfig::layer_num) .def_readwrite("kv_head_num", &KVCacheConfig::kv_head_num) .def_readwrite("q_head_num", &KVCacheConfig::q_head_num) .def_readwrite("head_dim", &KVCacheConfig::head_dim) .def_readwrite("block_len", &KVCacheConfig::block_len) .def_readwrite("anchor_num", &KVCacheConfig::anchor_num) .def_readwrite("anchor_type", &KVCacheConfig::anchor_type) .def_readwrite("kv_type", &KVCacheConfig::kv_type) .def_readwrite("retrieval_type", &KVCacheConfig::retrieval_type) .def_readwrite("layer_step", &KVCacheConfig::layer_step) .def_readwrite("token_step", &KVCacheConfig::token_step) .def_readwrite("layer_offset", &KVCacheConfig::layer_offset) .def_readwrite("max_block_num", &KVCacheConfig::max_block_num) .def_readwrite("max_batch_size", &KVCacheConfig::max_batch_size) .def_readwrite("max_thread_num", &KVCacheConfig::max_thread_num); py::class_(kvcache_module, "KVCache") .def(py::init()) .def("get_cache_total_len", &KVCache::get_cache_total_len) .def("update_cache_total_len", [](KVCache &kvcache, int cache_total_len) { kvcache.update_cache_total_len(cache_total_len); }) .def("attn", &KVCacheBindings::AttnBindings::cpuinfer_interface) .def( "get_all_kvcache_one_layer", &KVCacheBindings::GetAllKVCacheOneLayerBindings::cpuinfer_interface) .def("get_and_update_kvcache_fp16", &KVCacheBindings::GetAndUpdateKVCacheFp16Bindings:: cpuinfer_interface) .def("get_kvcache_fp16", &KVCacheBindings::GetKVCacheFp16Bindings::cpuinfer_interface) .def("update_kvcache_fp16", &KVCacheBindings::UpdateKVCacheFp16Bindings::cpuinfer_interface) .def("update_importance", &KVCacheBindings::UpdateImportanceBindings::cpuinfer_interface) .def("attn_with_kvcache", &KVCacheBindings::AttnWithKVCacheBindings::cpuinfer_interface) .def("clear_importance_all_layers", &KVCacheBindings::ClearImportanceAllLayersBindings:: cpuinfer_interface) .def("calc_anchor_all_layers", &KVCacheBindings::CalcAnchorAllLayersBindinds::cpuinfer_interface); } ================================================ FILE: archive/csrc/ktransformers_ext/operators/amx/la/amx.hpp ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2025-04-25 18:28:12 * @Version : 1.0.0 * @LastEditors : chenht2022 * @LastEditTime : 2025-04-25 18:28:12 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include "utils.hpp" #include #if (defined(_WIN32) || defined(_WIN64)) #define RESTRICT __restrict #else #define RESTRICT __restrict__ #endif #if (defined(_WIN32) || defined(_WIN64)) #define ALWAYS_INLINE __forceinline #elif __has_attribute(always_inline) || defined(__GNUC__) #define ALWAYS_INLINE __attribute__((__always_inline__)) inline #else #define ALWAYS_INLINE inline #endif namespace amx { #define ARCH_GET_XCOMP_PERM 0x1022 #define ARCH_REQ_XCOMP_PERM 0x1023 #define XFEATURE_XTILECFG 17 #define XFEATURE_XTILEDATA 18 const int TMMCount = 8; const int MaxTileHeight = 16; const int MaxTileWidth = 64; const int AMX_BLK_SIZE = 32; #define TMM0 0 #define TMM1 1 #define TMM2 2 #define TMM3 3 #define TMM4 4 #define TMM5 5 #define TMM6 6 #define TMM7 7 inline bool enable_amx() { static thread_local bool initialized = false; if (initialized) { return true; } initialized = true; if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) { printf("\n Fail to do XFEATURE_XTILEDATA \n\n"); return false; } else { // printf("\n TILE DATA USE SET - OK \n\n"); return true; } return true; } struct alignas(64) TileConfig { uint8_t palette; uint8_t start_row; std::array __0 = {}; std::array colsb; std::array __1 = {}; std::array rows; std::array __2 = {}; TileConfig() { palette = 1; start_row = 0; for (int i = 0; i < 8; i++) { set_row_col(i, 0, 0); } } void set_row_col(int i, uint8_t row, uint16_t col) { colsb[i] = col; rows[i] = row; } void set_config() { _tile_loadconfig(this); } static void load_data(int to, void *from, size_t stride) { switch (to) { case 0: _tile_loadd(0, from, stride); break; case 1: _tile_loadd(1, from, stride); break; case 2: _tile_loadd(2, from, stride); break; case 3: _tile_loadd(3, from, stride); break; case 4: _tile_loadd(4, from, stride); break; case 5: _tile_loadd(5, from, stride); break; case 6: _tile_loadd(6, from, stride); break; case 7: _tile_loadd(7, from, stride); break; default: throw std::runtime_error("no such tile"); } } static void store_data(int from, void *to, size_t stride) { switch (from) { case 0: _tile_stored(0, to, stride); break; case 1: _tile_stored(1, to, stride); break; case 2: _tile_stored(2, to, stride); break; case 3: _tile_stored(3, to, stride); break; case 4: _tile_stored(4, to, stride); break; case 5: _tile_stored(5, to, stride); break; case 6: _tile_stored(6, to, stride); break; case 7: _tile_stored(7, to, stride); break; default: throw std::runtime_error("no such tile"); } } }; static_assert(sizeof(TileConfig) == 64); inline void debug_tile(int t) { printf("Tile %d\n", t); uint8_t data[16][64] = {}; TileConfig::store_data(t, data, 64); for (int i = 0; i < 16; i++) { for (int j = 0; j < 64; j++) { printf("%3d ", data[i][j]); } printf("\n"); } printf("\n"); } inline void debug_tiles(int to = 8) { for (int i = 0; i < to; i++) { debug_tile(i); } } inline void debug_m512(__m512 x) { float data[16]; _mm512_storeu_ps(data, x); for (int i = 0; i < 16; i++) { printf("%f ", data[i]); } printf("\n"); } // transpose utils inline void transpose_16x16_32bit(__m512i *v) { __m512i v1[16]; v1[0] = _mm512_unpacklo_epi32(v[0], v[1]); v1[1] = _mm512_unpackhi_epi32(v[0], v[1]); v1[2] = _mm512_unpacklo_epi32(v[2], v[3]); v1[3] = _mm512_unpackhi_epi32(v[2], v[3]); v1[4] = _mm512_unpacklo_epi32(v[4], v[5]); v1[5] = _mm512_unpackhi_epi32(v[4], v[5]); v1[6] = _mm512_unpacklo_epi32(v[6], v[7]); v1[7] = _mm512_unpackhi_epi32(v[6], v[7]); v1[8] = _mm512_unpacklo_epi32(v[8], v[9]); v1[9] = _mm512_unpackhi_epi32(v[8], v[9]); v1[10] = _mm512_unpacklo_epi32(v[10], v[11]); v1[11] = _mm512_unpackhi_epi32(v[10], v[11]); v1[12] = _mm512_unpacklo_epi32(v[12], v[13]); v1[13] = _mm512_unpackhi_epi32(v[12], v[13]); v1[14] = _mm512_unpacklo_epi32(v[14], v[15]); v1[15] = _mm512_unpackhi_epi32(v[14], v[15]); v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]); v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]); v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]); v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]); v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]); v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]); v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]); v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]); v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]); v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]); v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]); v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]); v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]); v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]); v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]); v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]); v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88); v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88); v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88); v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88); v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd); v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd); v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd); v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd); v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88); v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88); v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88); v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88); v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd); v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd); v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd); v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd); v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); } /* Transpose 16x16 32-bit elements Note that v must be 64 byte aligned */ inline void transpose_16x16_32bit(__m512i *v, size_t stride) { assert(reinterpret_cast(v) % 64 == 0 && "v must be 64 aligned"); auto stride_v = [=](int i) { return offset_pointer(v, i * stride); }; __m512i v1[16]; v1[0] = _mm512_unpacklo_epi32(*stride_v(0), *stride_v(1)); v1[1] = _mm512_unpackhi_epi32(*stride_v(0), *stride_v(1)); v1[2] = _mm512_unpacklo_epi32(*stride_v(2), *stride_v(3)); v1[3] = _mm512_unpackhi_epi32(*stride_v(2), *stride_v(3)); v1[4] = _mm512_unpacklo_epi32(*stride_v(4), *stride_v(5)); v1[5] = _mm512_unpackhi_epi32(*stride_v(4), *stride_v(5)); v1[6] = _mm512_unpacklo_epi32(*stride_v(6), *stride_v(7)); v1[7] = _mm512_unpackhi_epi32(*stride_v(6), *stride_v(7)); v1[8] = _mm512_unpacklo_epi32(*stride_v(8), *stride_v(9)); v1[9] = _mm512_unpackhi_epi32(*stride_v(8), *stride_v(9)); v1[10] = _mm512_unpacklo_epi32(*stride_v(10), *stride_v(11)); v1[11] = _mm512_unpackhi_epi32(*stride_v(10), *stride_v(11)); v1[12] = _mm512_unpacklo_epi32(*stride_v(12), *stride_v(13)); v1[13] = _mm512_unpackhi_epi32(*stride_v(12), *stride_v(13)); v1[14] = _mm512_unpacklo_epi32(*stride_v(14), *stride_v(15)); v1[15] = _mm512_unpackhi_epi32(*stride_v(14), *stride_v(15)); *stride_v(0) = _mm512_unpacklo_epi64(v1[0], v1[2]); *stride_v(1) = _mm512_unpackhi_epi64(v1[0], v1[2]); *stride_v(2) = _mm512_unpacklo_epi64(v1[1], v1[3]); *stride_v(3) = _mm512_unpackhi_epi64(v1[1], v1[3]); *stride_v(4) = _mm512_unpacklo_epi64(v1[4], v1[6]); *stride_v(5) = _mm512_unpackhi_epi64(v1[4], v1[6]); *stride_v(6) = _mm512_unpacklo_epi64(v1[5], v1[7]); *stride_v(7) = _mm512_unpackhi_epi64(v1[5], v1[7]); *stride_v(8) = _mm512_unpacklo_epi64(v1[8], v1[10]); *stride_v(9) = _mm512_unpackhi_epi64(v1[8], v1[10]); *stride_v(10) = _mm512_unpacklo_epi64(v1[9], v1[11]); *stride_v(11) = _mm512_unpackhi_epi64(v1[9], v1[11]); *stride_v(12) = _mm512_unpacklo_epi64(v1[12], v1[14]); *stride_v(13) = _mm512_unpackhi_epi64(v1[12], v1[14]); *stride_v(14) = _mm512_unpacklo_epi64(v1[13], v1[15]); *stride_v(15) = _mm512_unpackhi_epi64(v1[13], v1[15]); v1[0] = _mm512_shuffle_i32x4(*stride_v(0), *stride_v(4), 0x88); v1[1] = _mm512_shuffle_i32x4(*stride_v(1), *stride_v(5), 0x88); v1[2] = _mm512_shuffle_i32x4(*stride_v(2), *stride_v(6), 0x88); v1[3] = _mm512_shuffle_i32x4(*stride_v(3), *stride_v(7), 0x88); v1[4] = _mm512_shuffle_i32x4(*stride_v(0), *stride_v(4), 0xdd); v1[5] = _mm512_shuffle_i32x4(*stride_v(1), *stride_v(5), 0xdd); v1[6] = _mm512_shuffle_i32x4(*stride_v(2), *stride_v(6), 0xdd); v1[7] = _mm512_shuffle_i32x4(*stride_v(3), *stride_v(7), 0xdd); v1[8] = _mm512_shuffle_i32x4(*stride_v(8), *stride_v(12), 0x88); v1[9] = _mm512_shuffle_i32x4(*stride_v(9), *stride_v(13), 0x88); v1[10] = _mm512_shuffle_i32x4(*stride_v(10), *stride_v(14), 0x88); v1[11] = _mm512_shuffle_i32x4(*stride_v(11), *stride_v(15), 0x88); v1[12] = _mm512_shuffle_i32x4(*stride_v(8), *stride_v(12), 0xdd); v1[13] = _mm512_shuffle_i32x4(*stride_v(9), *stride_v(13), 0xdd); v1[14] = _mm512_shuffle_i32x4(*stride_v(10), *stride_v(14), 0xdd); v1[15] = _mm512_shuffle_i32x4(*stride_v(11), *stride_v(15), 0xdd); *stride_v(0) = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); *stride_v(1) = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); *stride_v(2) = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); *stride_v(3) = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); *stride_v(4) = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); *stride_v(5) = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); *stride_v(6) = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); *stride_v(7) = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); *stride_v(8) = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); *stride_v(9) = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); *stride_v(10) = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); *stride_v(11) = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); *stride_v(12) = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); *stride_v(13) = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); *stride_v(14) = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); *stride_v(15) = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); } struct GemmKernel224BF { using dt = ggml_bf16_t; using output_t = float; static const int TILE_M = 16; static const int TILE_K = 32; static const int TILE_N = 16; static const int VNNI_BLK = 2; static const int M_STEP = TILE_M * 2; static const int N_STEP = TILE_N * 2; static const int K_STEP = TILE_K; static inline const int N_BLOCK = 256; static inline const int K_BLOCK = 1792; static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; } static std::pair split_range_n(int n, int ith, int nth) { int n_start = N_BLOCK * ith; int n_end = std::min(n, N_BLOCK * (ith + 1)); return {n_start, n_end}; } static void config() { enable_amx(); TileConfig tile_config; // size is 16 x 32 for (int i = 0; i < 2; i++) tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt)); // size is 16 x 32 for (int i = 2; i < 4; i++) tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt)); // size is 16 x 16 for (int i = 4; i < 8; i++) tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t)); tile_config.set_config(); } static void load_a(dt *a, size_t lda) { _tile_loadd(0, a, lda); _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda); } static void load_b(dt *b, size_t ldb) { _tile_loadd(2, b, ldb); _tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb); } static void clean_c() { _tile_zero(4); _tile_zero(5); _tile_zero(6); _tile_zero(7); } static void load_c(output_t *c, size_t ldc) { _tile_loadd(4, c, ldc); _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc); _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc); _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc); } static void store_c(output_t *c, size_t ldc) { _tile_stored(4, c, ldc); _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc); _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc); _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc); } static void run_tile() { _tile_dpbf16ps(4, 0, 2); _tile_dpbf16ps(5, 0, 3); _tile_dpbf16ps(6, 1, 2); _tile_dpbf16ps(7, 1, 3); } struct BufferA { ggml_bf16_t *a; int max_m, k; static size_t required_size(int max_m, int k) { return max_m * k * sizeof(ggml_bf16_t); } BufferA(int max_m, int k, void *ptr) : max_m(max_m), k(k) { assert(reinterpret_cast(ptr) % 64 == 0); assert(max_m % M_STEP == 0); assert(k % K_STEP == 0); a = reinterpret_cast(ptr); } void from_mat(int m, ggml_bf16_t *src, int ith, int nth) { assert(m <= max_m); assert(ith == 0 && nth == 1); int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; for (int m_begin = 0; m_begin < m; m_begin += M_STEP) { for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) { int k_block_size = std::min(K_BLOCK, k - k_block_begin); for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) { for (int i = 0; i < M_STEP && m_begin + i < m; i++) { __m512i *s = (__m512i *)(src + (m_begin + i) * k + k_block_begin + k_begin); __m512i *d = (__m512i *)(a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP); avx512_copy_32xbf16(s, d); } } } } } ggml_bf16_t *get_submat(int m, int k, int m_begin, int k_begin) { int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; int k_block_begin = k_begin / K_BLOCK * K_BLOCK; k_begin -= k_block_begin; int k_block_size = std::min(K_BLOCK, k - k_block_begin); return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP; } }; struct BufferB { ggml_bf16_t *b; int n, k; static size_t required_size(int n, int k) { return n * k * sizeof(ggml_bf16_t); } BufferB(int n, int k, void *ptr) : n(n), k(k) { assert(reinterpret_cast(ptr) % 64 == 0); assert(n % N_STEP == 0); assert(k % K_STEP == 0); b = reinterpret_cast(ptr); } void from_mat(ggml_bf16_t *src, int ith, int nth) { auto [n_start, n_end] = split_range_n(n, ith, nth); int n_block_begin = n_start; int n_block_size = n_end - n_block_begin; for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) { int k_block_size = std::min(K_BLOCK, k - k_block_begin); for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) { for (int i = 0; i < N_STEP; i++) { __m512i *s = (__m512i *)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin); __m512i *d = (__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP + i * K_STEP); avx512_copy_32xbf16(s, d); } transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP)); transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP)); } } } } ggml_bf16_t *get_submat(int n, int k, int n_begin, int k_begin) { int n_block_begin = n_begin / N_BLOCK * N_BLOCK; n_begin -= n_block_begin; int n_block_size = std::min(N_BLOCK, n - n_block_begin); int k_block_begin = k_begin / K_BLOCK * K_BLOCK; k_begin -= k_block_begin; int k_block_size = std::min(K_BLOCK, k - k_block_begin); return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP; } }; struct BufferC { float *c; int max_m, n; static size_t required_size(int max_m, int n) { return max_m * n * sizeof(float); } BufferC(int max_m, int n, void *ptr) : max_m(max_m), n(n) { assert(reinterpret_cast(ptr) % 64 == 0); assert(max_m % M_STEP == 0); assert(n % N_STEP == 0); c = reinterpret_cast(ptr); } void to_mat(int m, ggml_bf16_t *dst, int ith, int nth) { assert(m <= max_m); auto [n_start, n_end] = split_range_n(n, ith, nth); int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; int n_block_begin = n_start; int n_block_size = n_end - n_block_begin; for (int m_begin = 0; m_begin < m; m_begin += M_STEP) { for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { for (int i = 0; i < M_STEP && m_begin + i < m; i++) { __m512 *x0 = (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP); __m512 *x1 = (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP + 16); avx512_32xfp32_to_32xbf16(x0, x1, (__m512i *)(dst + (m_begin + i) * n + n_block_begin + n_begin)); } } } } float *get_submat(int m, int n, int m_begin, int n_begin) { int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; int n_block_begin = n_begin / N_BLOCK * N_BLOCK; int n_block_size = std::min(N_BLOCK, n - n_block_begin); n_begin -= n_block_begin; return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP; } }; }; struct GemmKernel224Int8 { using dt = int8_t; using output_t = int32_t; static const int TILE_M = 16; static const int TILE_K = 64; static const int TILE_N = 16; static const int VNNI_BLK = 4; static const int M_STEP = TILE_M * 2; static const int N_STEP = TILE_N * 2; static const int K_STEP = TILE_K; static inline const int N_BLOCK = 256; static inline const int K_BLOCK = 3584; static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; } static std::pair split_range_n(int n, int ith, int nth) { int n_start = N_BLOCK * ith; int n_end = std::min(n, N_BLOCK * (ith + 1)); return {n_start, n_end}; } static void config() { enable_amx(); TileConfig tile_config; // size is 16 x 64 for (int i = 0; i < 2; i++) tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt)); // size is 16 x 64 for (int i = 2; i < 4; i++) tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt)); // size is 16 x 16 for (int i = 4; i < 8; i++) tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t)); tile_config.set_config(); } static void load_a(dt *a, size_t lda) { _tile_loadd(0, a, lda); _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda); } static void load_b(dt *b, size_t ldb) { _tile_loadd(2, b, ldb); _tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb); } static void clean_c() { _tile_zero(4); _tile_zero(5); _tile_zero(6); _tile_zero(7); } static void load_c(output_t *c, size_t ldc) { _tile_loadd(4, c, ldc); _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc); _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc); _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc); } static void store_c(output_t *c, size_t ldc) { _tile_stored(4, c, ldc); _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc); _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc); _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc); } static void run_tile() { _tile_dpbssd(4, 0, 2); _tile_dpbssd(5, 0, 3); _tile_dpbssd(6, 1, 2); _tile_dpbssd(7, 1, 3); } struct BufferA { int8_t *a; float *d; int max_m, k; static size_t required_size(int max_m, int k) { return max_m * k * sizeof(int8_t) + max_m * sizeof(float); } BufferA(int max_m, int k, void *ptr) : max_m(max_m), k(k) { assert(reinterpret_cast(ptr) % 64 == 0); assert(max_m % M_STEP == 0); assert(k % K_STEP == 0); a = reinterpret_cast(ptr); d = reinterpret_cast(a + max_m * k); } void from_mat(int m, ggml_bf16_t *src, int ith, int nth) { assert(m <= max_m); assert(ith == 0 && nth == 1); for (int m_begin = 0; m_begin < m; m_begin += M_STEP) { for (int i = 0; i < M_STEP && m_begin + i < m; i++) { float amax = 0.0f; for (int j = 0; j < k; j += 32) { __m512 f0, f1; avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + j), &f0, &f1); amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0))); amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1))); } d[m_begin + i] = amax / ((1 << 7) - 1); } } int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; for (int m_begin = 0; m_begin < m; m_begin += M_STEP) { for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) { int k_block_size = std::min(K_BLOCK, k - k_block_begin); for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) { for (int i = 0; i < M_STEP && m_begin + i < m; i++) { __m512 id = _mm512_set1_ps(d[m_begin + i] ? 1.0f / d[m_begin + i] : 0.0f); int8_t *dst = a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP; __m512 f0, f1, f2, f3; avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + k_block_begin + k_begin), &f0, &f1); avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3); __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id)); __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id)); __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id)); __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id)); __m128i s0 = _mm512_cvtsepi32_epi8(i0); __m128i s1 = _mm512_cvtsepi32_epi8(i1); __m128i s2 = _mm512_cvtsepi32_epi8(i2); __m128i s3 = _mm512_cvtsepi32_epi8(i3); _mm_storeu_si128((__m128i *)dst, s0); _mm_storeu_si128((__m128i *)(dst + 16), s1); _mm_storeu_si128((__m128i *)(dst + 32), s2); _mm_storeu_si128((__m128i *)(dst + 48), s3); } } } } } int8_t *get_submat(int m, int k, int m_begin, int k_begin) { int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; int k_block_begin = k_begin / K_BLOCK * K_BLOCK; k_begin -= k_block_begin; int k_block_size = std::min(K_BLOCK, k - k_block_begin); return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP; } float *get_scale(int m, int m_begin) { return d + m_begin; } }; struct BufferB { int8_t *b; float *d; int n, k; static size_t required_size(int n, int k) { return n * k * sizeof(int8_t) + n * sizeof(float); } BufferB(int n, int k, void *ptr) : n(n), k(k) { assert(reinterpret_cast(ptr) % 64 == 0); assert(n % N_STEP == 0); assert(k % K_STEP == 0); b = reinterpret_cast(ptr); d = reinterpret_cast(b + n * k); } void from_mat(ggml_bf16_t *src, int ith, int nth) { auto [n_start, n_end] = split_range_n(n, ith, nth); int n_block_begin = n_start; int n_block_size = n_end - n_block_begin; for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { for (int i = 0; i < N_STEP; i++) { float amax = 0.0f; for (int j = 0; j < k; j += 32) { __m512 f0, f1; avx512_32xbf16_to_32xfp32((__m512i *)(src + (n_block_begin + n_begin + i) * k + j), &f0, &f1); amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0))); amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1))); } d[n_block_begin + n_begin + i] = amax / ((1 << 7) - 1); } } for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) { int k_block_size = std::min(K_BLOCK, k - k_block_begin); for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) { for (int i = 0; i < N_STEP; i++) { __m512 id = _mm512_set1_ps(d[n_block_begin + n_begin + i] ? 1.0f / d[n_block_begin + n_begin + i] : 0.0f); int8_t *dst = b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP + i * K_STEP; __m512 f0, f1, f2, f3; avx512_32xbf16_to_32xfp32((__m512i *)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin), &f0, &f1); avx512_32xbf16_to_32xfp32( (__m512i *)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3); __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id)); __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id)); __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id)); __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id)); __m128i s0 = _mm512_cvtsepi32_epi8(i0); __m128i s1 = _mm512_cvtsepi32_epi8(i1); __m128i s2 = _mm512_cvtsepi32_epi8(i2); __m128i s3 = _mm512_cvtsepi32_epi8(i3); _mm_storeu_si128((__m128i *)dst, s0); _mm_storeu_si128((__m128i *)(dst + 16), s1); _mm_storeu_si128((__m128i *)(dst + 32), s2); _mm_storeu_si128((__m128i *)(dst + 48), s3); } transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP)); transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP)); } } } } int8_t *get_submat(int n, int k, int n_begin, int k_begin) { int n_block_begin = n_begin / N_BLOCK * N_BLOCK; n_begin -= n_block_begin; int n_block_size = std::min(N_BLOCK, n - n_block_begin); int k_block_begin = k_begin / K_BLOCK * K_BLOCK; k_begin -= k_block_begin; int k_block_size = std::min(K_BLOCK, k - k_block_begin); return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP; } float *get_scale(int n, int n_begin) { return d + n_begin; } }; struct BufferC { float *c; int max_m, n; static size_t required_size(int max_m, int n) { return max_m * n * sizeof(float); } BufferC(int max_m, int n, void *ptr) : max_m(max_m), n(n) { assert(reinterpret_cast(ptr) % 64 == 0); assert(max_m % M_STEP == 0); assert(n % N_STEP == 0); c = reinterpret_cast(ptr); } void to_mat(int m, ggml_bf16_t *dst, int ith, int nth) { assert(m <= max_m); auto [n_start, n_end] = split_range_n(n, ith, nth); int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; int n_block_begin = n_start; int n_block_size = n_end - n_block_begin; for (int m_begin = 0; m_begin < m; m_begin += M_STEP) { for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { for (int i = 0; i < M_STEP && m_begin + i < m; i++) { __m512 *x0 = (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP); __m512 *x1 = (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP + 16); avx512_32xfp32_to_32xbf16(x0, x1, (__m512i *)(dst + (m_begin + i) * n + n_block_begin + n_begin)); } } } } float *get_submat(int m, int n, int m_begin, int n_begin) { int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; int n_block_begin = n_begin / N_BLOCK * N_BLOCK; int n_block_size = std::min(N_BLOCK, n - n_block_begin); n_begin -= n_block_begin; return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP; } }; }; inline void mat_mul(int m, int n, int k, std::shared_ptr ba, std::shared_ptr bb, std::shared_ptr bc, int ith, int nth, bool use_amx) { using K = GemmKernel224BF; assert(n % K::N_STEP == 0); assert(k % K::K_STEP == 0); auto [n_start, n_end] = K::split_range_n(n, ith, nth); for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) { for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) { for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) { float *c = bc->get_submat(m, n, m_begin, n_begin); if (!use_amx) { __m512 *c512 = (__m512 *)c; if (k_block_begin == 0) { for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) { c512[m_i * 2] = _mm512_setzero_ps(); c512[m_i * 2 + 1] = _mm512_setzero_ps(); } } for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) { int32_t *a32 = (int32_t *)ba->get_submat(m, k, m_begin, k_block_begin + k_begin); __m512bh *b512 = (__m512bh *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin); for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) { for (int k_i = 0; k_i < 16; k_i++) { __m512bh ma = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i]); for (int n_i = 0; n_i < 2; n_i++) { c512[m_i * 2 + n_i] = _mm512_dpbf16_ps(c512[m_i * 2 + n_i], ma, b512[n_i * 16 + k_i]); } } } } } else { if (k_block_begin == 0) { K::clean_c(); } else { K::load_c(c, K::N_STEP * sizeof(float)); } for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) { K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K::K_STEP * sizeof(ggml_bf16_t)); K::load_b(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::K_STEP * sizeof(ggml_bf16_t)); K::run_tile(); } K::store_c(c, K::N_STEP * sizeof(float)); } } } } } inline __m512i _mm512_dpbssd_epi32(__m512i src, __m512i a, __m512i b) { __m256i a_lo = _mm512_extracti64x4_epi64(a, 0); __m256i a_hi = _mm512_extracti64x4_epi64(a, 1); __m256i b_lo = _mm512_extracti64x4_epi64(b, 0); __m256i b_hi = _mm512_extracti64x4_epi64(b, 1); b_lo = _mm256_sign_epi8(b_lo, a_lo); b_hi = _mm256_sign_epi8(b_hi, a_hi); b = _mm512_inserti64x4(b, b_lo, 0); b = _mm512_inserti64x4(b, b_hi, 1); a = _mm512_abs_epi8(a); return _mm512_dpbusd_epi32(src, a, b); } inline void mat_mul(int m, int n, int k, std::shared_ptr ba, std::shared_ptr bb, std::shared_ptr bc, int ith, int nth, bool use_amx) { using K = GemmKernel224Int8; assert(n % K::N_STEP == 0); assert(k % K::K_STEP == 0); auto [n_start, n_end] = K::split_range_n(n, ith, nth); for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) { for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) { for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) { float *c = bc->get_submat(m, n, m_begin, n_begin); if (!use_amx) { __m512i *c512 = (__m512i *)c; if (k_block_begin == 0) { for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) { c512[m_i * 2] = _mm512_setzero_si512(); c512[m_i * 2 + 1] = _mm512_setzero_si512(); } } for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) { static_assert(K::K_STEP * sizeof(int8_t) == sizeof(__m512i)); static_assert(K::N_STEP / K::TILE_N == 2, "Must be lke this"); int32_t *a32 = (int32_t *)ba->get_submat(m, k, m_begin, k_block_begin + k_begin); __m512i *b512 = (__m512i *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin); for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) { for (int k_i = 0; k_i < 16; k_i++) { __m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]); for (int n_i = 0; n_i < 2; n_i++) { c512[m_i * 2 + n_i] = _mm512_dpbssd_epi32(c512[m_i * 2 + n_i], ma, b512[n_i * 16 + k_i]); } } } } } else { if (k_block_begin == 0) { K::clean_c(); } else { K::load_c((int32_t *)c, K::N_STEP * sizeof(int32_t)); } for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) { K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K::K_STEP * sizeof(int8_t)); K::load_b(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::K_STEP * sizeof(int8_t)); K::run_tile(); } K::store_c((int32_t *)c, K::N_STEP * sizeof(int32_t)); } if (k_block_begin + K::K_BLOCK >= k) { int to = m - m_begin; if (m - m_begin > K::M_STEP) { to = K::M_STEP; } for (int i = 0; i < to; i++) { __m512 as = _mm512_set1_ps(*ba->get_scale(m, m_begin + i)); __m512 bs = _mm512_load_ps(bb->get_scale(n, n_begin)); __m512i now = _mm512_load_si512((__m512i *)(c + i * K::N_STEP)); __m512 result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now)); _mm512_store_ps((__m512 *)(c + i * K::N_STEP), result); bs = _mm512_load_ps(bb->get_scale(n, n_begin) + K::TILE_N); now = _mm512_load_si512((__m512i *)(c + i * K::N_STEP + K::TILE_N)); result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now)); _mm512_store_ps((__m512 *)(c + i * K::N_STEP + K::TILE_N), result); } } } } } } } // namespace amx ================================================ FILE: archive/csrc/ktransformers_ext/operators/amx/la/utils.hpp ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2025-04-25 18:28:12 * @Version : 1.0.0 * @LastEditors : chenht2022 * @LastEditTime : 2025-04-25 18:28:12 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #pragma once #include template T* offset_pointer(T* ptr, std::size_t byte_offset) { return reinterpret_cast(reinterpret_cast(ptr) + byte_offset); } template const T* offset_pointer(const T* ptr, std::size_t byte_offset) { return reinterpret_cast(reinterpret_cast(ptr) + byte_offset); } template T* offset_pointer_row_major(T* t, int row, int col, std::size_t ld) { return offset_pointer(t, row * ld) + col; } template T* offset_pointer_col_major(T* t, int row, int col, std::size_t ld) { return offset_pointer(t, col * ld) + row; } static inline void avx512_copy_32xbf16(__m512i* src, __m512i* dst) { _mm512_storeu_si512(dst, _mm512_loadu_si512(src)); } static inline void avx512_32xfp32_to_32xbf16(__m512* src0, __m512* src1, __m512i* dst) { _mm512_storeu_si512(dst, __m512i(_mm512_cvtne2ps_pbh(*src1, *src0))); } static inline void avx512_32xbf16_to_32xfp32(__m512i* src, __m512* dst0, __m512* dst1) { _mm512_storeu_ps(dst0, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(src))), 16))); _mm512_storeu_ps(dst1, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(src) + 1)), 16))); } ================================================ FILE: archive/csrc/ktransformers_ext/operators/amx/moe.hpp ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2025-04-25 18:28:12 * @Version : 1.0.0 * @LastEditors : chenht2022 * @LastEditTime : 2025-04-25 18:28:12 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #ifndef CPUINFER_OPERATOR_AMX_MOE_H #define CPUINFER_OPERATOR_AMX_MOE_H #include #include #include #include #include #include "../../cpu_backend/backend.h" #include "../../cpu_backend/shared_mem_buffer.h" #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml.h" #include "llamafile/sgemm.h" #include "la/amx.hpp" #ifdef USE_NUMA #include #include void *numa_alloc_aligned(size_t size, int node, size_t alignment) { void *ptr = numa_alloc_onnode(size, node); assert(reinterpret_cast(ptr) % 64 == 0); return ptr; } #endif static inline __m512 exp_avx512(__m512 x) { const __m512 log2e = _mm512_set1_ps(1.44269504089f); const __m512 c1 = _mm512_set1_ps(0.69314718056f); __m512 y = _mm512_mul_ps(x, log2e); __m512i int_part = _mm512_cvtps_epi32(y); __m512 frac_part = _mm512_sub_ps(y, _mm512_cvtepi32_ps(int_part)); const __m512 poly_1 = _mm512_set1_ps(0.9999999995f); const __m512 poly_2 = _mm512_set1_ps(0.6931471805f); const __m512 poly_3 = _mm512_set1_ps(0.2402265069f); const __m512 poly_4 = _mm512_set1_ps(0.0555041087f); const __m512 poly_5 = _mm512_set1_ps(0.0096181291f); const __m512 poly_6 = _mm512_set1_ps(0.0013333558f); __m512 frac_exp = _mm512_fmadd_ps( frac_part, poly_6, _mm512_fmadd_ps(frac_part, poly_5, _mm512_fmadd_ps(frac_part, poly_4, _mm512_fmadd_ps(frac_part, poly_3, _mm512_fmadd_ps(frac_part, poly_2, poly_1))))); __m512 two_pow_i = _mm512_scalef_ps(_mm512_set1_ps(1.0f), _mm512_cvtepi32_ps(int_part)); return _mm512_mul_ps(two_pow_i, frac_exp); } static inline __m512 act_fn(__m512 gate_val, __m512 up_val) { __m512 neg_gate_val = _mm512_sub_ps(_mm512_setzero_ps(), gate_val); __m512 exp_neg_gate = exp_avx512(neg_gate_val); __m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), exp_neg_gate); __m512 act_val = _mm512_div_ps(gate_val, denom); return _mm512_mul_ps(act_val, up_val); } static inline __m512 relu_act_fn(__m512 gate_val, __m512 up_val) { __m512 zero_vec = _mm512_setzero_ps(); __m512 act_val = _mm512_max_ps(zero_vec, gate_val); return _mm512_mul_ps(act_val, up_val); } struct AMX_MOEConfig { int expert_num; int routed_expert_num; int hidden_size; int intermediate_size; int max_len; bool use_silu; void *gate_proj; void *up_proj; void *down_proj; AMX_MOEConfig() {} AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len, bool use_silu, void *gate_proj, void *up_proj, void *down_proj) : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), intermediate_size(intermediate_size), max_len(max_len), use_silu(use_silu), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj) {} }; template class AMX_MOE { private: AMX_MOEConfig config_; void *gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] void *up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] void *down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)] ggml_bf16_t *m_local_input_; // [routed_expert_num * max_len * hidden_size] ggml_bf16_t *m_local_gate_output_; // [routed_expert_num * max_len * intermediate_size] ggml_bf16_t *m_local_up_output_; // [routed_expert_num * max_len * intermediate_size] ggml_bf16_t *m_local_down_output_; // [routed_expert_num * max_len * hidden_size] std::vector> m_local_pos_; // [max_len, routed_expert_num] std::vector m_local_num_; // [expert_num] std::vector m_expert_id_map_; // [expert_num] std::vector m_local_input_ptr_; // [expert_num] std::vector m_local_gate_output_ptr_; // [expert_num] std::vector m_local_up_output_ptr_; // [expert_num] std::vector m_local_down_output_ptr_; // [expert_num] std::vector> gate_up_ba_; std::vector> gate_bc_; std::vector> up_bc_; std::vector> down_ba_; std::vector> down_bc_; #ifdef USE_NUMA std::vector>> gate_bb_numa_; std::vector>> up_bb_numa_; std::vector>> down_bb_numa_; #else std::vector> gate_bb_; std::vector> up_bb_; std::vector> down_bb_; #endif public: AMX_MOE(AMX_MOEConfig config) { config_ = config; gate_proj_ = config_.gate_proj; up_proj_ = config_.up_proj; down_proj_ = config_.down_proj; std::vector> m_mem_requests; m_mem_requests.push_back({(void **)&m_local_input_, sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size}); m_mem_requests.push_back({(void **)&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.intermediate_size}); m_mem_requests.push_back({(void **)&m_local_up_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.intermediate_size}); m_mem_requests.push_back({(void **)&m_local_down_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size}); std::vector gate_up_ba_ptr(config_.expert_num); std::vector gate_bc_ptr(config_.expert_num); std::vector up_bc_ptr(config_.expert_num); std::vector down_ba_ptr(config_.expert_num); std::vector down_bc_ptr(config_.expert_num); for (int i = 0; i < config_.expert_num; i++) { m_mem_requests.push_back( {(void **)&gate_up_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.hidden_size)}); m_mem_requests.push_back( {(void **)&gate_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)}); m_mem_requests.push_back( {(void **)&up_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)}); m_mem_requests.push_back( {(void **)&down_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.intermediate_size)}); m_mem_requests.push_back( {(void **)&down_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.hidden_size)}); } shared_mem_buffer.alloc(this, m_mem_requests); m_local_pos_.resize(config_.max_len); for (int i = 0; i < config_.max_len; i++) { m_local_pos_[i].resize(config_.routed_expert_num); } m_expert_id_map_.resize(config_.expert_num); m_local_num_.resize(config_.expert_num); m_local_input_ptr_.resize(config_.expert_num); m_local_gate_output_ptr_.resize(config_.expert_num); m_local_up_output_ptr_.resize(config_.expert_num); m_local_down_output_ptr_.resize(config_.expert_num); for (uint64_t i = 0; i < config_.expert_num; i++) { gate_up_ba_.push_back( std::make_shared(config_.max_len, config_.hidden_size, gate_up_ba_ptr[i])); gate_bc_.push_back( std::make_shared(config_.max_len, config_.intermediate_size, gate_bc_ptr[i])); up_bc_.push_back(std::make_shared(config_.max_len, config_.intermediate_size, up_bc_ptr[i])); down_ba_.push_back( std::make_shared(config_.max_len, config_.intermediate_size, down_ba_ptr[i])); down_bc_.push_back(std::make_shared(config_.max_len, config_.hidden_size, down_bc_ptr[i])); #ifdef USE_NUMA int numa_nodes = numa_num_configured_nodes(); gate_bb_numa_.resize(numa_nodes); up_bb_numa_.resize(numa_nodes); down_bb_numa_.resize(numa_nodes); for (int j = 0; j < numa_nodes; j++) { void *gate_bb_ptr = numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64); gate_bb_numa_[j].push_back( std::make_shared(config_.intermediate_size, config_.hidden_size, gate_bb_ptr)); void *up_bb_ptr = numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64); up_bb_numa_[j].push_back( std::make_shared(config_.intermediate_size, config_.hidden_size, up_bb_ptr)); void *down_bb_ptr = numa_alloc_aligned(T::BufferB::required_size(config_.hidden_size, config_.intermediate_size), j, 64); down_bb_numa_[j].push_back( std::make_shared(config_.hidden_size, config_.intermediate_size, down_bb_ptr)); } #else void *gate_bb_ptr = std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size)); gate_bb_.push_back( std::make_shared(config_.intermediate_size, config_.hidden_size, gate_bb_ptr)); void *up_bb_ptr = std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size)); up_bb_.push_back( std::make_shared(config_.intermediate_size, config_.hidden_size, up_bb_ptr)); void *down_bb_ptr = std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)); down_bb_.push_back( std::make_shared(config_.hidden_size, config_.intermediate_size, down_bb_ptr)); #endif } } ~AMX_MOE() { shared_mem_buffer.dealloc(this); } void load_weights(Backend *backend) { int nth = T::recommended_nth(config_.intermediate_size); backend->do_work_stealing_job( nth * config_.expert_num, nullptr, [&](int task_id) { uint64_t expert_idx = task_id / nth; int ith = task_id % nth; #ifdef USE_NUMA int numa_nodes = numa_num_configured_nodes(); for (int j = 0; j < numa_nodes; j++) { gate_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth); up_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.up_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth); } #else gate_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth); up_bb_[expert_idx]->from_mat( (ggml_bf16_t *)config_.up_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth); #endif }, nullptr); nth = T::recommended_nth(config_.hidden_size); backend->do_work_stealing_job( nth * config_.expert_num, nullptr, [&](int task_id) { uint64_t expert_idx = task_id / nth; int ith = task_id % nth; #ifdef USE_NUMA int numa_nodes = numa_num_configured_nodes(); for (int j = 0; j < numa_nodes; j++) { down_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj + expert_idx * config_.hidden_size * config_.intermediate_size, ith, nth); } #else down_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj + expert_idx * config_.hidden_size * config_.intermediate_size, ith, nth); #endif }, nullptr); } void warm_up(Backend *backend) {} void forward(int qlen, int k, const uint64_t *expert_ids, const float *weights, const void *input, void *output, int *batch_size_tensor, Backend *backend) { qlen = batch_size_tensor[0]; bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num); int activated_expert = 0; for (int i = 0; i < config_.expert_num; i++) { m_local_num_[i] = 0; } for (int i = 0; i < qlen; i++) { for (int j = 0; j < k; j++) { m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++; } } for (int i = 0; i < config_.expert_num; i++) { if (m_local_num_[i] > 0) { m_expert_id_map_[activated_expert] = i; activated_expert++; } } uint64_t offset = 0; for (int i = 0; i < config_.expert_num; i++) { m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size; m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size; m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size; m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size; offset += m_local_num_[i]; } backend->do_work_stealing_job( qlen, nullptr, [&](int i) { for (int j = 0; j < k; j++) { memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size, (ggml_bf16_t *)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size); } }, nullptr); backend->do_work_stealing_job( activated_expert, nullptr, [&](int task_id) { int expert_idx = m_expert_id_map_[task_id]; gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1); }, nullptr); int nth = T::recommended_nth(config_.intermediate_size); backend->do_work_stealing_job( nth * activated_expert, [&](int _) { T::config(); }, [&](int task_id) { int expert_idx = m_expert_id_map_[task_id / nth]; int ith = task_id % nth; #ifdef USE_NUMA amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, gate_up_ba_[expert_idx], gate_bb_numa_[Backend::numa_node][expert_idx], gate_bc_[expert_idx], ith, nth, use_amx); amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, gate_up_ba_[expert_idx], up_bb_numa_[Backend::numa_node][expert_idx], up_bc_[expert_idx], ith, nth, use_amx); #else amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, gate_up_ba_[expert_idx], gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth, use_amx); amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], ith, nth, use_amx); #endif gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth); up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth); auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); if (config_.use_silu) { for (int i = 0; i < m_local_num_[expert_idx]; i++) { ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; for (int j = n_start; j < n_end; j += 32) { __m512 gate_val0, gate_val1, up_val0, up_val1; avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1); avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1); __m512 result0 = act_fn(gate_val0, up_val0); __m512 result1 = act_fn(gate_val1, up_val1); avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j)); } } } else { for (int i = 0; i < m_local_num_[expert_idx]; i++) { ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; for (int j = n_start; j < n_end; j += 32) { __m512 gate_val0, gate_val1, up_val0, up_val1; avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1); avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1); __m512 result0 = relu_act_fn(gate_val0, up_val0); __m512 result1 = relu_act_fn(gate_val1, up_val1); avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j)); } } } }, nullptr); backend->do_work_stealing_job( activated_expert, nullptr, [&](int task_id) { int expert_idx = m_expert_id_map_[task_id]; down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1); }, nullptr); nth = T::recommended_nth(config_.hidden_size); backend->do_work_stealing_job( nth * activated_expert, [&](int _) { T::config(); }, [&](int task_id) { int expert_idx = m_expert_id_map_[task_id / nth]; int ith = task_id % nth; #ifdef USE_NUMA amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx], down_bb_numa_[Backend::numa_node][expert_idx], down_bc_[expert_idx], ith, nth, use_amx); #else amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx], down_bb_[expert_idx], down_bc_[expert_idx], ith, nth, use_amx); #endif down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth); }, nullptr); backend->do_work_stealing_job( qlen, nullptr, [&](int i) { for (int e = 0; e < config_.hidden_size; e += 32) { __m512 x0 = _mm512_setzero_ps(); __m512 x1 = _mm512_setzero_ps(); for (int j = 0; j < k; j++) { __m512 weight = _mm512_set1_ps(weights[i * k + j]); __m512 down_output0, down_output1; avx512_32xbf16_to_32xfp32((__m512i *)(m_local_down_output_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size + e), &down_output0, &down_output1); x0 = _mm512_fmadd_ps(down_output0, weight, x0); x1 = _mm512_fmadd_ps(down_output1, weight, x1); } avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i *)((ggml_bf16_t *)output + i * config_.hidden_size + e)); } }, nullptr); } }; #endif ================================================ FILE: archive/csrc/ktransformers_ext/operators/kvcache/kvcache.h ================================================ /** * @Description : * @Author : Jianwei Dong * @Date : 2024-08-26 22:47:06 * @Version : 1.0.0 * @LastEditors : Jianwei Dong * @LastEditTime : 2024-08-26 22:47:06 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #ifndef CPUINFER_OPERATOR_KVCACHE_H #define CPUINFER_OPERATOR_KVCACHE_H #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "../../cpu_backend/backend.h" #include "llama.cpp/ggml-common.h" #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml.h" #include "llamafile/sgemm.h" #define CHUNK_SIZE 32 /** * @brief Converts a ggml_type enum value to its corresponding string * representation. * * This function provides a human-readable string representation for a given * ggml_type enum value. The string can be used for logging, debugging, or * displaying information in a user interface. * * @param type The ggml_type enum value to convert. * @return A string representation of the enum value. */ std::string ggml_type_to_string(ggml_type type); /** * @enum AnchorType * @brief Defines the types of anchors used in attention mechanisms. * * This enum specifies different types of anchors that can be used in attention * mechanisms, such as fixed anchors, dynamic anchors, or special anchors like * QUEST, BLOCK_MEAN, or BLOCK_MAX. */ enum AnchorType { FIXED_ANCHOR, /**< A fixed anchor that does not change. */ DYNAMIC, /**< A dynamic anchor that can change over time. */ QUEST, /**< A special anchor type used for QUEST (Query and Embedding Space Transformation). */ BLOCK_MEAN, /**< An anchor based on the mean of a block of data. */ BLOCK_MAX /**< An anchor based on the maximum value within a block of data. */ }; /** * @brief Converts an AnchorType enum value to its corresponding string * representation. * * This function provides a human-readable string representation for a given * AnchorType enum value. The string can be used for logging, debugging, or * displaying information in a user interface. * * @param anchor_type The AnchorType enum value to convert. * @return A string representation of the enum value. */ std::string AnchorTypeToString(AnchorType anchor_type); /** * @enum RetrievalType * @brief Defines the types of retrieval strategies in attention mechanisms. * * This enum specifies different retrieval strategies that can be used in * attention mechanisms, such as layer-level retrieval, key-value head-level * retrieval, or query head-level retrieval. */ enum RetrievalType { LAYER, /**< Retrieval at the layer level. */ KVHEAD, /**< Retrieval at the key-value head level. */ QHEAD /**< Retrieval at the query head level. */ }; /** * @brief Converts a RetrievalType enum value to its corresponding string * representation. * * This function provides a human-readable string representation for a given * RetrievalType enum value. The string can be used for logging, debugging, or * displaying information in a user interface. * * @param retrieval_type The RetrievalType enum value to convert. * @return A string representation of the enum value. */ std::string RetrievalTypeToString(RetrievalType retrieval_type); /** * @struct KVCacheConfig * @brief Configuration structure for Key-Value (KV) Cache. * * This structure holds configuration parameters for setting up and managing * a Key-Value (KV) Cache used in various attention mechanisms. It includes * parameters such as the number of layers, the number of heads, the dimension * of each head, block length, anchor information, and memory-related settings. */ struct KVCacheConfig { int layer_num; /**< Number of layers in the model. */ int kv_head_num; /**< Number of heads in the KV Cache. */ int q_head_num; /**< Number of heads in the query. */ int head_dim; /**< Dimension of each head. */ int block_len; /**< Length of each block in the cache. */ int anchor_num; /**< Number of anchors used in attention. */ ggml_type kv_type; /**< Data type of the KV Cache (e.g., fp16, q8_0). */ // Controls the pre-allocated memory size int max_block_num; /**< Maximum number of blocks that can be allocated. */ int max_batch_size; /**< Maximum batch size that can be processed. */ int max_thread_num; /**< Maximum number of threads that can be used. */ AnchorType anchor_type; /**< Type of anchors used in the attention mechanism. */ RetrievalType retrieval_type; /**< Type of retrieval strategy used in the cache. */ int layer_step; /**< Step size between layers. */ int token_step; /**< Step size between tokens. */ int layer_offset; /**< Offset value for layers. */ /** * @brief Default constructor for KVCacheConfig. * * Initializes the configuration with default values. This constructor * does not initialize any member variables explicitly. */ KVCacheConfig() = default; /** * @brief Parameterized constructor for KVCacheConfig. * * This constructor initializes the configuration with specific values * for all member variables. * * @param layer_num The number of layers in the model. * @param kv_head_num The number of heads in the KV Cache. * @param q_head_num The number of heads in the query. * @param head_dim The dimension of each head. * @param block_len The length of each block in the cache. * @param anchor_num The number of anchors used in attention. * @param anchor_type The type of anchors used in the attention mechanism. * @param kv_type The data type of the KV Cache (e.g., fp16, q8_0). * @param retrieval_type The type of retrieval strategy used in the cache. * @param layer_step The step size between layers. * @param token_step The step size between tokens. * @param layer_offset The offset value for layers. * @param max_block_num The maximum number of blocks that can be allocated. * @param max_batch_size The maximum batch size that can be processed. * @param max_thread_num The maximum number of threads that can be used. */ KVCacheConfig(int layer_num, int kv_head_num, int q_head_num, int head_dim, int block_len, int anchor_num, AnchorType anchor_type, ggml_type kv_type, RetrievalType retrieval_type, int layer_step, int token_step, int layer_offset, int max_block_num, int max_batch_size, int max_thread_num); }; /** * @class KVCache * @brief Manages the Key-Value (KV) Cache used in attention mechanisms. * * The KVCache class provides functionality for managing the Key-Value Cache, * including resizing the cache, retrieving configuration parameters, and * updating internal states. This class is typically used in transformer models * to store and manage past key and value states for efficient attention * computations. */ class KVCache { public: /** * @brief Constructs a KVCache object with the given configuration. * * Initializes the KVCache with the specified configuration parameters, * such as the number of layers, heads, head dimensions, and other * relevant settings. * * @param config The configuration object containing initialization * parameters. */ KVCache(KVCacheConfig config); /** * @brief Resizes the number of threads used by the cache. * * This function adjusts the number of threads that the cache can utilize. * It allows dynamic reconfiguration of the parallel processing capabilities * based on the current workload or system resources. * * @param thread_num The new number of threads to use. */ void ThreadResize(int thread_num); /** * @brief Resizes the batch size managed by the cache. * * This function adjusts the batch size that the cache can handle. It * is useful when the input batch size changes dynamically, allowing * the cache to be reconfigured accordingly. * * @param batch_size The new batch size. */ void BatchResize(int batch_size); /** * @brief Resizes the number of blocks managed by the cache. * * This function adjusts the number of blocks that the cache can manage. * It allows dynamic reconfiguration of the block structure based on the * current sequence length or other factors. * * @param block_num The new number of blocks. */ void BlockResize(int block_num); /** * @brief Gets the number of layers in the cache. * * @return The number of layers configured in the cache. */ int get_layer_num() { return config_.layer_num; } /** * @brief Gets the number of KV heads in the cache. * * @return The number of KV heads configured in the cache. */ int get_kv_head_num() { return config_.kv_head_num; } /** * @brief Gets the number of query heads in the cache. * * @return The number of query heads configured in the cache. */ int get_q_head_num() { return config_.q_head_num; } /** * @brief Gets the dimension of each head in the cache. * * @return The dimension of each head. */ int get_head_dim() { return config_.head_dim; } /** * @brief Gets the length of each block in the cache. * * @return The length of each block. */ int get_block_len() { return config_.block_len; } /** * @brief Gets the number of blocks for a specific layer. * * @param layer_id The ID of the layer for which to retrieve the block * number. * @return The number of blocks in the specified layer. */ int get_block_num(int layer_id) { return past_block_num_[layer_id]; } /** * @brief Gets the number of anchors in the cache. * * @return The number of anchors configured in the cache. */ int get_anchor_num() { return config_.anchor_num; } /** * @brief Gets the total length of the cache. * * @return The total length of the cache. */ int get_cache_total_len() { return cache_total_len_; } /** * @brief Gets the total number of blocks in the cache. * * This function computes and returns the total number of blocks in the * cache based on the total cache length and the block length configuration. * * @return The total number of blocks in the cache. */ int get_cache_total_block_num() { return (cache_total_len_ + config_.block_len - 1) / config_.block_len; } /** * @brief Updates the total length of the cache. * * This function sets a new total length for the cache, allowing dynamic * adjustment of the cache size during runtime. * * @param cache_total_len The new total length of the cache. */ void update_cache_total_len(int cache_total_len) { cache_total_len_ = cache_total_len; } void attn(const ggml_fp16_t *q_in, ggml_fp16_t *output, float *attn_lse, int layer_idx, int generate_token_idx, int q_len, int batch_size, int max_block_num, int *block_table, int *cache_seqlens, int pick_block_num, int init_block_num, int local_block_num, Backend *backend); void update_kvcache_one_block_fp16(const ggml_fp16_t *k_in, const ggml_fp16_t *v_in, int layer_id, int block_idx, Backend *backend); void get_kvcache_one_block_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in, int layer_id, int block_idx, Backend *backend); void update_importance_one_block(const ggml_fp16_t *importance, int layer_id, int block_idx, Backend *backend); void get_importance_one_block(ggml_fp16_t *importance, int layer_id, int block_idx, Backend *backend); void get_anchor_one_block(ggml_fp16_t *anchor, int layer_id, int block_idx, Backend *backend); void update_anchor_one_block(const ggml_fp16_t *anchor, int layer_id, int block_idx, Backend *backend); void calc_anchor_all_layers(int *block_table, int *cache_seqlens, int batch_size, int max_block_num, Backend *backend); void load_kvcache(std::string tensor_file_path, Backend *backend); void dump_kvcache(int *block_table, int cache_total_len, std::string tensor_file_path, Backend *backend); void get_and_update_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in, int layer_id, int *block_table, int batch_size, int max_block_num, int *cache_seqlens, int q_len, Backend *backend); void get_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in, int layer_id, int *block_table, int batch_size, int max_block_num, int *cache_seqlens, Backend *backend); void update_kvcache_fp16(const ggml_fp16_t *k_in, const ggml_fp16_t *v_in, int layer_id, int *block_table, int batch_size, int max_block_num, int *cache_seqlens, int q_len, Backend *backend); void update_importance(const ggml_fp16_t *importance, int layer_id, int *block_table, int batch_size, int max_block_num, int *offset, int width, Backend *backend); void attn_with_kvcache(const ggml_fp16_t *q_in, const ggml_fp16_t *k_in, const ggml_fp16_t *v_in, ggml_fp16_t *output, float *attn_lse, int layer_idx, int generate_token_idx, int q_len, int batch_size, int max_block_num, int *block_table, int *cache_seqlens, int topk, int local, Backend *backend); void clear_importance_all_layers(int *block_table, int *cache_seqlens, int batch_size, int max_block_num, Backend *backend); void clear_kvcache_all_layers(int *block_table, int *cache_seqlens, int batch_size, int max_block_num, Backend *backend); void get_sincos(ggml_fp16_t *sin, ggml_fp16_t *cos, int seqlen); void get_attn_sparsity(const ggml_fp16_t *q_in, float *attn_sparsity, int layer_idx, int generate_token_idx, int q_len, int batch_size, int max_block_num, int *block_table, int *cache_seqlens, int *block_table_origin, int *cache_seqlens_origin, int max_block_num_origin, int topk, int local, Backend *backend); void get_all_kvcache_one_layer(int layer_id, ggml_fp16_t *k_in, ggml_fp16_t *v_in, Backend *backend); private: // Persistent data KVCacheConfig config_; int n_gqa_; // q_head_num / kv_head_num int cache_total_len_; // Number of tokens in cache std::vector past_block_num_; // [layer_num] std::vector>>> k_cache_q4; // [layer_num, kv_head_num, past_block_num, block_len * // (head_dim / QK_4)] std::vector>>> v_cache_q4; // [layer_num, kv_head_num, past_block_num, head_dim * // (block_len / QK_4)] std::vector>>> k_cache_q8; // [layer_num, kv_head_num, past_block_num, block_len * // (head_dim / QK_8)] std::vector>>> v_cache_q8; // [layer_num, kv_head_num, past_block_num, head_dim * // (block_len / QK_8)] std::vector>>> k_cache_fp16_; // [layer_num, kv_head_num, past_block_num, block_len * // head_dim] std::vector>>> v_cache_fp16_; // [layer_num, kv_head_num, past_block_num, head_dim * // block_len] std::vector>>> importance_; // [layer_num, past_block_num, block_len, // attention_head_num] std::vector anchor_; // [layer_num * past_block_num * anchor_num * // attention_head_num * head_dim] // Runtime data int64_t layer_id_; int64_t block_idx_; int *block_table_; uint64_t block_num_; int max_block_num_after_retrieval_; // Rotary positional embeddings std::vector> sin_; // [seq_len, head_dim] std::vector> cos_; // [seq_len, head_dim] // update/get int seq_len_; uint16_t *k_scales_; // q4_0 uint8_t *k_in_; // q4_0 uint16_t *v_scales_; // q4_0 uint8_t *v_in_; // q4_0 uint16_t *k_data_; // fp16 uint16_t *v_data_; // fp16 uint16_t *importance_data_; // fp16 uint16_t *anchor_data_; // fp16 // sparsity = (sigma(block lse / lse)) std::vector>> block_lse_; // [batch_size, max_block_num, q_head_num] std::vector> attn_sparsity_; // [batch_size, q_head_num] // attn std::vector> avg_q; // [batch_size, q_head_num * head_dim] std::vector> avg_q_fp16; // [batch_size, q_head_num * head_dim] std::vector< std::priority_queue, std::vector>, std::greater<>>> top_similar_block_; std::vector> block_similar_; std::vector>> block_similar_kv_head_; std::vector>> block_similar_q_head_; std::vector cache_seqlens_; // [batch_size] std::vector selected_blocks_num_history_; // [layer_num // layer_step] std::vector>> selected_blocks_history_; // [layer_num // layer_step, batch_size, max_block_num] std::vector>>> selected_blocks_history_kvhead_; // [layer_num // layer_step, // batch_size, max_block_num, // kv_head_num] std::vector> block_table_before_retrieval_; // [batch_size, max_block_num] std::vector> block_table_after_retrieval_; // [batch_size, pick_block_num] std::vector>> block_table_before_retrieval_qhead_; // [batch_size, max_block_num, // q_head_num] std::vector>> block_table_after_retrieval_qhead_; // [batch_size, pick_block_num, // q_head_num] std::vector>> block_table_before_retrieval_kvhead_; // [batch_size, max_block_num, // kv_head_num] std::vector>> block_table_after_retrieval_kvhead_; // [batch_size, pick_block_num, // kv_head_num] std::vector>> mutex_; // [batch_size, kv_head_num] std::vector>> q_q8_0_; // [batch_size, kv_head_num, n_gqa * head_dim / QK8_0] std::vector>> q_fp32_; // [batch_size, kv_head_num, n_gqa * head_dim] std::vector>> output_fp32_; // [batch_size, kv_head_num, n_gqa * head_dim] std::vector>> attn_lse_; // [batch_size, kv_head_num, n_gqa] std::vector> thread_cur_head_idx_; // [thread_num] std::vector> thread_local_output_q8_0_; // [thread_num, n_gqa * head_dim / QK8_0] std::vector> thread_local_attn_score_; // [thread_num, n_gqa * block_len] std::vector> thread_local_output_fp32_; // [thread_num, n_gqa * head_dim] std::vector> thread_local_attn_lse_; // [thread_num, n_gqa] std::vector> thread_local_cur_output_fp32_; // [thread_num, n_gqa * head_dim] std::vector> thread_local_cur_attn_lse_; // [thread_num, n_gqa] std::vector> thread_local_attn_mask_; // [thread_num, block_len // 8] std::vector> thread_local_draft_; // [thread_num, 2 * n_gqa * block_len + 6 * n_gqa * // head_dim + 2 * block_len * head_dim] // tmp space std::vector q_fp32; // [n_gqa * head_dim] void quantize_q_(const uint16_t *q_in_data, int batch_size); void attn_initialize_layer_(int batch_size, int layer_idx, int *block_table, int &max_block_num, int *cache_seqlens); void attn_initialize_kvhead_(int batch_size, int layer_idx, int *block_table, int &max_block_num, int *cache_seqlens); void retrieval_kvcache_layer_(const uint16_t *q_in_data, int init_block_num, int local_block_num, int pick_block_num, int q_len, int generate_token_idx, int batch_size, int layer_idx, int *cache_seqlens, int &max_block_num, Backend *backend); void retrieval_kvcache_kvhead_(const uint16_t *q_in_data, int init_block_num, int local_block_num, int pick_block_num, int q_len, int generate_token_idx, int batch_size, int layer_idx, int *cache_seqlens, int &max_block_num, Backend *backend); void calculate_block_similarity_layer_( const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len, int max_block_num, int *cache_seqlens, int init_block_num, int local_block_num, int pick_block_num, Backend *backend); void calculate_block_similarity_kvhead_( const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len, int max_block_num, int *cache_seqlens, int init_block_num, int local_block_num, int pick_block_num, Backend *backend); void select_block_layer_(int batch_size, int layer_idx, int max_block_num, int init_block_num, int local_block_num, int pick_block_num); void select_block_kvhead_(int batch_size, int layer_idx, int max_block_num, int init_block_num, int local_block_num, int pick_block_num); void calculate_sparsity_layer_(const uint16_t *q_in_data, float *attn_sparsity, int batch_size, int max_block_num, int *block_table, int *cache_seqlens, Backend *backend); void calculate_sparsity_kvhead_(const uint16_t *q_in_data, float *attn_sparsity, int batch_size, int max_block_num, int *block_table, int *cache_seqlens, Backend *backend); void attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output, float *attn_lse, int batch_size, Backend *backend); void attention_layer_(const uint16_t *q_in_data, ggml_fp16_t *output, float *attn_lse, int batch_size, Backend *backend); /** * @brief Computes attention with KV cache for one block. * * This function performs attention computation for one block using KV * cache. The function supports different data types for Q, K, and V caches, * and provides options for quantization. The function does not perform any * dynamic memory allocation internally, so all necessary buffers must be * pre-allocated externally. * * @param head_dim The dimension of the head. * @param bsz The batch size. * @param q_type The data type of Q (GGML data type). Only supports fp16 and * q8_0. * @param q Pointer to the Q tensor [bsz, head_dim]. The quantization is * always applied along the head_dim dimension. The size must be * bsz * head_dim/32 * qtype_size. If head_dim % 32 != 0, an error * will be raised. * @param past_kv_len The length of the past KV cache. * @param past_kv_offset The offset in the past KV cache. * @param is_full_attn Boolean flag indicating whether to use full attention * (true for full 1 mask). * @param attn_mask Pointer to the attention mask [bsz, past_kv_len]. If * is_full_attn = false, a bit matrix is passed to * represent the mask. * @param k_type The data type of K cache (GGML data type). Only supports * fp16, q4_0, and q8_0. * @param k_quant_type Quantization type for K cache. 0 for per_token, 1 for * per_channel. Other values will raise an error. * @param k_cache Pointer to the K cache tensor [seq_len, head_dim]. If * quant_type == 0, head_dim % 32 must be 0. If quant_type == * 1, seq_len % 32 must be 0. * @param num_k_anchor The number of K anchors. If num_k_anchor == 0, it * means no anchor is present. * @param k_cache_anchors Pointer to the K cache anchors [num_k_anchor, * head_dim]. The k_anchor_type must be fp16. * @param k_cache_anchor_pos Pointer to the K cache anchor positions. Each * token is associated with the nearest previous anchor position. * @param v_type The data type of V cache (GGML data type). * @param v_quant_type Quantization type for V cache. * @param v_cache Pointer to the V cache tensor [head_dim, seq_len]. * @param num_v_anchor The number of V anchors. * @param v_cache_anchors Pointer to the V cache anchors. * @param v_cache_anchor_pos Pointer to the V cache anchor positions. * @param attn_score Pre-allocated buffer for attention scores [bsz, * past_kv_len]. * @param output Output tensor [bsz, head_dim] with the same type as q_type. * @param lse Pre-allocated buffer [bsz] for the log-sum-exp of the * attention scores. * @param draft Pre-allocated temporary buffer. The buffer size should be * enough to hold (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 * * past_kv_len * head_dim + past_kv_len * head_dim / 32) bytes. * @param rotary_angle Pointer to the rotary angle tensor. * @param rotary_cos Pointer to the cosine values for rotary embedding. * @param rotary_sin Pointer to the sine values for rotary embedding. */ void attn_with_kvcache_one_block_( int head_dim, int bsz, ggml_type q_type, // GGML data type of `Q`, only supports fp16 and q8_0 // [bsz, head_dim] // Quantization is always on the head_dim dimension (per_token). If // head_dim % 32 != 0, an error will be raised. The size must be bsz * // head_dim/32 * qtype_size. const void *q, int past_kv_len, int past_kv_offset, bool is_full_attn, // true indicates a full 1 mask // If is_full_attn = false, a bit matrix representing the mask is // passed. [bsz, past_kv_len] const uint8_t *attn_mask, ggml_type k_type, // GGML data type of `K Cache`, only supports fp16, // q4_0, q8_0 int k_quant_type, // 0 for per_token, 1 for per_channel, others raise an // error // [seq_len, head_dim] // If quant_type == 0, head_dim % 32 must be 0. // If quant_type == 1, seq_len % 32 must be 0. const void *k_cache, // k_anchor_type must be fp16 int num_k_anchor, // num_k_anchor == 0 indicates no anchor // [num_k_anchor, head_dim] const void *k_cache_anchors, // Each token is associated with the nearest previous position's anchor, // with the same distance. const int *k_cache_anchor_pos, // v_cache similar to k_cache ggml_type v_type, int v_quant_type, // [head_dim, seq_len] const void *v_cache, int num_v_anchor, const void *v_cache_anchors, const int *v_cache_anchor_pos, // Pre-allocated buffer for intermediate calculations [bsz, // past_kv_len]. No malloc is performed inside this function. float *attn_score, // Output: [bsz, head_dim], with the same type as q_type void *output, // [bsz] float *lse, // Pre-allocated temporary buffer with sufficient size: // (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 * past_kv_len * // head_dim + past_kv_len * head_dim / 32) bytes. void *draft, // Apply rotary embedding online const int *rotary_angle, const void *rotary_cos, const void *rotary_sin // rotary_cos=None, // rotary_sin=None, // cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, // cache_batch_idx: Optional[torch.Tensor] = None, // rotary_interleaved=True, // // Not supported for now // window_size=(-1, -1), # -1 means infinite context window // alibi_slopes=None, ); }; /** * @brief Scales a float32 vector by a given scalar value. * * This function multiplies each element of the input vector `y` by a scalar * `v`. It uses platform-specific optimizations if available, such as Apple's * Accelerate framework or SIMD instructions. If no specific optimization is * available, the function falls back to a simple scalar multiplication loop. * * @param n The number of elements in the vector `y`. * @param y The input vector to be scaled. The result will be stored in the same * vector. * @param v The scalar value by which to scale the vector. */ void ggml_vec_scale_f32(const int n, float *y, const float v); #endif ================================================ FILE: archive/csrc/ktransformers_ext/operators/kvcache/kvcache_attn.cpp ================================================ /** * @Description : * @Author : Jianwei Dong * @Date : 2024-08-26 22:47:06 * @Version : 1.0.0 * @LastEditors : Jianwei Dong * @LastEditTime : 2024-08-26 22:47:06 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "kvcache.h" #include void KVCache::attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output, float *attn_lse, int batch_size, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); seq_len_ = config_.block_len; backend->do_work_stealing_job( batch_size * config_.kv_head_num * max_block_num_after_retrieval_, [&](int thread_id) { thread_cur_head_idx_[thread_id].first = -1; thread_cur_head_idx_[thread_id].second = -1; }, [&](int task_id) { int batch_id = task_id / (config_.kv_head_num * max_block_num_after_retrieval_); int head_id = (task_id % (config_.kv_head_num * max_block_num_after_retrieval_)) / max_block_num_after_retrieval_; int block_id = task_id % max_block_num_after_retrieval_; int thread_id = Backend::thread_local_id; // If the block is out of the sequence length, skip it. if (cache_seqlens_[batch_id] / config_.block_len < block_id) { return; } int block_idx = block_table_after_retrieval_kvhead_[batch_id][block_id] [head_id]; if (cache_seqlens_[batch_id] / config_.block_len == block_id) { int seq_len = cache_seqlens_[batch_id] % config_.block_len; if (seq_len == 0) return; // Prepare the attention mask for the last block. int full_blocks = seq_len / 8; int remaining_bits = seq_len % 8; // Fill full blocks with 1s for (int i = 0; i < full_blocks; ++i) { thread_local_attn_mask_[thread_id][i] = 0xFF; } // Fill the remaining bits in the next block if (remaining_bits > 0 && full_blocks < seq_len_ / 8) { thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1; } else { thread_local_attn_mask_[thread_id][full_blocks] = 0; } for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) { thread_local_attn_mask_[thread_id][i] = 0; } if (config_.kv_type == ggml_type::GGML_TYPE_F16) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16, (void *)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim + head_id * n_gqa_ * config_.head_dim], seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } } else { if (config_.kv_type == ggml_type::GGML_TYPE_F16) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16, (void *)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim + head_id * n_gqa_ * config_.head_dim], seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } } int cur_batch_idx = thread_cur_head_idx_[thread_id].first; int cur_head_id = thread_cur_head_idx_[thread_id].second; if (batch_id == cur_batch_idx && head_id == cur_head_id) { for (int i = 0; i < n_gqa_; i++) { float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] + std::log( 1.0 + std::exp(thread_local_attn_lse_[thread_id][i] - thread_local_cur_attn_lse_[thread_id][i])); ggml_vec_scale_f32( config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse)); ggml_vec_scale_f32( config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse)); for (int j = 0; j < config_.head_dim; j++) { thread_local_cur_output_fp32_[thread_id] [i * config_.head_dim + j] += thread_local_output_fp32_[thread_id] [i * config_.head_dim + j]; } thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse; } } else { if (cur_batch_idx != -1) { mutex_[cur_batch_idx][cur_head_id]->lock(); for (int i = 0; i < n_gqa_; i++) { if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) { attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i]; for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] = thread_local_cur_output_fp32_ [thread_id] [i * config_.head_dim + j]; } continue; } float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] + std::log( 1.0 + std::exp( thread_local_cur_attn_lse_[thread_id][i] - attn_lse_[cur_batch_idx][cur_head_id][i])); ggml_vec_scale_f32( config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim, std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse)); ggml_vec_scale_f32( config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse)); for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] += thread_local_cur_output_fp32_ [thread_id][i * config_.head_dim + j]; } attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse; } mutex_[cur_batch_idx][cur_head_id]->unlock(); } thread_cur_head_idx_[thread_id].first = batch_id; thread_cur_head_idx_[thread_id].second = head_id; for (int i = 0; i < n_gqa_; i++) { thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i]; for (int j = 0; j < config_.head_dim; j++) { thread_local_cur_output_fp32_ [thread_id][i * config_.head_dim + j] = thread_local_output_fp32_[thread_id] [i * config_.head_dim + j]; } } } }, // Merge the results of the remaining blocks. [&](int thread_id) { int cur_batch_idx = thread_cur_head_idx_[thread_id].first; int cur_head_id = thread_cur_head_idx_[thread_id].second; if (cur_head_id != -1) { mutex_[cur_batch_idx][cur_head_id]->lock(); for (int i = 0; i < n_gqa_; i++) { float new_attn_lse; if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) { attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i]; for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] = thread_local_cur_output_fp32_ [thread_id] [i * config_.head_dim + j]; } continue; } new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] + std::log( 1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] - attn_lse_[cur_batch_idx][cur_head_id][i])); ggml_vec_scale_f32( config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim, std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse)); ggml_vec_scale_f32( config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse)); for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] += thread_local_cur_output_fp32_[thread_id] [i * config_.head_dim + j]; } attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse; } mutex_[cur_batch_idx][cur_head_id]->unlock(); } }); // move the results to output and attn_lse uint16_t *output_data = reinterpret_cast(output); float *attn_lse_data = attn_lse; for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { for (int i = 0; i < config_.kv_head_num; i++) { for (int j = 0; j < n_gqa_ * config_.head_dim; j++) { output_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim + i * n_gqa_ * config_.head_dim + j] = GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]); } for (int j = 0; j < n_gqa_; j++) { attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ + i * n_gqa_ + j] = attn_lse_[batch_idx][i][j]; } } } // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = end - start; // printf("layer %d time of computing attention: %f s\n", layer_idx, // diff.count()); } void KVCache::attention_layer_(const uint16_t *q_in_data, ggml_fp16_t *output, float *attn_lse, int batch_size, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); seq_len_ = config_.block_len; backend->do_work_stealing_job( batch_size * config_.kv_head_num * max_block_num_after_retrieval_, [&](int thread_id) { thread_cur_head_idx_[thread_id].first = -1; thread_cur_head_idx_[thread_id].second = -1; }, [&](int task_id) { int batch_id = task_id / (config_.kv_head_num * max_block_num_after_retrieval_); int head_id = (task_id % (config_.kv_head_num * max_block_num_after_retrieval_)) / max_block_num_after_retrieval_; int block_id = task_id % max_block_num_after_retrieval_; int thread_id = Backend::thread_local_id; // If the block is out of the sequence length, skip it. if (cache_seqlens_[batch_id] / config_.block_len < block_id) { return; } int block_idx = block_table_after_retrieval_[batch_id][block_id]; if (cache_seqlens_[batch_id] / config_.block_len == block_id) { int seq_len = cache_seqlens_[batch_id] % config_.block_len; if (seq_len == 0) return; // Prepare the attention mask for the last block. int full_blocks = seq_len / 8; int remaining_bits = seq_len % 8; // Fill full blocks with 1s for (int i = 0; i < full_blocks; ++i) { thread_local_attn_mask_[thread_id][i] = 0xFF; } // Fill the remaining bits in the next block if (remaining_bits > 0 && full_blocks < seq_len_ / 8) { thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1; } else { thread_local_attn_mask_[thread_id][full_blocks] = 0; } for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) { thread_local_attn_mask_[thread_id][i] = 0; } if (config_.kv_type == ggml_type::GGML_TYPE_F16) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16, (void *)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim + head_id * n_gqa_ * config_.head_dim], seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } } else { if (config_.kv_type == ggml_type::GGML_TYPE_F16) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16, (void *)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim + head_id * n_gqa_ * config_.head_dim], seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } } int cur_batch_idx = thread_cur_head_idx_[thread_id].first; int cur_head_id = thread_cur_head_idx_[thread_id].second; if (batch_id == cur_batch_idx && head_id == cur_head_id) { for (int i = 0; i < n_gqa_; i++) { float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] + std::log( 1.0 + std::exp(thread_local_attn_lse_[thread_id][i] - thread_local_cur_attn_lse_[thread_id][i])); ggml_vec_scale_f32( config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse)); ggml_vec_scale_f32( config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse)); for (int j = 0; j < config_.head_dim; j++) { thread_local_cur_output_fp32_[thread_id] [i * config_.head_dim + j] += thread_local_output_fp32_[thread_id] [i * config_.head_dim + j]; } thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse; } } else { if (cur_batch_idx != -1) { mutex_[cur_batch_idx][cur_head_id]->lock(); for (int i = 0; i < n_gqa_; i++) { if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) { attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i]; for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] = thread_local_cur_output_fp32_ [thread_id] [i * config_.head_dim + j]; } continue; } float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] + std::log( 1.0 + std::exp( thread_local_cur_attn_lse_[thread_id][i] - attn_lse_[cur_batch_idx][cur_head_id][i])); ggml_vec_scale_f32( config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim, std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse)); ggml_vec_scale_f32( config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse)); for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] += thread_local_cur_output_fp32_ [thread_id][i * config_.head_dim + j]; } attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse; } mutex_[cur_batch_idx][cur_head_id]->unlock(); } thread_cur_head_idx_[thread_id].first = batch_id; thread_cur_head_idx_[thread_id].second = head_id; for (int i = 0; i < n_gqa_; i++) { thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i]; for (int j = 0; j < config_.head_dim; j++) { thread_local_cur_output_fp32_ [thread_id][i * config_.head_dim + j] = thread_local_output_fp32_[thread_id] [i * config_.head_dim + j]; } } } }, // Merge the results of the remaining blocks. [&](int thread_id) { int cur_batch_idx = thread_cur_head_idx_[thread_id].first; int cur_head_id = thread_cur_head_idx_[thread_id].second; if (cur_head_id != -1) { mutex_[cur_batch_idx][cur_head_id]->lock(); for (int i = 0; i < n_gqa_; i++) { float new_attn_lse; if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) { attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i]; for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] = thread_local_cur_output_fp32_ [thread_id] [i * config_.head_dim + j]; } continue; } new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] + std::log( 1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] - attn_lse_[cur_batch_idx][cur_head_id][i])); ggml_vec_scale_f32( config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim, std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse)); ggml_vec_scale_f32( config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse)); for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] += thread_local_cur_output_fp32_[thread_id] [i * config_.head_dim + j]; } attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse; } mutex_[cur_batch_idx][cur_head_id]->unlock(); } }); // move the results to output and attn_lse uint16_t *output_data = reinterpret_cast(output); float *attn_lse_data = attn_lse; for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { for (int i = 0; i < config_.kv_head_num; i++) { for (int j = 0; j < n_gqa_ * config_.head_dim; j++) { output_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim + i * n_gqa_ * config_.head_dim + j] = GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]); } for (int j = 0; j < n_gqa_; j++) { attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ + i * n_gqa_ + j] = attn_lse_[batch_idx][i][j]; } } } // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = end - start; // printf("layer %d time of computing attention: %f s\n", layer_id_, // diff.count()); } void KVCache::attn(const ggml_fp16_t *q_in, ggml_fp16_t *output, float *attn_lse, int layer_idx, int generate_token_idx, int q_len, int batch_size, int max_block_num, int *block_table, int *cache_seqlens, int pick_block_num, int init_block_num, int local_block_num, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); layer_id_ = layer_idx; batch_size = batch_size * q_len; const uint16_t *q_in_data = const_cast(q_in); quantize_q_(q_in_data, batch_size); if (config_.retrieval_type == RetrievalType::LAYER) { attn_initialize_layer_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens); retrieval_kvcache_layer_(q_in_data, init_block_num, local_block_num, pick_block_num, q_len, generate_token_idx, batch_size, layer_idx, cache_seqlens, max_block_num, backend); attention_layer_(q_in_data, output, attn_lse, batch_size, backend); } else if (config_.retrieval_type == RetrievalType::KVHEAD) { attn_initialize_kvhead_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens); retrieval_kvcache_kvhead_(q_in_data, init_block_num, local_block_num, pick_block_num, q_len, generate_token_idx, batch_size, layer_idx, cache_seqlens, max_block_num, backend); attention_kvhead_(q_in_data, output, attn_lse, batch_size, backend); } // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = end - start; // printf("layer %d time of computing attention: %f s\n", layer_idx, // diff.count()); } void KVCache::attn_with_kvcache( const ggml_fp16_t *q_in, const ggml_fp16_t *k_in, const ggml_fp16_t *v_in, ggml_fp16_t *output, float *attn_lse, int layer_idx, int generate_token_idx, int q_len, int batch_size, int max_block_num, int *block_table, int *cache_seqlens, int topk, int local, Backend *backend) { // printf("attn_with_kvcache start\n"); assert(q_len == 1); // Timer start auto start = std::chrono::high_resolution_clock::now(); layer_id_ = layer_idx; update_kvcache_fp16(k_in, v_in, layer_idx, block_table, batch_size, max_block_num, cache_seqlens, q_len, backend); // printf("update finished.\n"); // cache_seqlens memory is modified. for (int i = 0; i < batch_size; i++) { cache_seqlens[i] += q_len; } int init_block_num = 1; if (config_.block_len <= 32) { init_block_num = 64 / config_.block_len; } attn(q_in, output, attn_lse, layer_idx, generate_token_idx, q_len, batch_size, max_block_num, block_table, cache_seqlens, topk, init_block_num, local, backend); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = end - start; // printf("layer %d time of computing attention with kvcache: %f s\n", // layer_idx, diff.count()); } void KVCache::quantize_q_(const uint16_t *q_in_data, int batch_size) { // Timer start auto start = std::chrono::high_resolution_clock::now(); for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { if (config_.kv_type == ggml_type::GGML_TYPE_F16) { // quantize q for (int i = 0; i < config_.kv_head_num; i++) { for (int j = 0; j < n_gqa_ * config_.head_dim; j++) { q_fp32_[batch_idx][i][j] = GGML_FP16_TO_FP32( q_in_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim + i * n_gqa_ * config_.head_dim + j]); } } } else { // quantize q for (int i = 0; i < config_.kv_head_num; i++) { for (int j = 0; j < n_gqa_ * config_.head_dim; j++) { q_fp32[j] = GGML_FP16_TO_FP32( q_in_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim + i * n_gqa_ * config_.head_dim + j]); } quantize_row_q8_0(q_fp32.data(), q_q8_0_[batch_idx][i].data(), n_gqa_ * config_.head_dim); } } } // Timer end auto end = std::chrono::high_resolution_clock::now(); // printf("time of quantizing q: %f s\n", // std::chrono::duration(end - start).count()); } void KVCache::attn_initialize_layer_(int batch_size, int layer_idx, int *block_table, int &max_block_num, int *cache_seqlens) { // Timer start auto start = std::chrono::high_resolution_clock::now(); for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { // initialize output_fp32_ and attn_lse_ for (int i = 0; i < config_.kv_head_num; i++) { for (int j = 0; j < n_gqa_ * config_.head_dim; j++) { output_fp32_[batch_idx][i][j] = 0; } for (int j = 0; j < n_gqa_; j++) { attn_lse_[batch_idx][i][j] = 0; } } // clear top_similar_block_ while (!top_similar_block_[batch_idx].empty()) top_similar_block_[batch_idx].pop(); } // get block_table_before_retrieval_ and cache_seqlens_ if (block_table == nullptr) { max_block_num = past_block_num_[layer_idx]; for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { if (cache_total_len_ != 0) cache_seqlens_[batch_idx] = cache_total_len_; else cache_seqlens_[batch_idx] = max_block_num * config_.block_len; for (int i = 0; i < max_block_num; i++) { block_table_before_retrieval_[batch_idx][i] = i; block_similar_[batch_idx][i] = 0; } } } else { for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { cache_seqlens_[batch_idx] = cache_seqlens[batch_idx]; for (int i = 0; i < max_block_num; i++) { block_table_before_retrieval_[batch_idx][i] = block_table[batch_idx * max_block_num + i]; block_similar_[batch_idx][i] = 0; } } } // Timer end auto end = std::chrono::high_resolution_clock::now(); // printf("layer %d time of initializing attention: %f s\n", layer_idx, // std::chrono::duration(end - start).count()); } void KVCache::calculate_block_similarity_layer_( const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len, int max_block_num, int *cache_seqlens, int init_block_num, int local_block_num, int pick_block_num, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); if (batch_size == 1 && config_.anchor_num == 1) { // TODO: improve batch_size > 1 for (int batch_id = 0; batch_id < batch_size; batch_id++) { if (q_len == 1) { for (int j = 0; j < config_.head_dim * config_.q_head_num; j++) { avg_q[batch_id][j] = GGML_FP16_TO_FP32( q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim + j]); avg_q_fp16[batch_id][j] = q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim + j]; } } else { for (int j = 0; j < config_.head_dim * config_.q_head_num; j++) { avg_q[batch_id][j] = 0; } for (int i = 0; i < q_len; i++) { for (int j = 0; j < config_.head_dim; j++) { avg_q[batch_id][j] += GGML_FP16_TO_FP32( q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim + i * config_.q_head_num * config_.head_dim + j]); } } for (int j = 0; j < config_.head_dim * config_.q_head_num; j++) { avg_q[batch_id][j] /= q_len; avg_q_fp16[batch_id][j] = GGML_FP32_TO_FP16(avg_q[batch_id][j]); } } int seq_len = cache_seqlens_[batch_id]; int block_num = (seq_len / config_.block_len) - local_block_num - init_block_num; if (block_num <= 0) { continue; } bool is_seq = true; for (int i = init_block_num + 1; i < (seq_len / config_.block_len) - local_block_num; i++) { if (block_table_before_retrieval_[batch_id][i] != block_table_before_retrieval_[batch_id][i - 1] + 1) { is_seq = false; break; } } if (is_seq) { int nth = backend->get_thread_num(); backend->do_work_stealing_job( nth, nullptr, [&](int task_id) { int ith = task_id; bool ok = llamafile_sgemm( block_num, 1, config_.q_head_num * config_.head_dim, anchor_.data() + (layer_idx * config_.max_block_num + block_table_before_retrieval_ [batch_id][init_block_num]) * config_.anchor_num * config_.q_head_num * config_.head_dim, config_.q_head_num * config_.head_dim, avg_q_fp16[batch_id].data(), config_.q_head_num * config_.head_dim, block_similar_[batch_id].data() + init_block_num, block_num, ith, nth, GGML_TASK_TYPE_COMPUTE, GGML_TYPE_F16, GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT); if (!ok) { printf("llamafile_sgemm failed\n"); } }, nullptr); } else { backend->do_work_stealing_job( block_num, nullptr, [&](int task_id) { int block_id = task_id + init_block_num; int block_idx = block_table_before_retrieval_[batch_id][block_id]; bool ok = llamafile_sgemm( 1, 1, config_.q_head_num * config_.head_dim, anchor_.data() + (layer_idx * config_.max_block_num + block_table_before_retrieval_[batch_id] [block_idx]) * config_.anchor_num * config_.q_head_num * config_.head_dim, config_.q_head_num * config_.head_dim, avg_q_fp16[batch_id].data(), config_.q_head_num * config_.head_dim, block_similar_[batch_id].data() + block_id, 1, 0, 1, GGML_TASK_TYPE_COMPUTE, GGML_TYPE_F16, GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT); if (!ok) { printf("llamafile_sgemm failed\n"); } }, nullptr); } } } else { backend->do_work_stealing_job( batch_size * max_block_num, nullptr, [&](int task_id) { int batch_id = task_id / max_block_num; int block_id = task_id % max_block_num; int seq_len = cache_seqlens_[batch_id]; if (block_id < init_block_num || block_id >= (seq_len / config_.block_len) - local_block_num) { return; } int block_idx = block_table_before_retrieval_[batch_id][block_id]; float sim = 0; for (int head_id = 0; head_id < config_.q_head_num; head_id++) { for (int i = 0; i < config_.head_dim; i++) { float q_i = 0, qa_i = std::numeric_limits::lowest(); for (int q_id = 0; q_id < q_len; q_id++) { q_i += GGML_FP16_TO_FP32( q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim + q_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]); } q_i /= q_len; for (int anchor_id = 0; anchor_id < config_.anchor_num; anchor_id++) { qa_i = std::max( qa_i, GGML_FP16_TO_FP32( anchor_[(long long)layer_idx * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]) * q_i); } sim += qa_i; } } block_similar_[batch_id][block_id] = sim; }, nullptr); } // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = end - start; // printf("layer %d time of calculating similarity: %f s\n", layer_idx, // diff.count()); } void KVCache::select_block_layer_(int batch_size, int layer_idx, int max_block_num, int init_block_num, int local_block_num, int pick_block_num) { // Timer start auto start = std::chrono::high_resolution_clock::now(); for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { if (cache_seqlens_[batch_idx] / config_.block_len <= init_block_num + pick_block_num + local_block_num) { block_table_after_retrieval_[batch_idx].swap( block_table_before_retrieval_[batch_idx]); selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0; continue; } for (int block_id = init_block_num; block_id < (cache_seqlens_[batch_idx] / config_.block_len) - local_block_num; block_id++) { top_similar_block_[batch_idx].push(std::make_pair( block_similar_[batch_idx][block_id], block_table_before_retrieval_[batch_idx][block_id])); if (top_similar_block_[batch_idx].size() > pick_block_num) { top_similar_block_[batch_idx].pop(); } } int i = 0; for (; i < init_block_num; i++) { block_table_after_retrieval_[batch_idx][i] = block_table_before_retrieval_[batch_idx][i]; } while (!top_similar_block_[batch_idx].empty()) { block_table_after_retrieval_[batch_idx][i] = top_similar_block_[batch_idx].top().second; top_similar_block_[batch_idx].pop(); i++; } for (; i < init_block_num + pick_block_num + local_block_num; i++) { block_table_after_retrieval_[batch_idx][i] = block_table_before_retrieval_[batch_idx] [(cache_seqlens_[batch_idx] / config_.block_len) - local_block_num + i - init_block_num - pick_block_num]; } if (cache_seqlens_[batch_idx] % config_.block_len != 0) { block_table_after_retrieval_[batch_idx][i] = block_table_before_retrieval_[batch_idx][( cache_seqlens_[batch_idx] / config_.block_len)]; cache_seqlens_[batch_idx] = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len; i++; } else { cache_seqlens_[batch_idx] = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len; } for (int j = 0; j < i; j++) { selected_blocks_history_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][j] = block_table_after_retrieval_[batch_idx][j]; } selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = i; } // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = end - start; // printf("layer %d time of selecting blocks: %f s\n", layer_idx, // diff.count()); } // retrieval kvcache, get the init_block_num block at beginning, top // pick_block_num similar and last local_block_num blocks. Each task // calculates the simlarity of a certain block with the query, then push // the block into the priority queue. Finally, the required blocks are // pushed into the block_table_after_retrieval_. void KVCache::retrieval_kvcache_layer_(const uint16_t *q_in_data, int init_block_num, int local_block_num, int pick_block_num, int q_len, int generate_token_idx, int batch_size, int layer_idx, int *cache_seqlens, int &max_block_num, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); max_block_num_after_retrieval_ = 0; if (pick_block_num != -1 && (generate_token_idx % config_.token_step != 0 || (layer_idx % config_.layer_step != config_.layer_offset))) { if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] == 0) { max_block_num_after_retrieval_ = max_block_num; block_table_after_retrieval_.swap(block_table_before_retrieval_); } else { max_block_num_after_retrieval_ = selected_blocks_num_history_ [(layer_idx - config_.layer_offset) / config_.layer_step]; for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { for (int i = 0; i < max_block_num_after_retrieval_; i++) { block_table_after_retrieval_[batch_idx][i] = selected_blocks_history_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx] [i]; } if (cache_seqlens[batch_idx] % config_.block_len == 1) { selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] += 1; int x = selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step]; int last_block_idx = block_table_before_retrieval_[batch_idx] [cache_seqlens[batch_idx] / config_.block_len]; selected_blocks_history_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx] [x - 1] = last_block_idx; block_table_after_retrieval_[batch_idx][x - 1] = last_block_idx; } cache_seqlens_[batch_idx] = (cache_seqlens_[batch_idx] % config_.block_len) + selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] * config_.block_len - config_.block_len; } } } else if (pick_block_num != -1) { max_block_num_after_retrieval_ = std::min(max_block_num, init_block_num + pick_block_num + local_block_num + 1); calculate_block_similarity_layer_(q_in_data, batch_size, layer_idx, q_len, max_block_num, cache_seqlens, init_block_num, local_block_num, pick_block_num, backend); select_block_layer_(batch_size, layer_idx, max_block_num, init_block_num, local_block_num, pick_block_num); } else { selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0; max_block_num_after_retrieval_ = max_block_num; block_table_after_retrieval_.swap(block_table_before_retrieval_); } // Timer end auto end = std::chrono::high_resolution_clock::now(); // printf("layer %d time of retrieval kvcache: %f s\n", layer_idx, // std::chrono::duration(end - start).count()); } void KVCache::calculate_sparsity_layer_(const uint16_t *q_in_data, float *attn_sparsity, int batch_size, int max_block_num, int *block_table, int *cache_seqlens, Backend *backend ) { // Timer start auto start = std::chrono::high_resolution_clock::now(); seq_len_ = config_.block_len; backend->do_work_stealing_job( batch_size * config_.kv_head_num * max_block_num, [&](int thread_id) { thread_cur_head_idx_[thread_id].first = -1; thread_cur_head_idx_[thread_id].second = -1; }, [&](int task_id) { int batch_id = task_id / (config_.kv_head_num * max_block_num); int head_id = (task_id % (config_.kv_head_num * max_block_num)) / max_block_num; int block_id = task_id % max_block_num; int thread_id = Backend::thread_local_id; // If the block is out of the sequence length, skip it. if (cache_seqlens[batch_id] / config_.block_len < block_id) { return; } int block_idx = block_table[batch_id * max_block_num + block_id]; if (cache_seqlens_[batch_id] / config_.block_len == block_id) { int seq_len = cache_seqlens_[batch_id] % config_.block_len; if (seq_len == 0) return; // Prepare the attention mask for the last block. int full_blocks = seq_len / 8; int remaining_bits = seq_len % 8; // Fill full blocks with 1s for (int i = 0; i < full_blocks; ++i) { thread_local_attn_mask_[thread_id][i] = 0xFF; } // Fill the remaining bits in the next block if (remaining_bits > 0 && full_blocks < seq_len_ / 8) { thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1; } else { thread_local_attn_mask_[thread_id][full_blocks] = 0; } for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) { thread_local_attn_mask_[thread_id][i] = 0; } if (config_.kv_type == ggml_type::GGML_TYPE_F16) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16, (void *)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim + head_id * n_gqa_ * config_.head_dim], seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } } else { if (config_.kv_type == ggml_type::GGML_TYPE_F16) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16, (void *)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim + head_id * n_gqa_ * config_.head_dim], seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } } for (int i = 0; i < n_gqa_; i++) { block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] = thread_local_attn_lse_[thread_id][i]; } int cur_batch_idx = thread_cur_head_idx_[thread_id].first; int cur_head_id = thread_cur_head_idx_[thread_id].second; if (batch_id == cur_batch_idx && head_id == cur_head_id) { for (int i = 0; i < n_gqa_; i++) { float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] + std::log( 1.0 + std::exp(thread_local_attn_lse_[thread_id][i] - thread_local_cur_attn_lse_[thread_id][i])); ggml_vec_scale_f32( config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse)); ggml_vec_scale_f32( config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse)); for (int j = 0; j < config_.head_dim; j++) { thread_local_cur_output_fp32_[thread_id] [i * config_.head_dim + j] += thread_local_output_fp32_[thread_id] [i * config_.head_dim + j]; } thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse; } } else { if (cur_batch_idx != -1) { mutex_[cur_batch_idx][cur_head_id]->lock(); for (int i = 0; i < n_gqa_; i++) { if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) { attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i]; for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] = thread_local_cur_output_fp32_ [thread_id] [i * config_.head_dim + j]; } continue; } float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] + std::log( 1.0 + std::exp( thread_local_cur_attn_lse_[thread_id][i] - attn_lse_[cur_batch_idx][cur_head_id][i])); ggml_vec_scale_f32( config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim, std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse)); ggml_vec_scale_f32( config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse)); for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] += thread_local_cur_output_fp32_ [thread_id][i * config_.head_dim + j]; } attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse; } mutex_[cur_batch_idx][cur_head_id]->unlock(); } thread_cur_head_idx_[thread_id].first = batch_id; thread_cur_head_idx_[thread_id].second = head_id; for (int i = 0; i < n_gqa_; i++) { thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i]; for (int j = 0; j < config_.head_dim; j++) { thread_local_cur_output_fp32_ [thread_id][i * config_.head_dim + j] = thread_local_output_fp32_[thread_id] [i * config_.head_dim + j]; } } } }, // Merge the results of the remaining blocks. [&](int thread_id) { int cur_batch_idx = thread_cur_head_idx_[thread_id].first; int cur_head_id = thread_cur_head_idx_[thread_id].second; if (cur_head_id != -1) { mutex_[cur_batch_idx][cur_head_id]->lock(); for (int i = 0; i < n_gqa_; i++) { float new_attn_lse; if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) { attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i]; for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] = thread_local_cur_output_fp32_ [thread_id] [i * config_.head_dim + j]; } continue; } new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] + std::log( 1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] - attn_lse_[cur_batch_idx][cur_head_id][i])); ggml_vec_scale_f32( config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim, std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse)); ggml_vec_scale_f32( config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse)); for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] += thread_local_cur_output_fp32_[thread_id] [i * config_.head_dim + j]; } attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse; } mutex_[cur_batch_idx][cur_head_id]->unlock(); } }); for (int i = 0; i < batch_size; i++) { for (int j = 0; j < max_block_num_after_retrieval_; j++) { int block_idx = block_table_after_retrieval_[i][j]; for (int k = 0; k < config_.q_head_num; k++) { attn_sparsity[i * config_.q_head_num + k] += std::exp(block_lse_[i][block_idx][k] - attn_lse_[i][k / n_gqa_][k % n_gqa_]); } } } // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = end - start; // printf("layer %d time of calculating sparsity: %f s\n", layer_id_, // diff.count()); } void KVCache::attn_initialize_kvhead_(int batch_size, int layer_idx, int *block_table, int &max_block_num, int *cache_seqlens) { // Timer start auto start = std::chrono::high_resolution_clock::now(); for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { // initialize output_fp32_ and attn_lse_ for (int i = 0; i < config_.kv_head_num; i++) { for (int j = 0; j < n_gqa_ * config_.head_dim; j++) { output_fp32_[batch_idx][i][j] = 0; } for (int j = 0; j < n_gqa_; j++) { attn_lse_[batch_idx][i][j] = 0; } } // clear top_similar_block_ while (!top_similar_block_[batch_idx].empty()) top_similar_block_[batch_idx].pop(); } for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { cache_seqlens_[batch_idx] = cache_seqlens[batch_idx]; for (int i = 0; i < max_block_num; i++) { for (int j = 0; j < config_.kv_head_num; j++) { block_table_before_retrieval_kvhead_[batch_idx][i][j] = block_table[batch_idx * max_block_num + i]; block_similar_kv_head_[batch_idx][i][j] = 0; } } } // Timer end auto end = std::chrono::high_resolution_clock::now(); // printf("layer %d time of initializing attn: %f s\n", layer_idx, // std::chrono::duration(end - start).count()); } void KVCache::retrieval_kvcache_kvhead_(const uint16_t *q_in_data, int init_block_num, int local_block_num, int pick_block_num, int q_len, int generate_token_idx, int batch_size, int layer_idx, int *cache_seqlens, int &max_block_num, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); max_block_num_after_retrieval_ = 0; if (pick_block_num != -1 && (generate_token_idx % config_.token_step != 0 || (layer_idx % config_.layer_step != config_.layer_offset))) { if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] == 0) { max_block_num_after_retrieval_ = max_block_num; for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { for (int i = 0; i < max_block_num; i++) { for (int j = 0; j < config_.kv_head_num; j++) { block_table_after_retrieval_kvhead_[batch_idx][i][j] = block_table_before_retrieval_kvhead_[batch_idx][i] [j]; } } } } else { max_block_num_after_retrieval_ = selected_blocks_num_history_ [(layer_idx - config_.layer_offset) / config_.layer_step]; for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { for (int i = 0; i < max_block_num_after_retrieval_; i++) { for (int j = 0; j < config_.kv_head_num; j++) { block_table_after_retrieval_kvhead_[batch_idx][i][j] = selected_blocks_history_kvhead_ [(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][i][j]; } } if (cache_seqlens[batch_idx] % config_.block_len == 1) { selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] += 1; int x = selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step]; for (int i = 0; i < config_.kv_head_num; i++) { int last_block_idx = block_table_before_retrieval_kvhead_ [batch_idx][cache_seqlens[batch_idx] / config_.block_len][i]; selected_blocks_history_kvhead_[(layer_idx - config_.layer_offset) / config_.layer_step] [batch_idx][x - 1][i] = last_block_idx; block_table_after_retrieval_kvhead_[batch_idx][x - 1] [i] = last_block_idx; } } cache_seqlens_[batch_idx] = std::min( cache_seqlens_[batch_idx], (cache_seqlens_[batch_idx] % config_.block_len) + (init_block_num + pick_block_num + local_block_num) * config_.block_len); } } } else if (pick_block_num != -1) { max_block_num_after_retrieval_ = std::min(max_block_num, init_block_num + pick_block_num + local_block_num + 1); calculate_block_similarity_kvhead_(q_in_data, batch_size, layer_idx, q_len, max_block_num, cache_seqlens, init_block_num, local_block_num, pick_block_num, backend); select_block_kvhead_(batch_size, layer_idx, max_block_num, init_block_num, local_block_num, pick_block_num); } else { selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0; max_block_num_after_retrieval_ = max_block_num; for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { for (int i = 0; i < max_block_num; i++) { for (int j = 0; j < config_.kv_head_num; j++) { block_table_after_retrieval_kvhead_[batch_idx][i][j] = block_table_before_retrieval_kvhead_[batch_idx][i][j]; } } } } // Timer end auto end = std::chrono::high_resolution_clock::now(); // printf("layer %d time of retrieval kvcache: %f s\n", layer_idx, // std::chrono::duration(end - start).count()); } void KVCache::calculate_sparsity_kvhead_(const uint16_t *q_in_data, float *attn_sparsity, int batch_size, int max_block_num, int *block_table, int *cache_seqlens, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); seq_len_ = config_.block_len; backend->do_work_stealing_job( batch_size * config_.kv_head_num * max_block_num, [&](int thread_id) { thread_cur_head_idx_[thread_id].first = -1; thread_cur_head_idx_[thread_id].second = -1; }, [&](int task_id) { int batch_id = task_id / (config_.kv_head_num * max_block_num); int head_id = (task_id % (config_.kv_head_num * max_block_num)) / max_block_num; int block_id = task_id % max_block_num; int thread_id = Backend::thread_local_id; // If the block is out of the sequence length, skip it. if (cache_seqlens[batch_id] / config_.block_len < block_id) { return; } int block_idx = block_table[batch_id * max_block_num + block_id]; if (cache_seqlens_[batch_id] / config_.block_len == block_id) { int seq_len = cache_seqlens_[batch_id] % config_.block_len; if (seq_len == 0) return; // Prepare the attention mask for the last block. int full_blocks = seq_len / 8; int remaining_bits = seq_len % 8; // Fill full blocks with 1s for (int i = 0; i < full_blocks; ++i) { thread_local_attn_mask_[thread_id][i] = 0xFF; } // Fill the remaining bits in the next block if (remaining_bits > 0 && full_blocks < seq_len_ / 8) { thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1; } else { thread_local_attn_mask_[thread_id][full_blocks] = 0; } for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) { thread_local_attn_mask_[thread_id][i] = 0; } if (config_.kv_type == ggml_type::GGML_TYPE_F16) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16, (void *)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim + head_id * n_gqa_ * config_.head_dim], seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } } else { if (config_.kv_type == ggml_type::GGML_TYPE_F16) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16, (void *)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim + head_id * n_gqa_ * config_.head_dim], seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { attn_with_kvcache_one_block_( config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(), thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data()); dequantize_row_q8_0( thread_local_output_q8_0_[thread_id].data(), thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim); } } for (int i = 0; i < n_gqa_; i++) { block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] = thread_local_attn_lse_[thread_id][i]; } int cur_batch_idx = thread_cur_head_idx_[thread_id].first; int cur_head_id = thread_cur_head_idx_[thread_id].second; if (batch_id == cur_batch_idx && head_id == cur_head_id) { for (int i = 0; i < n_gqa_; i++) { float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] + std::log( 1.0 + std::exp(thread_local_attn_lse_[thread_id][i] - thread_local_cur_attn_lse_[thread_id][i])); ggml_vec_scale_f32( config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse)); ggml_vec_scale_f32( config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse)); for (int j = 0; j < config_.head_dim; j++) { thread_local_cur_output_fp32_[thread_id] [i * config_.head_dim + j] += thread_local_output_fp32_[thread_id] [i * config_.head_dim + j]; } thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse; } } else { if (cur_batch_idx != -1) { mutex_[cur_batch_idx][cur_head_id]->lock(); for (int i = 0; i < n_gqa_; i++) { if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) { attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i]; for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] = thread_local_cur_output_fp32_ [thread_id] [i * config_.head_dim + j]; } continue; } float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] + std::log( 1.0 + std::exp( thread_local_cur_attn_lse_[thread_id][i] - attn_lse_[cur_batch_idx][cur_head_id][i])); ggml_vec_scale_f32( config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim, std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse)); ggml_vec_scale_f32( config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse)); for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] += thread_local_cur_output_fp32_ [thread_id][i * config_.head_dim + j]; } attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse; } mutex_[cur_batch_idx][cur_head_id]->unlock(); } thread_cur_head_idx_[thread_id].first = batch_id; thread_cur_head_idx_[thread_id].second = head_id; for (int i = 0; i < n_gqa_; i++) { thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i]; for (int j = 0; j < config_.head_dim; j++) { thread_local_cur_output_fp32_ [thread_id][i * config_.head_dim + j] = thread_local_output_fp32_[thread_id] [i * config_.head_dim + j]; } } } }, // Merge the results of the remaining blocks. [&](int thread_id) { int cur_batch_idx = thread_cur_head_idx_[thread_id].first; int cur_head_id = thread_cur_head_idx_[thread_id].second; if (cur_head_id != -1) { mutex_[cur_batch_idx][cur_head_id]->lock(); for (int i = 0; i < n_gqa_; i++) { float new_attn_lse; if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) { attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i]; for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] = thread_local_cur_output_fp32_ [thread_id] [i * config_.head_dim + j]; } continue; } new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] + std::log( 1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] - attn_lse_[cur_batch_idx][cur_head_id][i])); ggml_vec_scale_f32( config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim, std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse)); ggml_vec_scale_f32( config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim, std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse)); for (int j = 0; j < config_.head_dim; j++) { output_fp32_[cur_batch_idx][cur_head_id] [i * config_.head_dim + j] += thread_local_cur_output_fp32_[thread_id] [i * config_.head_dim + j]; } attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse; } mutex_[cur_batch_idx][cur_head_id]->unlock(); } }); for (int i = 0; i < batch_size; i++) { for (int j = 0; j < max_block_num_after_retrieval_; j++) { for (int k = 0; k < config_.q_head_num; k++) { int block_idx = block_table_after_retrieval_kvhead_[i][j][k / n_gqa_]; attn_sparsity[i * config_.q_head_num + k] += std::exp(block_lse_[i][block_idx][k] - attn_lse_[i][k / n_gqa_][k % n_gqa_]); } } } // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = end - start; // printf("layer %d time of calculating sparsity: %f s\n", layer_id_, // diff.count()); } void KVCache::calculate_block_similarity_kvhead_( const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len, int max_block_num, int *cache_seqlens, int init_block_num, int local_block_num, int pick_block_num, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); backend->do_work_stealing_job( batch_size * max_block_num, nullptr, [&](int task_id) { int batch_id = task_id / max_block_num; int block_id = task_id % max_block_num; int seq_len = cache_seqlens_[batch_id]; if (block_id < init_block_num || block_id >= (seq_len / config_.block_len) - local_block_num) { return; } int block_idx = block_table_before_retrieval_kvhead_[batch_id][block_id][0]; for (int head_id = 0; head_id < config_.q_head_num; head_id++) { for (int i = 0; i < config_.head_dim; i++) { float q_i = 0, qa_i = std::numeric_limits::lowest(); for (int q_id = 0; q_id < q_len; q_id++) { q_i += GGML_FP16_TO_FP32( q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim + q_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]); } q_i /= q_len; for (int anchor_id = 0; anchor_id < config_.anchor_num; anchor_id++) { qa_i = std::max( qa_i, GGML_FP16_TO_FP32( anchor_[layer_idx * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]) * q_i); } block_similar_kv_head_[batch_id][block_id] [head_id / n_gqa_] += qa_i; } } }, nullptr); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = end - start; // printf("layer %d time of calculating similarity: %f s\n", layer_idx, // diff.count()); } void KVCache::select_block_kvhead_(int batch_size, int layer_idx, int max_block_num, int init_block_num, int local_block_num, int pick_block_num) { // Timer start auto start = std::chrono::high_resolution_clock::now(); for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { int cache_len_after_retrieval = 0; if (cache_seqlens_[batch_idx] / config_.block_len <= init_block_num + pick_block_num + local_block_num) { selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0; for (int i = 0; i < max_block_num; i++) { for (int j = 0; j < config_.kv_head_num; j++) { block_table_after_retrieval_kvhead_[batch_idx][i][j] = block_table_before_retrieval_kvhead_[batch_idx][i][j]; } } continue; } for (int head_id = 0; head_id < config_.kv_head_num; head_id++) { for (int block_id = init_block_num; block_id < (cache_seqlens_[batch_idx] / config_.block_len) - local_block_num; block_id++) { top_similar_block_[batch_idx].push(std::make_pair( block_similar_kv_head_[batch_idx][block_id][head_id], block_table_before_retrieval_kvhead_[batch_idx][block_id] [head_id])); if (top_similar_block_[batch_idx].size() > pick_block_num) { top_similar_block_[batch_idx].pop(); } } int i = 0; for (; i < init_block_num; i++) { block_table_after_retrieval_kvhead_[batch_idx][i][head_id] = block_table_before_retrieval_kvhead_[batch_idx][i][head_id]; } while (!top_similar_block_[batch_idx].empty()) { block_table_after_retrieval_kvhead_[batch_idx][i][head_id] = top_similar_block_[batch_idx].top().second; top_similar_block_[batch_idx].pop(); i++; } for (; i < init_block_num + pick_block_num + local_block_num; i++) { block_table_after_retrieval_kvhead_[batch_idx][i][head_id] = block_table_before_retrieval_kvhead_ [batch_idx] [(cache_seqlens_[batch_idx] / config_.block_len) - local_block_num + i - init_block_num - pick_block_num] [head_id]; } if (cache_seqlens_[batch_idx] % config_.block_len != 0) { block_table_after_retrieval_kvhead_[batch_idx][i][head_id] = block_table_before_retrieval_kvhead_[batch_idx][( cache_seqlens_[batch_idx] / config_.block_len)] [head_id]; cache_len_after_retrieval = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len; i++; } else { cache_len_after_retrieval = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len; } for (int j = 0; j < i; j++) { selected_blocks_history_kvhead_ [(layer_idx - config_.layer_offset) / config_.layer_step] [batch_idx][j][head_id] = block_table_after_retrieval_kvhead_[batch_idx][j] [head_id]; } } cache_seqlens_[batch_idx] = cache_len_after_retrieval; selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = (cache_len_after_retrieval + config_.block_len - 1) / config_.block_len; } // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = end - start; // printf("layer %d time of selecting block: %f s\n", layer_idx, // diff.count()) } void KVCache::get_attn_sparsity(const ggml_fp16_t *q_in, float *attn_sparsity, int layer_idx, int generate_token_idx, int q_len, int batch_size, int max_block_num, int *block_table, int *cache_seqlens, int *block_table_origin, int *cache_seqlens_origin, int max_block_num_origin, int topk, int local, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); layer_id_ = layer_idx; int thread_num = backend->get_thread_num(); batch_size = 1; const uint16_t *q_in_data = const_cast(q_in); quantize_q_(q_in_data, batch_size); if (config_.retrieval_type == RetrievalType::LAYER) { attn_initialize_layer_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens); retrieval_kvcache_layer_(q_in_data, 1, local, topk, q_len, generate_token_idx, batch_size, layer_idx, cache_seqlens, max_block_num, backend); calculate_sparsity_layer_(q_in_data, attn_sparsity, batch_size, max_block_num_origin, block_table_origin, cache_seqlens_origin, backend); } else if (config_.retrieval_type == RetrievalType::KVHEAD) { attn_initialize_kvhead_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens); retrieval_kvcache_kvhead_(q_in_data, 1, local, topk, q_len, generate_token_idx, batch_size, layer_idx, cache_seqlens, max_block_num, backend); calculate_sparsity_kvhead_(q_in_data, attn_sparsity, batch_size, max_block_num_origin, block_table_origin, cache_seqlens_origin, backend); } } void KVCache::attn_with_kvcache_one_block_( int head_dim, int bsz, ggml_type q_type, // GGML data type of `Q`, only supports fp16 and q8_0 // [bsz, head_dim] // Quantization is always on the head_dim dimension (per_token). If // head_dim % 32 != 0, an error will be raised. The size must be bsz * // head_dim/32 * qtype_size. const void *q, int past_kv_len, int past_kv_offset, bool is_full_attn, // true indicates a full 1 mask // If is_full_attn = false, a bit matrix representing the mask is // passed. [bsz, past_kv_len] const uint8_t *attn_mask, ggml_type k_type, // GGML data type of `K Cache`, only supports fp16, // q4_0, q8_0 int k_quant_type, // 0 for per_token, 1 for per_channel, others raise an // error // [seq_len, head_dim] // If quant_type == 0, head_dim % 32 must be 0. // If quant_type == 1, seq_len % 32 must be 0. const void *k_cache, // k_anchor_type must be fp16 int num_k_anchor, // num_k_anchor == 0 indicates no anchor // [num_k_anchor, head_dim] const void *k_cache_anchors, // Each token is associated with the nearest previous position's anchor, // with the same distance. const int *k_cache_anchor_pos, // v_cache similar to k_cache ggml_type v_type, int v_quant_type, // [head_dim, seq_len] const void *v_cache, int num_v_anchor, const void *v_cache_anchors, const int *v_cache_anchor_pos, // Pre-allocated buffer for intermediate calculations [bsz, // past_kv_len]. No malloc is performed inside this function. float *attn_score, // Output: [bsz, head_dim], with the same type as q_type void *output, // [bsz] float *lse, // Pre-allocated temporary buffer with sufficient size: // (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 * past_kv_len * // head_dim + past_kv_len * head_dim / 32) bytes. void *draft, // Apply rotary embedding online const int *rotary_angle, const void *rotary_cos, const void *rotary_sin // rotary_cos=None, // rotary_sin=None, // cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, // cache_batch_idx: Optional[torch.Tensor] = None, // rotary_interleaved=True, // // Not supported for now // window_size=(-1, -1), # -1 means infinite context window // alibi_slopes=None, ) { assert(head_dim % 32 == 0); assert(k_quant_type == 0); assert(v_quant_type == 1); assert(q_type == GGML_TYPE_F16 || q_type == GGML_TYPE_Q8_0); if (q_type == GGML_TYPE_F16) { assert(k_type == GGML_TYPE_F16); assert(v_type == GGML_TYPE_F16); // attn = q * k + q * k_anchor // TODO: anchor assert(num_k_anchor == 0); if (rotary_angle != nullptr) { ggml_fp16_t *k_cache_with_rope_fp16 = (reinterpret_cast(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 + sizeof(float) * bsz * head_dim); // dequant k_cache and apply rope // k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i) // k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i) // k(i)cos(i) -> k_rope(i) // k(i)sin(i+l) -> k_rope(i+l) // k(i)cos(i) -> k_rope(i) // -k(i)sin(i-l) -> k_rope(i-l) std::vector block_fp32(32); for (int k = 0; k < past_kv_len; k++) { int angle = rotary_angle[k]; for (int l = 0; l < head_dim / 32; l++) { for (int m = 0; m < 32; m++) { float x = GGML_FP16_TO_FP32(( (ggml_fp16_t *)k_cache)[k * head_dim + l * 32 + m]); float sin_val = GGML_FP16_TO_FP32( ((ggml_fp16_t *) rotary_sin)[angle * head_dim + l * 32 + m]); float cos_val = GGML_FP16_TO_FP32( ((ggml_fp16_t *) rotary_cos)[angle * head_dim + l * 32 + m]); if (l * 32 + m < head_dim / 2) { k_cache_with_rope_fp16[k * head_dim + l * 32 + m] = GGML_FP32_TO_FP16(x * cos_val); k_cache_with_rope_fp16[k * head_dim + l * 32 + m + head_dim / 2] = GGML_FP32_TO_FP16(-x * sin_val); } else { k_cache_with_rope_fp16[k * head_dim + l * 32 + m] = GGML_FP32_TO_FP16( GGML_FP16_TO_FP32( k_cache_with_rope_fp16[k * head_dim + l * 32 + m]) + x * sin_val); k_cache_with_rope_fp16[k * head_dim + l * 32 + m - head_dim / 2] = GGML_FP32_TO_FP16( GGML_FP16_TO_FP32( k_cache_with_rope_fp16[k * head_dim + l * 32 + m - head_dim / 2]) - x * cos_val); } } } } llamafile_sgemm(past_kv_len, bsz, head_dim, (ggml_fp16_t *)k_cache_with_rope_fp16, head_dim, (ggml_fp16_t *)q, head_dim, attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT); } else { bool ok = llamafile_sgemm( past_kv_len, bsz, head_dim, (ggml_fp16_t *)k_cache, head_dim, (ggml_fp16_t *)q, head_dim, attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT); if (!ok) { printf("llamafile_sgemm failed\n"); } } // attn = attn * scale float scale_factor = 1.0 / std::sqrt(float(head_dim)); ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor); // attn = attn & mask if (!is_full_attn) { for (int i = 0; i < bsz; i++) { for (int j = 0; j < past_kv_len; j++) { int index = i * past_kv_len + j; if (!(attn_mask[j / 8] & (1 << (j % 8)))) { attn_score[index] = std::numeric_limits::lowest(); } } } } // attn = softmax(attn) for (int i = 0; i < bsz; i++) { float sum_exp = 0; for (int j = 0; j < past_kv_len; j++) { attn_score[i * past_kv_len + j] = std::exp(attn_score[i * past_kv_len + j]); sum_exp += attn_score[i * past_kv_len + j]; } for (int j = 0; j < past_kv_len; j++) { attn_score[i * past_kv_len + j] /= sum_exp; } if (lse != nullptr) { lse[i] = std::log(sum_exp); } } // output = attn * v + attn * v_anchor // std::vector sum(bsz * head_dim); float *sum = reinterpret_cast(reinterpret_cast(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0); // float* attn_score_fp16(bsz, past_kv_len) ggml_fp16_t *attn_score_fp16 = (reinterpret_cast( reinterpret_cast(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 + sizeof(float) * bsz * head_dim)); for (int i = 0; i < bsz * past_kv_len; i++) { attn_score_fp16[i] = GGML_FP32_TO_FP16(attn_score[i]); } // TODO: anchor assert(num_v_anchor == 0); bool ok = llamafile_sgemm( head_dim, bsz, past_kv_len, (ggml_fp16_t *)v_cache, past_kv_len, (ggml_fp16_t *)attn_score_fp16, past_kv_len, sum, head_dim, 0, 1, GGML_TASK_TYPE_COMPUTE, v_type, GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT); if (!ok) { printf("llamafile_sgemm failed\n"); } // copy to output for (int i = 0; i < bsz; i++) { for (int j = 0; j < head_dim; j++) { ((float *)output)[i * head_dim + j] = sum[i * head_dim + j]; } } } else { assert(k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q8_0); assert(v_type == GGML_TYPE_Q4_0 || v_type == GGML_TYPE_Q8_0); // attn = q * k + q * k_anchor // TODO: anchor assert(num_k_anchor == 0); if (rotary_angle != nullptr) { ggml_fp16_t *k_cache_with_rope_fp16 = (reinterpret_cast(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 + sizeof(float) * bsz * head_dim); block_q4_0 *k_cache_with_rope_q4 = (reinterpret_cast(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 + sizeof(float) * bsz * head_dim) + sizeof(ggml_fp16_t) * bsz * head_dim; // dequant k_cache and apply rope // k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i) // k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i) // k(i)cos(i) -> k_rope(i) // k(i)sin(i+l) -> k_rope(i+l) // k(i)cos(i) -> k_rope(i) // -k(i)sin(i-l) -> k_rope(i-l) std::vector block_fp32(32); for (int k = 0; k < past_kv_len; k++) { int angle = rotary_angle[k]; for (int l = 0; l < head_dim / 32; l++) { block_q4_0 block = ((block_q4_0 *)k_cache)[k * head_dim / 32 + l]; dequantize_row_q4_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { float sin_val = GGML_FP16_TO_FP32( ((ggml_fp16_t *) rotary_sin)[angle * head_dim + l * 32 + m]); float cos_val = GGML_FP16_TO_FP32( ((ggml_fp16_t *) rotary_cos)[angle * head_dim + l * 32 + m]); if (l * 32 + m < head_dim / 2) { k_cache_with_rope_fp16[k * head_dim + l * 32 + m] = GGML_FP32_TO_FP16(block_fp32[m] * cos_val); k_cache_with_rope_fp16[k * head_dim + l * 32 + m + head_dim / 2] = GGML_FP32_TO_FP16(-block_fp32[m] * sin_val); } else { k_cache_with_rope_fp16[k * head_dim + l * 32 + m] += GGML_FP32_TO_FP16(block_fp32[m] * sin_val); k_cache_with_rope_fp16[k * head_dim + l * 32 + m - head_dim / 2] -= GGML_FP32_TO_FP16(block_fp32[m] * cos_val); } } } } // quantize k_cache_with_rope_fp16 for (int k = 0; k < past_kv_len; k++) { for (int l = 0; l < head_dim / 32; l++) { for (int m = 0; m < 32; m++) { block_fp32[m] = GGML_FP16_TO_FP32( k_cache_with_rope_fp16[k * head_dim + l * 32 + m]); } quantize_row_q4_0( block_fp32.data(), &k_cache_with_rope_q4[k * head_dim / 32 + l], 32); } } llamafile_sgemm(past_kv_len, bsz, head_dim / 32, (block_q4_0 *)k_cache_with_rope_q4, head_dim / 32, (block_q8_0 *)q, head_dim / 32, attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_Q8_0, GGML_TYPE_F32, GGML_PREC_DEFAULT); } else { llamafile_sgemm(past_kv_len, bsz, head_dim / 32, (block_q4_0 *)k_cache, head_dim / 32, (block_q8_0 *)q, head_dim / 32, attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_Q8_0, GGML_TYPE_F32, GGML_PREC_DEFAULT); } // attn = attn * scale float scale_factor = 1.0 / std::sqrt(float(head_dim)); ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor); // attn = attn & mask if (!is_full_attn) { for (int i = 0; i < bsz; i++) { for (int j = 0; j < past_kv_len; j++) { int index = i * past_kv_len + j; if (!(attn_mask[j / 8] & (1 << (j % 8)))) { attn_score[index] = std::numeric_limits::lowest(); } } } } // attn = softmax(attn) for (int i = 0; i < bsz; i++) { float sum_exp = 0; for (int j = 0; j < past_kv_len; j++) { attn_score[i * past_kv_len + j] = std::exp(attn_score[i * past_kv_len + j]); sum_exp += attn_score[i * past_kv_len + j]; } for (int j = 0; j < past_kv_len; j++) { attn_score[i * past_kv_len + j] /= sum_exp; } if (lse != nullptr) { lse[i] = std::log(sum_exp); } } // output = attn * v + attn * v_anchor // std::vector attn_q8_0(bsz * past_kv_len / QK8_0); block_q8_0 *attn_q8_0 = reinterpret_cast(draft); quantize_row_q8_0(attn_score, attn_q8_0, bsz * past_kv_len); // std::vector sum(bsz * head_dim); float *sum = reinterpret_cast(reinterpret_cast(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0); // TODO: anchor assert(num_v_anchor == 0); llamafile_sgemm(head_dim, bsz, past_kv_len / 32, (block_q4_0 *)v_cache, past_kv_len / 32, attn_q8_0, past_kv_len / 32, sum, head_dim, 0, 1, GGML_TASK_TYPE_COMPUTE, v_type, GGML_TYPE_Q8_0, GGML_TYPE_F32, GGML_PREC_DEFAULT); quantize_row_q8_0(sum, (block_q8_0 *)output, bsz * head_dim); } } ================================================ FILE: archive/csrc/ktransformers_ext/operators/kvcache/kvcache_load_dump.cpp ================================================ /** * @Description : * @Author : Jianwei Dong * @Date : 2024-08-26 22:47:06 * @Version : 1.0.0 * @LastEditors : Jianwei Dong * @LastEditTime : 2024-08-26 22:47:06 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "kvcache.h" #include void KVCache::load_kvcache(std::string tensor_file_path, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); std::ifstream ifs_tensor(tensor_file_path, std::ios::binary); if (!ifs_tensor) { throw std::runtime_error("Failed to open tensor file"); } ifs_tensor.read(reinterpret_cast(&cache_total_len_), sizeof(cache_total_len_)); int past_block_num = (cache_total_len_ + config_.block_len - 1) / config_.block_len; printf("cache_total_len: %d, past_block_num: %d\n", cache_total_len_, past_block_num); for (int i = 0; i < config_.layer_num; ++i) { past_block_num_[i] = past_block_num; } ifs_tensor.read(reinterpret_cast(anchor_.data()), anchor_.size() * sizeof(ggml_fp16_t)); for (int i = 0; i < config_.layer_num; ++i) { for (int j = 0; j < config_.kv_head_num; ++j) { for (int k = 0; k < past_block_num_[i]; ++k) { if (config_.kv_type == GGML_TYPE_F16) { ifs_tensor.read( reinterpret_cast(k_cache_fp16_[i][j][k].data()), k_cache_fp16_[i][j][k].size() * sizeof(ggml_fp16_t)); ifs_tensor.read( reinterpret_cast(v_cache_fp16_[i][j][k].data()), v_cache_fp16_[i][j][k].size() * sizeof(ggml_fp16_t)); } else if (config_.kv_type == GGML_TYPE_Q4_0) { ifs_tensor.read( reinterpret_cast(k_cache_q4[i][j][k].data()), k_cache_q4[i][j][k].size() * sizeof(block_q4_0)); ifs_tensor.read( reinterpret_cast(v_cache_q4[i][j][k].data()), v_cache_q4[i][j][k].size() * sizeof(block_q4_0)); } } } for (int k = 0; k < past_block_num_[i]; ++k) { for (int l = 0; l < config_.block_len; l++) { ifs_tensor.read( reinterpret_cast(importance_[i][k][l].data()), importance_[i][k][l].size() * sizeof(ggml_fp16_t)); } } } ifs_tensor.close(); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = end - start; printf("time of load: %f s\n", diff.count()); } void KVCache::dump_kvcache(int *block_table, int cache_total_len, std::string tensor_file_path, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); std::ofstream ofs(tensor_file_path, std::ios::binary); printf("dump_kvcache: %s\n", tensor_file_path.c_str()); if (!ofs.is_open()) { std::cerr << "Cannot open file " << tensor_file_path << std::endl; return; } ofs.write(reinterpret_cast(&cache_total_len), sizeof(cache_total_len)); int past_block_num = (cache_total_len + config_.block_len - 1) / config_.block_len; printf("cache_total_len: %d, past_block_num: %d\n", cache_total_len, past_block_num); ofs.write(reinterpret_cast(anchor_.data()), anchor_.size() * sizeof(ggml_fp16_t)); for (int i = 0; i < config_.layer_num; ++i) { for (int j = 0; j < config_.kv_head_num; ++j) { for (int k = 0; k < past_block_num; ++k) { int block_idx = block_table[k]; if (config_.kv_type == GGML_TYPE_F16) { ofs.write(reinterpret_cast( k_cache_fp16_[i][j][block_idx].data()), k_cache_fp16_[i][j][block_idx].size() * sizeof(ggml_fp16_t)); ofs.write(reinterpret_cast( v_cache_fp16_[i][j][block_idx].data()), v_cache_fp16_[i][j][block_idx].size() * sizeof(ggml_fp16_t)); } else if (config_.kv_type == GGML_TYPE_Q4_0) { ofs.write(reinterpret_cast( k_cache_q4[i][j][block_idx].data()), k_cache_q4[i][j][block_idx].size() * sizeof(block_q4_0)); ofs.write(reinterpret_cast( v_cache_q4[i][j][block_idx].data()), v_cache_q4[i][j][block_idx].size() * sizeof(block_q4_0)); } } } for (int k = 0; k < past_block_num; ++k) { int block_idx = block_table[k]; for (int l = 0; l < config_.block_len; l++) { ofs.write(reinterpret_cast( importance_[i][block_idx][l].data()), importance_[i][block_idx][l].size() * sizeof(ggml_fp16_t)); } } } ofs.close(); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = end - start; printf("time of dump: %f s\n", diff.count()); } ================================================ FILE: archive/csrc/ktransformers_ext/operators/kvcache/kvcache_read_write.cpp ================================================ /** * @Description : * @Author : Jianwei Dong * @Date : 2024-08-26 22:47:06 * @Version : 1.0.0 * @LastEditors : Jianwei Dong * @LastEditTime : 2024-08-26 22:47:06 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "kvcache.h" #include void KVCache::get_anchor_one_block(ggml_fp16_t *anchor, int layer_id, int block_idx, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); layer_id_ = layer_id; block_idx = block_idx; seq_len_ = config_.block_len; anchor_data_ = const_cast(anchor); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; printf("layer %d block %d time of reading anchor: %f s\n", layer_id, block_idx, duration.count()); } void KVCache::update_anchor_one_block(const ggml_fp16_t *anchor, int layer_id, int block_idx, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); layer_id_ = layer_id; block_idx = block_idx; seq_len_ = config_.block_len; anchor_data_ = const_cast(anchor); // Each task updates the anchor of a certain position // backend->do_work_stealing_job(config_.anchor_num, [&](int task_id) { // int k = task_id % config_.anchor_num; // int head_id = task_id / config_.anchor_num; // memcpy(anchor_[layer_id_][head_id][block_idx].data() + // k * config_.head_dim, // anchor_data_ + k * config_.head_dim, // sizeof(uint16_t) * config_.head_dim); // }); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; printf("layer %d block %d time of writting anchor: %f s\n", layer_id, block_idx, duration.count()); } void KVCache::update_importance_one_block(const ggml_fp16_t *importance, int layer_id, int block_idx, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); layer_id_ = layer_id; block_idx = block_idx; seq_len_ = config_.block_len; importance_data_ = const_cast(importance); // Each task updates the importance of a certain position backend->do_work_stealing_job( config_.block_len, nullptr, [&](int task_id) { int k = task_id; memcpy(importance_[layer_id_][block_idx].data() + k, importance_data_ + k, sizeof(uint16_t)); }, nullptr); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; printf("layer %d block %d time of writting importance: %f s\n", layer_id, block_idx, duration.count()); } void KVCache::get_importance_one_block(ggml_fp16_t *importance, int layer_id, int block_idx, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); layer_id_ = layer_id; block_idx = block_idx; seq_len_ = config_.block_len; importance_data_ = const_cast(importance); // Each task updates the importance of a certain position backend->do_work_stealing_job( config_.block_len, nullptr, [&](int task_id) { int k = task_id; memcpy(importance_data_ + k, importance_[layer_id_][block_idx].data() + k, sizeof(uint16_t)); }, nullptr); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; printf("layer %d block %d time of reading importance: %f s\n", layer_id, block_idx, duration.count()); } void KVCache::update_kvcache_one_block_fp16(const ggml_fp16_t *k_in, const ggml_fp16_t *v_in, int layer_id, int block_idx, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); layer_id_ = layer_id; block_idx = block_idx; seq_len_ = config_.block_len; k_data_ = const_cast(k_in); v_data_ = const_cast(v_in); int new_block_num = std::max((int)past_block_num_[layer_id], block_idx + 1); importance_[layer_id_].resize(new_block_num); for (int i = 0; i < config_.kv_head_num; i++) { k_cache_q4[layer_id][i].resize(new_block_num); v_cache_q4[layer_id][i].resize(new_block_num); // anchor_[layer_id][i].resize(new_block_num); } for (int i = 0; i < new_block_num; i++) { importance_[layer_id][i].resize(config_.block_len); } // Each task updates the k cache or v cache of a certain header backend->do_work_stealing_job( config_.kv_head_num * 2, nullptr, [&](int task_id) { std::vector block_fp32(32); int head_id = task_id / 2; if (task_id & 1) { // fill k_cache_ k_cache_q4[layer_id_][head_id][block_idx].resize( config_.block_len * config_.head_dim / 32); for (int k = 0; k < config_.block_len; k++) { for (int l = 0; l < config_.head_dim / 32; l++) { block_q4_0 block; for (int m = 0; m < 32; m++) { block_fp32[m] = GGML_FP16_TO_FP32( k_data_[((0 * config_.kv_head_num + head_id) * seq_len_ + 0 * config_.block_len + k) * config_.head_dim + l * 32 + m]); } quantize_row_q4_0(block_fp32.data(), &block, 32); k_cache_q4[layer_id_][head_id][block_idx] [k * config_.head_dim / 32 + l] = block; } } } else { // fill v_cache_ v_cache_q4[layer_id_][head_id][block_idx].resize( config_.head_dim * config_.block_len / 32); for (int k = 0; k < config_.block_len / 32; k++) { for (int l = 0; l < config_.head_dim; l++) { block_q4_0 block; for (int m = 0; m < 32; m++) { block_fp32[m] = GGML_FP16_TO_FP32( v_data_[((0 * config_.kv_head_num + head_id) * seq_len_ + 0 * config_.block_len + k * 32 + m) * config_.head_dim + l]); } quantize_row_q4_0(block_fp32.data(), &block, 32); v_cache_q4[layer_id_][head_id][block_idx] [l * config_.block_len / 32 + k] = block; } } } }, nullptr); past_block_num_[layer_id] = new_block_num; // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; printf("layer %d block %d time of writting KV Cache: %f s\n", layer_id, block_idx, duration.count()); // printf("get_one_block_fp16 duration: %ld\n", duration); } void KVCache::get_kvcache_one_block_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in, int layer_id, int block_idx, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); layer_id_ = layer_id; seq_len_ = config_.block_len; k_data_ = reinterpret_cast(k_in); v_data_ = reinterpret_cast(v_in); // printf("layer_id: %d, block_idx: %d\n", layer_id, block_idx); // Each task gets the k cache or v cache of a certain header backend->do_work_stealing_job( config_.kv_head_num * 2, nullptr, [&](int task_id) { std::vector block_fp32(32); int head_id = task_id / 2; if (task_id & 1) { // get k_cache_ for (int k = 0; k < config_.block_len; k++) { for (int l = 0; l < config_.head_dim / 32; l++) { block_q4_0 block = k_cache_q4[layer_id_][head_id][block_idx] [k * config_.head_dim / 32 + l]; dequantize_row_q4_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { k_data_[((0 * config_.kv_head_num + head_id) * seq_len_ + 0 * config_.block_len + k) * config_.head_dim + l * 32 + m] = GGML_FP32_TO_FP16(block_fp32[m]); } } } } else { // get v_cache_ for (int k = 0; k < config_.block_len / 32; k++) { for (int l = 0; l < config_.head_dim; l++) { block_q4_0 block = v_cache_q4[layer_id_][head_id][block_idx] [l * config_.block_len / 32 + k]; dequantize_row_q4_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { v_data_[((0 * config_.kv_head_num + head_id) * seq_len_ + 0 * config_.block_len + k * 32 + m) * config_.head_dim + l] = GGML_FP32_TO_FP16(block_fp32[m]); } } } } }, nullptr); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; printf("layer %d block %d time of reading KV Cache: %f s\n", layer_id, block_idx, duration.count()); // printf("get_one_block_fp16 duration: %ld\n", duration); } // k_in: (batch_size, seq_len, head_num, head_dim) // v_in: (batch_size, seq_len, head_num, head_dim) void KVCache::get_and_update_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in, int layer_id, int *block_table, int batch_size, int max_block_num, int *cache_seqlens, int q_len, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); layer_id_ = layer_id; k_data_ = const_cast(k_in); v_data_ = const_cast(v_in); // Each task updates the k cache and v cache of a certain header backend->do_work_stealing_job( config_.kv_head_num * max_block_num * batch_size, nullptr, [&](int task_id) { // printf("block_idx: %d, task_id: %d\n", block_idx, task_id); std::vector block_fp32(32); int batch_id = task_id / (config_.kv_head_num * max_block_num); int block_id = (task_id / config_.kv_head_num) % max_block_num; int head_id = task_id % config_.kv_head_num; int block_idx = block_table[batch_id * max_block_num + block_id]; int seq_len = cache_seqlens[batch_id]; int block_l = block_id * config_.block_len; int block_r = block_id * config_.block_len + config_.block_len; if (block_l < seq_len) { if (config_.kv_type == ggml_type::GGML_TYPE_F16) { for (int k = 0; k < config_.block_len; k++) { if (block_id * config_.block_len + k >= seq_len) break; for (int l = 0; l < config_.head_dim; l++) { k_data_ [batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l] = k_cache_fp16_[layer_id_][head_id][block_idx] [k * config_.head_dim + l]; v_data_ [batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l] = v_cache_fp16_[layer_id_][head_id][block_idx] [l * config_.block_len + k]; } } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { // get k_cache_ for (int k = 0; k < config_.block_len; k++) { if (block_id * config_.block_len + k >= seq_len) break; for (int l = 0; l < config_.head_dim / 32; l++) { block_q4_0 block = k_cache_q4[layer_id_][head_id][block_idx] [k * config_.head_dim / 32 + l]; dequantize_row_q4_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { k_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l * 32 + m] = GGML_FP32_TO_FP16(block_fp32[m]); } } } // get v_cache_ for (int k = 0; k < config_.block_len / 32; k++) { for (int l = 0; l < config_.head_dim; l++) { block_q4_0 block = v_cache_q4[layer_id_][head_id][block_idx] [l * config_.block_len / 32 + k]; dequantize_row_q4_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { if (block_id * config_.block_len + k * 32 + m >= seq_len) break; v_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + (k * 32 + m) * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l] = GGML_FP32_TO_FP16(block_fp32[m]); } } } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { // get k_cache_ for (int k = 0; k < config_.block_len; k++) { if (block_id * config_.block_len + k >= seq_len) break; for (int l = 0; l < config_.head_dim / 32; l++) { block_q8_0 block = k_cache_q8[layer_id_][head_id][block_idx] [k * config_.head_dim / 32 + l]; dequantize_row_q8_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { k_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l * 32 + m] = GGML_FP32_TO_FP16(block_fp32[m]); } } } // get v_cache_ for (int k = 0; k < config_.block_len / 32; k++) { for (int l = 0; l < config_.head_dim; l++) { block_q8_0 block = v_cache_q8[layer_id_][head_id][block_idx] [l * config_.block_len / 32 + k]; dequantize_row_q8_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { if (block_id * config_.block_len + k * 32 + m >= seq_len) break; v_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + (k * 32 + m) * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l] = GGML_FP32_TO_FP16(block_fp32[m]); } } } } } if (block_r > seq_len && block_l < seq_len + q_len) { if (config_.kv_type == ggml_type::GGML_TYPE_F16) { for (int k = 0; k < config_.block_len; k++) { if (block_id * config_.block_len + k >= seq_len + q_len || block_id * config_.block_len + k < seq_len) continue; for (int l = 0; l < config_.head_dim; l++) { k_cache_fp16_[layer_id_][head_id][block_idx] [k * config_.head_dim + l] = k_data_ [batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l]; v_cache_fp16_[layer_id_][head_id][block_idx] [l * config_.block_len + k] = v_data_ [batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l]; } } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { // fill k_cache_ for (int k = 0; k < config_.block_len; k++) { if (block_id * config_.block_len + k >= seq_len + q_len || block_id * config_.block_len + k < seq_len) continue; for (int l = 0; l < config_.head_dim / 32; l++) { block_q4_0 block; for (int m = 0; m < 32; m++) { block_fp32[m] = GGML_FP16_TO_FP32( k_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l * 32 + m]); } quantize_row_q4_0(block_fp32.data(), &block, 32); k_cache_q4[layer_id_][head_id][block_idx] [k * config_.head_dim / 32 + l] = block; } } // fill v_cache_ for (int k = 0; k < config_.block_len / 32; k++) { for (int l = 0; l < config_.head_dim; l++) { block_q4_0 block; for (int m = 0; m < 32; m++) { if (block_id * config_.block_len + k * 32 + m >= seq_len + q_len) { block_fp32[m] = 0; continue; } block_fp32[m] = GGML_FP16_TO_FP32( v_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + (k * 32 + m) * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l]); } quantize_row_q4_0(block_fp32.data(), &block, 32); v_cache_q4[layer_id_][head_id][block_idx] [l * config_.block_len / 32 + k] = block; } } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { // fill k_cache_ for (int k = 0; k < config_.block_len; k++) { if (block_id * config_.block_len + k >= seq_len + q_len || block_id * config_.block_len + k < seq_len) continue; for (int l = 0; l < config_.head_dim / 32; l++) { block_q8_0 block; for (int m = 0; m < 32; m++) { block_fp32[m] = GGML_FP16_TO_FP32( k_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l * 32 + m]); } quantize_row_q8_0(block_fp32.data(), &block, 32); k_cache_q8[layer_id_][head_id][block_idx] [k * config_.head_dim / 32 + l] = block; } } // fill v_cache_ for (int k = 0; k < config_.block_len / 32; k++) { for (int l = 0; l < config_.head_dim; l++) { block_q8_0 block; for (int m = 0; m < 32; m++) { if (block_id * config_.block_len + k * 32 + m >= seq_len + q_len) { block_fp32[m] = 0; continue; } block_fp32[m] = GGML_FP16_TO_FP32( v_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + (k * 32 + m) * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l]); } quantize_row_q8_0(block_fp32.data(), &block, 32); v_cache_q8[layer_id_][head_id][block_idx] [l * config_.block_len / 32 + k] = block; } } } } }, nullptr); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; // printf("layer %d time of reading and updating KV Cache: %f s\n", // layer_id, // duration.count()); } void KVCache::update_importance(const ggml_fp16_t *importance, int layer_id, int *block_table, int batch_size, int max_block_num, int *offset, int width, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); layer_id_ = layer_id; importance_data_ = const_cast(importance); // Each task updates the importance of a certain position backend->do_work_stealing_job( max_block_num * batch_size, nullptr, [&](int task_id) { int block_id = task_id % max_block_num; int batch_id = task_id / max_block_num; int block_idx = block_table[batch_id * max_block_num + block_id]; if (block_id > (offset[batch_id] + width) / config_.block_len) { return; } for (int k = 0; k < config_.block_len; k++) { for (int head_id = 0; head_id < config_.q_head_num; head_id++) { importance_[layer_id_][block_idx][k][head_id] = GGML_FP32_TO_FP16( GGML_FP16_TO_FP32( importance_data_[batch_id * max_block_num * config_.block_len * config_.q_head_num + (block_id * config_.block_len + k) * config_.q_head_num + head_id]) + GGML_FP16_TO_FP32( importance_[layer_id_][block_idx][k][head_id])); } } }, nullptr); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; // printf("layer %d time of updating importance: %f s\n", layer_id, // duration.count()); } void KVCache::get_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in, int layer_id, int *block_table, int batch_size, int max_block_num, int *cache_seqlens, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); layer_id_ = layer_id; k_data_ = const_cast(k_in); v_data_ = const_cast(v_in); // Each task updates the k cache and v cache of a certain header backend->do_work_stealing_job( config_.kv_head_num * max_block_num * batch_size, nullptr, [&](int task_id) { // printf("block_idx: %d, task_id: %d\n", block_idx, task_id); std::vector block_fp32(32); int batch_id = task_id / (config_.kv_head_num * max_block_num); int block_id = (task_id / config_.kv_head_num) % max_block_num; int head_id = task_id % config_.kv_head_num; int block_idx = block_table[batch_id * max_block_num + block_id]; int seq_len = cache_seqlens[batch_id]; int block_l = block_id * config_.block_len; int block_r = block_id * config_.block_len + config_.block_len; if (block_l < seq_len) { if (config_.kv_type == ggml_type::GGML_TYPE_F16) { for (int k = 0; k < config_.block_len; k++) { if (block_id * config_.block_len + k >= seq_len) break; for (int l = 0; l < config_.head_dim; l++) { k_data_ [batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l] = k_cache_fp16_[layer_id_][head_id][block_idx] [k * config_.head_dim + l]; v_data_ [batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l] = v_cache_fp16_[layer_id_][head_id][block_idx] [l * config_.block_len + k]; } } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { // get k_cache_ for (int k = 0; k < config_.block_len; k++) { if (block_id * config_.block_len + k >= seq_len) break; for (int l = 0; l < config_.head_dim / 32; l++) { block_q4_0 block = k_cache_q4[layer_id_][head_id][block_idx] [k * config_.head_dim / 32 + l]; dequantize_row_q4_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { k_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l * 32 + m] = GGML_FP32_TO_FP16(block_fp32[m]); } } } // get v_cache_ for (int k = 0; k < config_.block_len / 32; k++) { for (int l = 0; l < config_.head_dim; l++) { block_q4_0 block = v_cache_q4[layer_id_][head_id][block_idx] [l * config_.block_len / 32 + k]; dequantize_row_q4_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { if (block_id * config_.block_len + k * 32 + m >= seq_len) break; v_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + (k * 32 + m) * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l] = GGML_FP32_TO_FP16(block_fp32[m]); } } } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { // get k_cache_ for (int k = 0; k < config_.block_len; k++) { if (block_id * config_.block_len + k >= seq_len) break; for (int l = 0; l < config_.head_dim / 32; l++) { block_q8_0 block = k_cache_q8[layer_id_][head_id][block_idx] [k * config_.head_dim / 32 + l]; dequantize_row_q8_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { k_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l * 32 + m] = GGML_FP32_TO_FP16(block_fp32[m]); } } } // get v_cache_ for (int k = 0; k < config_.block_len / 32; k++) { for (int l = 0; l < config_.head_dim; l++) { block_q8_0 block = v_cache_q8[layer_id_][head_id][block_idx] [l * config_.block_len / 32 + k]; dequantize_row_q8_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { if (block_id * config_.block_len + k * 32 + m >= seq_len) break; v_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) + block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) + (k * 32 + m) * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l] = GGML_FP32_TO_FP16(block_fp32[m]); } } } } } }, nullptr); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; } void KVCache::update_kvcache_fp16(const ggml_fp16_t *k_in, const ggml_fp16_t *v_in, int layer_id, int *block_table, int batch_size, int max_block_num, int *cache_seqlens, int q_len, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); layer_id_ = layer_id; k_data_ = const_cast(k_in); v_data_ = const_cast(v_in); // Each task updates the k cache and v cache of a certain header backend->do_work_stealing_job( batch_size * config_.kv_head_num * q_len, nullptr, [&](int task_id) { int batch_id = task_id / (config_.kv_head_num * q_len); int head_id = task_id / q_len % config_.kv_head_num; int seq_len = cache_seqlens[batch_id] + task_id % q_len; int q_offset = task_id % q_len; int block_id = seq_len / config_.block_len; int block_idx = block_table[batch_id * max_block_num + block_id]; int pos_in_block = seq_len % config_.block_len; if (config_.kv_type == ggml_type::GGML_TYPE_F16) { for (int l = 0; l < config_.head_dim; l++) { k_cache_fp16_[layer_id_][head_id][block_idx] [pos_in_block * config_.head_dim + l] = k_data_[batch_id * (q_len * config_.kv_head_num * config_.head_dim) + q_offset * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l]; v_cache_fp16_[layer_id_][head_id][block_idx] [l * config_.block_len + pos_in_block] = v_data_[batch_id * (q_len * config_.kv_head_num * config_.head_dim) + q_offset * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l]; } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { std::vector block_fp32(32); // fill k_cache_ for (int l = 0; l < config_.head_dim / 32; l++) { block_q4_0 block; for (int m = 0; m < 32; m++) { block_fp32[m] = GGML_FP16_TO_FP32( k_data_[batch_id * (q_len * config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l * 32 + m]); } quantize_row_q4_0(block_fp32.data(), &block, 32); k_cache_q4[layer_id_][head_id][block_idx] [pos_in_block * config_.head_dim / 32 + l] = block; } // fill v_cache_ for (int l = 0; l < config_.head_dim; l++) { block_q4_0 block = v_cache_q4[layer_id_][head_id][block_idx] [l * config_.block_len / 32 + pos_in_block / 32]; dequantize_row_q4_0(&block, block_fp32.data(), 32); block_fp32[pos_in_block % 32] = GGML_FP16_TO_FP32( v_data_[batch_id * (q_len * config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l]); quantize_row_q4_0(block_fp32.data(), &block, 32); v_cache_q4[layer_id_][head_id][block_idx] [l * config_.block_len / 32 + pos_in_block / 32] = block; } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { std::vector block_fp32(32); // fill k_cache_ for (int l = 0; l < config_.head_dim / 32; l++) { block_q8_0 block; for (int m = 0; m < 32; m++) { block_fp32[m] = GGML_FP16_TO_FP32( k_data_[batch_id * (q_len * config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l * 32 + m]); } quantize_row_q8_0(block_fp32.data(), &block, 32); k_cache_q8[layer_id_][head_id][block_idx] [pos_in_block * config_.head_dim / 32 + l] = block; } // fill v_cache_ for (int l = 0; l < config_.head_dim; l++) { block_q8_0 block = v_cache_q8[layer_id_][head_id][block_idx] [l * config_.block_len / 32 + pos_in_block / 32]; dequantize_row_q8_0(&block, block_fp32.data(), 32); block_fp32[pos_in_block % 32] = GGML_FP16_TO_FP32( v_data_[batch_id * (q_len * config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l]); quantize_row_q8_0(block_fp32.data(), &block, 32); v_cache_q8[layer_id_][head_id][block_idx] [l * config_.block_len / 32 + pos_in_block / 32] = block; } } }, nullptr); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; // printf("layer %d time of reading KV Cache: %f s\n", layer_id, // duration.count()); } void KVCache::get_all_kvcache_one_layer(int layer_id, ggml_fp16_t *k_in, ggml_fp16_t *v_in, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); layer_id_ = layer_id; seq_len_ = config_.block_len; block_num_ = get_cache_total_block_num(); k_data_ = reinterpret_cast(k_in); v_data_ = reinterpret_cast(v_in); // Each task gets the k cache or v cache of a certain header backend->do_work_stealing_job( config_.kv_head_num * past_block_num_[layer_id] * 2, nullptr, [&](int task_id) { std::vector block_fp32(32); int head_id = task_id / 2 / past_block_num_[layer_id]; int block_idx = task_id / 2 % past_block_num_[layer_id]; if (block_idx >= block_num_) return; int max_offset = 0; if (task_id & 1) { // get k_cache_ for (int k = 0; k < config_.block_len; k++) { if (block_idx * seq_len_ + k >= cache_total_len_) break; for (int l = 0; l < config_.head_dim / 32; l++) { block_q4_0 block = k_cache_q4[layer_id_][head_id][block_idx] [k * config_.head_dim / 32 + l]; dequantize_row_q4_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { k_data_[(head_id * cache_total_len_ + block_idx * config_.block_len + k) * config_.head_dim + l * 32 + m] = GGML_FP32_TO_FP16(block_fp32[m]); max_offset = std::max( max_offset, (int)(head_id * cache_total_len_ + block_idx * config_.block_len + k) * config_.head_dim + l * 32 + m); } } } } else { // get v_cache_ for (int k = 0; k < config_.block_len / 32; k++) { for (int l = 0; l < config_.head_dim; l++) { block_q4_0 block = v_cache_q4[layer_id_][head_id][block_idx] [l * config_.block_len / 32 + k]; dequantize_row_q4_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { if (block_idx * seq_len_ + k * 32 + m >= cache_total_len_) break; v_data_[(head_id * cache_total_len_ + block_idx * config_.block_len + k * 32 + m) * config_.head_dim + l] = GGML_FP32_TO_FP16(block_fp32[m]); max_offset = std::max(max_offset, (int)((head_id * cache_total_len_ + block_idx * config_.block_len + k * 32 + m) * config_.head_dim + l)); } } } } }, nullptr); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; // printf("layer %d block num %d time of reading all KV Cache: %f s\n", // layer_id, block_num_, duration.count()); } ================================================ FILE: archive/csrc/ktransformers_ext/operators/kvcache/kvcache_utils.cpp ================================================ /** * @Description : * @Author : Jianwei Dong * @Date : 2024-08-26 22:47:06 * @Version : 1.0.0 * @LastEditors : Jianwei Dong * @LastEditTime : 2024-08-26 22:47:06 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "kvcache.h" #include std::string ggml_type_to_string(ggml_type type) { switch (type) { case GGML_TYPE_F32: return "GGML_TYPE_F32"; case GGML_TYPE_F16: return "GGML_TYPE_F16"; case GGML_TYPE_Q4_0: return "GGML_TYPE_Q4_0"; case GGML_TYPE_Q8_0: return "GGML_TYPE_Q8_0"; } return "UNDIFINED"; } std::string AnchorTypeToString(AnchorType type) { switch (type) { case AnchorType::DYNAMIC: return "DYNAMIC"; case AnchorType::BLOCK_MEAN: return "BLOCK_MEAN"; case AnchorType::BLOCK_MAX: return "BLOCK_MAX"; case AnchorType::FIXED_ANCHOR: return "FIXED_ANCHOR"; case AnchorType::QUEST: return "QUEST"; } return "UNDIFINED"; } std::string RetrievalTypeToString(RetrievalType type) { switch (type) { case RetrievalType::LAYER: return "SHARED"; case RetrievalType::KVHEAD: return "SEPARATE"; case RetrievalType::QHEAD: return "INDIVIDUAL"; } return "UNDIFINED"; } KVCacheConfig::KVCacheConfig(int layer_num, int kv_head_num, int q_head_num, int head_dim, int block_len, int anchor_num, AnchorType anchor_type, ggml_type kv_type, RetrievalType retrieval_type, int layer_step, int token_step, int layer_offset, int max_block_num, int max_batch_size, int max_thread_num) : layer_num(layer_num), kv_head_num(kv_head_num), q_head_num(q_head_num), head_dim(head_dim), block_len(block_len), anchor_num(anchor_num), anchor_type(anchor_type), kv_type(kv_type), retrieval_type(retrieval_type), layer_step(layer_step), token_step(token_step), layer_offset(layer_offset), max_block_num(max_block_num), max_batch_size(max_batch_size), max_thread_num(max_thread_num) { printf( "layer_num: %d, kv_head_num: %d, q_head_num: %d, head_dim: %d, " "block_len: %d, anchor_num: %d, anchor_type: %s, kv_type: %s, " "retrieval_type: %s, layer_step: %d, token_step: %d, layer_offset: %d," "max_block_num: %d, max_batch_size: %d, max_thread_num: %d\n", layer_num, kv_head_num, q_head_num, head_dim, block_len, anchor_num, AnchorTypeToString(anchor_type).c_str(), ggml_type_to_string(kv_type).c_str(), RetrievalTypeToString(retrieval_type).c_str(), layer_step, token_step, layer_offset, max_block_num, max_batch_size, max_thread_num); assert(q_head_num % kv_head_num == 0); } KVCache::KVCache(KVCacheConfig config) { this->config_ = config; n_gqa_ = config_.q_head_num / config_.kv_head_num; if (config_.kv_type == ggml_type::GGML_TYPE_F16) { // TODO: Elegant implement k_cache_fp16_.resize(config_.layer_num); v_cache_fp16_.resize(config_.layer_num); selected_blocks_num_history_.resize(config_.layer_num / config_.layer_step); if (config_.retrieval_type == RetrievalType::LAYER) { selected_blocks_history_.resize(config_.layer_num / config_.layer_step); } else if (config_.retrieval_type == RetrievalType::KVHEAD) { selected_blocks_history_kvhead_.resize(config_.layer_num / config_.layer_step); } else if (config_.retrieval_type == RetrievalType::QHEAD) { } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { k_cache_q4.resize(config.layer_num); v_cache_q4.resize(config.layer_num); } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { k_cache_q8.resize(config.layer_num); v_cache_q8.resize(config.layer_num); } else { assert(false); } anchor_.resize(config.layer_num * config.max_block_num * config.anchor_num * config.q_head_num * config.head_dim); importance_.resize(config.layer_num); past_block_num_.resize(config.layer_num); for (int i = 0; i < config.layer_num; i++) { past_block_num_[i] = 0; } ThreadResize(config.max_thread_num); BatchResize(config.max_batch_size); BlockResize(config.max_block_num); q_fp32.resize(n_gqa_ * config.head_dim); } void KVCache::ThreadResize(int thread_num) { thread_local_output_q8_0_.resize(thread_num); thread_local_attn_score_.resize(thread_num); thread_local_output_fp32_.resize(thread_num); thread_local_attn_lse_.resize(thread_num); thread_local_cur_output_fp32_.resize(thread_num); thread_local_cur_attn_lse_.resize(thread_num); thread_local_draft_.resize(thread_num); thread_cur_head_idx_.resize(thread_num); thread_local_attn_mask_.resize(thread_num); for (int i = 0; i < thread_num; i++) { thread_local_output_q8_0_[i].resize(n_gqa_ * config_.head_dim / QK8_0); thread_local_attn_score_[i].resize(n_gqa_ * config_.block_len); thread_local_output_fp32_[i].resize(n_gqa_ * config_.head_dim); thread_local_attn_lse_[i].resize(n_gqa_); thread_local_cur_output_fp32_[i].resize(n_gqa_ * config_.head_dim); thread_local_cur_attn_lse_[i].resize(n_gqa_); thread_local_draft_[i].resize( 2 * n_gqa_ * config_.block_len + 6 * n_gqa_ * config_.head_dim + 2 * config_.block_len * config_.head_dim + config_.block_len * config_.head_dim / QK4_0); thread_local_attn_mask_[i].resize(config_.block_len / 8); } } void KVCache::BatchResize(int batch_size) { mutex_.resize(batch_size); q_q8_0_.resize(batch_size); q_fp32_.resize(batch_size); output_fp32_.resize(batch_size); attn_lse_.resize(batch_size); block_lse_.resize(batch_size); attn_sparsity_.resize(batch_size); if (config_.retrieval_type == RetrievalType::LAYER) { block_table_before_retrieval_.resize(batch_size); block_table_after_retrieval_.resize(batch_size); for (int i = 0; i < config_.layer_num / config_.layer_step; i++) { selected_blocks_history_[i].resize(batch_size); } } else if (config_.retrieval_type == RetrievalType::KVHEAD) { block_table_before_retrieval_kvhead_.resize(batch_size); block_table_after_retrieval_kvhead_.resize(batch_size); for (int i = 0; i < config_.layer_num / config_.layer_step; i++) { selected_blocks_history_kvhead_[i].resize(batch_size); } } else if (config_.retrieval_type == RetrievalType::QHEAD) { block_table_before_retrieval_qhead_.resize(batch_size); block_table_after_retrieval_qhead_.resize(batch_size); } cache_seqlens_.resize(batch_size); if (config_.retrieval_type == RetrievalType::LAYER) { block_similar_.resize(batch_size); } else if (config_.retrieval_type == RetrievalType::KVHEAD) { block_similar_kv_head_.resize(batch_size); } else if (config_.retrieval_type == RetrievalType::QHEAD) { block_similar_q_head_.resize(batch_size); } for (int i = 0; i < batch_size; i++) { top_similar_block_.resize(batch_size); mutex_[i].resize(config_.kv_head_num); q_q8_0_[i].resize(config_.kv_head_num); q_fp32_[i].resize(config_.kv_head_num); output_fp32_[i].resize(config_.kv_head_num); attn_lse_[i].resize(config_.kv_head_num); for (int j = 0; j < config_.kv_head_num; j++) { if (!mutex_[i][j]) { mutex_[i][j] = std::make_unique(); } q_q8_0_[i][j].resize(n_gqa_ * config_.head_dim / QK8_0); q_fp32_[i][j].resize(n_gqa_ * config_.head_dim); output_fp32_[i][j].resize(n_gqa_ * config_.head_dim); attn_lse_[i][j].resize(n_gqa_); } } avg_q.resize(batch_size); avg_q_fp16.resize(batch_size); for (int i = 0; i < batch_size; i++) { attn_sparsity_[i].resize(config_.q_head_num); avg_q[i].resize(config_.q_head_num * config_.head_dim); avg_q_fp16[i].resize(config_.q_head_num * config_.head_dim); } } void KVCache::BlockResize(int max_block_num) { sin_.resize(max_block_num * config_.block_len); cos_.resize(max_block_num * config_.block_len); for (int i = 0; i < max_block_num * config_.block_len; i++) { sin_[i].resize(config_.head_dim); cos_[i].resize(config_.head_dim); } for (int i = 0; i < config_.layer_num / config_.layer_step; i++) { for (int j = 0; j < config_.max_batch_size; j++) { if (config_.retrieval_type == RetrievalType::LAYER) { selected_blocks_history_[i][j].resize(max_block_num); } else if (config_.retrieval_type == RetrievalType::KVHEAD) { selected_blocks_history_kvhead_[i][j].resize(max_block_num); for (int k = 0; k < config_.max_block_num; k++) { selected_blocks_history_kvhead_[i][j][k].resize( config_.kv_head_num); } } else if (config_.retrieval_type == RetrievalType::QHEAD) { } } } for (int layer_id = 0; layer_id < config_.layer_num; layer_id++) { importance_[layer_id].resize(max_block_num); if (config_.kv_type == ggml_type::GGML_TYPE_F16) { // TODO: Elegant implement k_cache_fp16_[layer_id].resize(config_.kv_head_num); v_cache_fp16_[layer_id].resize(config_.kv_head_num); for (int i = 0; i < config_.kv_head_num; i++) { k_cache_fp16_[layer_id][i].resize(max_block_num); v_cache_fp16_[layer_id][i].resize(max_block_num); for (int j = 0; j < max_block_num; j++) { k_cache_fp16_[layer_id][i][j].resize(config_.block_len * config_.head_dim); v_cache_fp16_[layer_id][i][j].resize(config_.block_len * config_.head_dim); } } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { k_cache_q4[layer_id].resize(config_.kv_head_num); v_cache_q4[layer_id].resize(config_.kv_head_num); for (int i = 0; i < config_.kv_head_num; i++) { k_cache_q4[layer_id][i].resize(max_block_num); v_cache_q4[layer_id][i].resize(max_block_num); for (int j = 0; j < max_block_num; j++) { k_cache_q4[layer_id][i][j].resize(config_.block_len * config_.head_dim / 32); v_cache_q4[layer_id][i][j].resize(config_.block_len * config_.head_dim / 32); } } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { k_cache_q8[layer_id].resize(config_.kv_head_num); v_cache_q8[layer_id].resize(config_.kv_head_num); for (int i = 0; i < config_.kv_head_num; i++) { k_cache_q8[layer_id][i].resize(max_block_num); v_cache_q8[layer_id][i].resize(max_block_num); for (int j = 0; j < max_block_num; j++) { k_cache_q8[layer_id][i][j].resize(config_.block_len * config_.head_dim / 32); v_cache_q8[layer_id][i][j].resize(config_.block_len * config_.head_dim / 32); } } } else { assert(false); } for (int i = 0; i < config_.max_batch_size; i++) { if (config_.retrieval_type == RetrievalType::LAYER) { block_similar_[i].resize(max_block_num); block_table_before_retrieval_[i].resize(max_block_num); block_table_after_retrieval_[i].resize(max_block_num); } else if (config_.retrieval_type == RetrievalType::KVHEAD) { block_similar_kv_head_[i].resize(max_block_num); block_table_before_retrieval_kvhead_[i].resize(max_block_num); block_table_after_retrieval_kvhead_[i].resize(max_block_num); for (int j = 0; j < max_block_num; j++) { block_similar_kv_head_[i][j].resize(config_.kv_head_num); block_table_before_retrieval_kvhead_[i][j].resize( config_.kv_head_num); block_table_after_retrieval_kvhead_[i][j].resize( config_.kv_head_num); } } else if (config_.retrieval_type == RetrievalType::QHEAD) { block_similar_q_head_[i].resize(max_block_num); block_table_before_retrieval_qhead_[i].resize(max_block_num); block_table_after_retrieval_qhead_[i].resize(max_block_num); for (int j = 0; j < max_block_num; j++) { block_similar_q_head_[i][j].resize(config_.q_head_num); block_table_before_retrieval_qhead_[i][j].resize( config_.q_head_num); block_table_after_retrieval_qhead_[i][j].resize( config_.q_head_num); } } block_lse_[i].resize(max_block_num); for (int j = 0; j < max_block_num; j++) { block_lse_[i][j].resize(config_.q_head_num); } } for (int i = 0; i < max_block_num; i++) { importance_[layer_id][i].resize(config_.block_len); for (int j = 0; j < config_.block_len; j++) { importance_[layer_id][i][j].resize(config_.q_head_num); } } } } void KVCache::calc_anchor_all_layers(int *block_table, int *cache_seqlens, int batch_size, int max_block_num, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); // Each task updates the importance of a certain block seq_len_ = config_.block_len; backend->do_work_stealing_job( config_.layer_num * batch_size * max_block_num, nullptr, [&](int task_id) { int layer_id = task_id / (batch_size * max_block_num); int batch_id = (task_id / max_block_num) % batch_size; int block_id = task_id % max_block_num; // If the block is out of the sequence length, skip it. In // particular, the last block of the sequence that is shorter than // the block length should be skipped. if (cache_seqlens[batch_id] / config_.block_len < block_id) { return; } int block_idx = block_table[batch_id * max_block_num + block_id]; std::vector block_fp32(32); if (config_.anchor_type == AnchorType::DYNAMIC) { // clear anchor_ for (int anchor_id = 0; anchor_id < 1; anchor_id++) { for (int head_id = 0; head_id < config_.q_head_num; head_id++) { for (int l = 0; l < config_.head_dim; l++) { anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = 0; } } } // find top anchor_num importances and their corresponding // positions in the importance_ tensor // TODO: Move top_importances to the class member to avoid // repeated memory allocation std::priority_queue< std::pair>, std::vector>>, std::greater<>> top_importances; for (int head_id = 0; head_id < config_.q_head_num; head_id++) { for (int k = 0; k < seq_len_; k++) { top_importances.push(std::make_pair( GGML_FP16_TO_FP32( importance_[layer_id][block_idx][k][head_id]), std::make_pair(block_idx, k))); // TODO: change to config_ item if (top_importances.size() > config_.anchor_num) { top_importances.pop(); } } // fill anchor_ for (int l = 0; l < config_.head_dim; l++) { anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = 0; } for (int k = 0; k < config_.anchor_num; k++) { int top_indice = top_importances.top().second.second; int top_block_idx = top_importances.top().second.first; if (config_.kv_type == ggml_type::GGML_TYPE_F16) { for (int l = 0; l < config_.head_dim; l++) { anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + top_block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = GGML_FP32_TO_FP16( GGML_FP16_TO_FP32( anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + top_block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l]) + GGML_FP16_TO_FP32( k_cache_fp16_[layer_id] [head_id / n_gqa_] [top_block_idx] [top_indice * config_.head_dim + l])); } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { for (int l = 0; l < config_.head_dim / 32; l++) { block_q4_0 block = k_cache_q4 [layer_id][head_id / n_gqa_][top_block_idx] [top_indice * config_.head_dim / 32 + l]; dequantize_row_q4_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + top_block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m] = GGML_FP32_TO_FP16( block_fp32[m] / 4 + GGML_FP16_TO_FP32( anchor_[layer_id * config_ .max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + top_block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m])); } } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { for (int l = 0; l < config_.head_dim / 32; l++) { block_q8_0 block = k_cache_q8 [layer_id][head_id / n_gqa_][top_block_idx] [top_indice * config_.head_dim / 32 + l]; dequantize_row_q8_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + top_block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m] = GGML_FP32_TO_FP16( block_fp32[m] / 4 + GGML_FP16_TO_FP32( anchor_[layer_id * config_ .max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + top_block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m])); } } } top_importances.pop(); } } } else if (config_.anchor_type == AnchorType::BLOCK_MEAN) { // clear anchor_ for (int anchor_id = 0; anchor_id < config_.anchor_num; anchor_id++) { for (int head_id = 0; head_id < config_.q_head_num; head_id++) { for (int l = 0; l < config_.head_dim; l++) { anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = 0; } } } // fill anchor_ if (config_.kv_type == ggml_type::GGML_TYPE_F16) { for (int head_id = 0; head_id < config_.q_head_num; head_id++) { for (int k = 0; k < config_.block_len; k++) { for (int l = 0; l < config_.head_dim; l++) { anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = GGML_FP32_TO_FP16( GGML_FP16_TO_FP32( anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l]) + GGML_FP16_TO_FP32( k_cache_fp16_[layer_id] [head_id / n_gqa_] [block_idx] [k * config_.head_dim + l]) / config_.block_len); } } } } } else if (config_.anchor_type == AnchorType::BLOCK_MAX) { // clear anchor_ for (int anchor_id = 0; anchor_id < config_.anchor_num; anchor_id++) { for (int head_id = 0; head_id < config_.q_head_num; head_id++) { for (int l = 0; l < config_.head_dim; l++) { anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = 0; } } } // fill anchor_ if (config_.kv_type == ggml_type::GGML_TYPE_F16) { for (int head_id = 0; head_id < config_.q_head_num; head_id++) { for (int k = 0; k < config_.block_len; k++) { for (int l = 0; l < config_.head_dim; l++) { anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = GGML_FP32_TO_FP16(std::max( GGML_FP16_TO_FP32( anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l]), GGML_FP16_TO_FP32( k_cache_fp16_ [layer_id][head_id / n_gqa_] [block_idx] [k * config_.head_dim + l]))); } } } } } else if (config_.anchor_type == AnchorType::FIXED_ANCHOR) { // clear anchor_ for (int anchor_id = 0; anchor_id < 1; anchor_id++) { for (int head_id = 0; head_id < config_.q_head_num; head_id++) { for (int l = 0; l < config_.head_dim; l++) { anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = 0; } } } // fill anchor_ if (config_.kv_type == ggml_type::GGML_TYPE_F16) { int stride = config_.block_len / config_.anchor_num; for (int head_id = 0; head_id < config_.q_head_num; head_id++) { for (int k = 0, tot = 0; k < config_.block_len, tot < config_.anchor_num; k += stride, tot++) { for (int l = 0; l < config_.head_dim; l++) { anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = GGML_FP32_TO_FP16( GGML_FP16_TO_FP32( anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l]) + GGML_FP16_TO_FP32( k_cache_fp16_[layer_id] [head_id / n_gqa_] [block_idx] [k * config_.head_dim + l]) / config_.anchor_num); } } } } } else if (config_.anchor_type == AnchorType::QUEST) { // clear anchor_ for (int head_id = 0; head_id < config_.q_head_num; head_id++) { for (int l = 0; l < config_.head_dim; l++) { anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 1 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = GGML_FP32_TO_FP16( std::numeric_limits::max()); anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = GGML_FP32_TO_FP16( std::numeric_limits::min()); } } // fill anchor_ if (config_.kv_type == ggml_type::GGML_TYPE_F16) { for (int indice = 0; indice < seq_len_; indice++) { for (int head_id = 0; head_id < config_.kv_head_num; head_id++) { for (int l = 0; l < config_.head_dim; l++) { anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = GGML_FP32_TO_FP16(std::max( GGML_FP16_TO_FP32( k_cache_fp16_ [layer_id][head_id][block_idx] [indice * config_.head_dim + l]), GGML_FP16_TO_FP32( anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l]))); anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 1 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = GGML_FP32_TO_FP16(std::min( GGML_FP16_TO_FP32( k_cache_fp16_ [layer_id][head_id][block_idx] [indice * config_.head_dim + l]), GGML_FP16_TO_FP32( anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 1 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l]))); } } } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { for (int indice = 0; indice < seq_len_; indice++) { for (int head_id = 0; head_id < config_.kv_head_num; head_id++) { for (int l = 0; l < config_.head_dim / 32; l++) { block_q4_0 block = k_cache_q4[layer_id][head_id][block_idx] [indice * config_.head_dim / 32 + l]; dequantize_row_q4_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { for (int gqa_idx = 0; gqa_idx < n_gqa_; gqa_idx++) { anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m] = GGML_FP32_TO_FP16(std::max( block_fp32[m], GGML_FP16_TO_FP32( anchor_ [layer_id * config_ .max_block_num * config_ .anchor_num * config_ .q_head_num * config_.head_dim + block_idx * config_ .anchor_num * config_ .q_head_num * config_.head_dim + 0 * config_ .q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m]))); anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 1 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m] = GGML_FP32_TO_FP16(std::min( block_fp32[m], GGML_FP16_TO_FP32( anchor_ [layer_id * config_ .max_block_num * config_ .anchor_num * config_ .q_head_num * config_.head_dim + block_idx * config_ .anchor_num * config_ .q_head_num * config_.head_dim + 1 * config_ .q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m]))); } } } } } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { for (int indice = 0; indice < seq_len_; indice++) { for (int head_id = 0; head_id < config_.kv_head_num; head_id++) { for (int l = 0; l < config_.head_dim / 32; l++) { block_q8_0 block = k_cache_q8[layer_id][head_id][block_idx] [indice * config_.head_dim / 32 + l]; dequantize_row_q8_0(&block, block_fp32.data(), 32); for (int m = 0; m < 32; m++) { for (int gqa_idx = 0; gqa_idx < n_gqa_; gqa_idx++) { anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m] = GGML_FP32_TO_FP16(std::max( block_fp32[m], GGML_FP16_TO_FP32( anchor_ [layer_id * config_ .max_block_num * config_ .anchor_num * config_ .q_head_num * config_.head_dim + block_idx * config_ .anchor_num * config_ .q_head_num * config_.head_dim + 0 * config_ .q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m]))); anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim + block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim + 1 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m] = GGML_FP32_TO_FP16(std::min( block_fp32[m], GGML_FP16_TO_FP32( anchor_ [layer_id * config_ .max_block_num * config_ .anchor_num * config_ .q_head_num * config_.head_dim + block_idx * config_ .anchor_num * config_ .q_head_num * config_.head_dim + 1 * config_ .q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m]))); } } } } } } } else { assert(false); } }, nullptr); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; // printf("time of calc_anchor_all_layers: %f s\n", duration.count()); } void KVCache::clear_importance_all_layers(int *block_table, int *cache_seqlens, int batch_size, int max_block_num, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); // Each task updates the importance of a certain block seq_len_ = config_.block_len; backend->do_work_stealing_job( config_.layer_num * batch_size * max_block_num, nullptr, [&](int task_id) { int layer_id = task_id / (batch_size * max_block_num); int batch_id = (task_id / max_block_num) % batch_size; int block_id = task_id % max_block_num; // If the block is out of the sequence length, skip it. In // particular, the last block of the sequence that is shorter than // the block length should be skipped. if (cache_seqlens[batch_id] / config_.block_len < block_id) { return; } int block_idx = block_table[batch_id * max_block_num + block_id]; if (config_.anchor_type == AnchorType::DYNAMIC) { // clear anchor_ for (int head_id = 0; head_id < config_.q_head_num; head_id++) { for (int l = 0; l < config_.block_len; l++) { importance_[layer_id][block_idx][l][head_id] = 0; } } } }, nullptr); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; // printf("time of clear_importance_all_layerssssss: %f s\n", // duration.count()); } void KVCache::clear_kvcache_all_layers(int *block_table, int *cache_seqlens, int batch_size, int max_block_num, Backend *backend) { // Timer start auto start = std::chrono::high_resolution_clock::now(); // Each task updates the importance of a certain block seq_len_ = config_.block_len; backend->do_work_stealing_job( config_.layer_num * batch_size * max_block_num * config_.kv_head_num, nullptr, [&](int task_id) { int layer_id = task_id / (batch_size * max_block_num * config_.kv_head_num); int batch_id = (task_id / (max_block_num * config_.kv_head_num)) % batch_size; int block_id = task_id / config_.kv_head_num % max_block_num; int head_id = task_id % config_.kv_head_num; // If the block is out of the sequence length, skip it. In // particular, the last block of the sequence that is shorter than // the block length should be skipped. if (cache_seqlens[batch_id] / config_.block_len < block_id) { return; } int block_idx = block_table[batch_id * max_block_num + block_id]; if (config_.kv_type == ggml_type::GGML_TYPE_F16) { for (int l = 0; l < config_.block_len * config_.head_dim; l++) { k_cache_fp16_[layer_id][head_id][block_idx][l] = 0; v_cache_fp16_[layer_id][head_id][block_idx][l] = 0; } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) { for (int l = 0; l < config_.block_len * config_.head_dim / 32; l++) { k_cache_q4[layer_id][head_id][block_idx][l].d = 0; v_cache_q4[layer_id][head_id][block_idx][l].d = 0; } } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) { for (int l = 0; l < config_.block_len * config_.head_dim / 32; l++) { k_cache_q8[layer_id][head_id][block_idx][l].d = 0; v_cache_q8[layer_id][head_id][block_idx][l].d = 0; } } }, nullptr); // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; // printf("time of clear_kvcache_all_layers: %f s\n", duration.count()); } void KVCache::get_sincos(ggml_fp16_t *sin, ggml_fp16_t *cos, int seqlen) { // Timer start auto start = std::chrono::high_resolution_clock::now(); const uint16_t *sin_data = const_cast(sin); const uint16_t *cos_data = const_cast(cos); for (int i = 0; i < seqlen; i++) { for (int j = 0; j < config_.head_dim; j++) { sin_[i][j] = sin_data[i * config_.head_dim + j]; cos_[i][j] = cos_data[i * config_.head_dim + j]; } } // Timer end auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; printf("time of get_sincos: %f s\n", duration.count()); } void ggml_vec_scale_f32(const int n, float *y, const float v) { #if defined(GGML_USE_ACCELERATE) vDSP_vsmul(y, 1, &v, y, 1, n); #elif defined(GGML_SIMD) const int np = (n & ~(GGML_F32_STEP - 1)); GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); GGML_F32_VEC ay[GGML_F32_ARR]; for (int i = 0; i < np; i += GGML_F32_STEP) { for (int j = 0; j < GGML_F32_ARR; j++) { ay[j] = GGML_F32_VEC_LOAD(y + i + j * GGML_F32_EPR); ay[j] = GGML_F32_VEC_MUL(ay[j], vx); GGML_F32_VEC_STORE(y + i + j * GGML_F32_EPR, ay[j]); } } // leftovers for (int i = np; i < n; ++i) { y[i] *= v; } #else // scalar for (int i = 0; i < n; ++i) { y[i] *= v; } #endif } ================================================ FILE: archive/csrc/ktransformers_ext/operators/llamafile/conversion.h ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2024-07-12 10:07:58 * @Version : 1.0.0 * @LastEditors : chenht2022 * @LastEditTime : 2024-07-25 10:34:55 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #ifndef CPUINFER_CONVERSION_H #define CPUINFER_CONVERSION_H #include #include "llama.cpp/ggml.h" inline void to_float(const void* input, float* output, int size, ggml_type type) { if (type == ggml_type::GGML_TYPE_F32) { memcpy(output, input, size * sizeof(float)); } else { ggml_internal_get_type_traits(type).to_float(input, output, size); } } inline void from_float(const float* input, void* output, int size, ggml_type type) { if (type == ggml_type::GGML_TYPE_F32) { memcpy(output, input, size * sizeof(float)); } else { ggml_internal_get_type_traits(type).from_float(input, output, size); } } #endif ================================================ FILE: archive/csrc/ktransformers_ext/operators/llamafile/linear.cpp ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2024-07-12 10:07:58 * @Version : 1.0.0 * @LastEditors : kkk1nak0 * @LastEditTime : 2024-08-15 07:45:18 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "linear.h" Linear::Linear(LinearConfig config) { config_ = config; proj_ = config_.proj; std::vector> mem_requests; mem_requests.push_back({(void**)&input_fp32_, sizeof(float) * config_.group_max_len * config_.input_size}); mem_requests.push_back({(void**)&proj_input_, config_.group_max_len * config_.input_size * ggml_type_size(ggml_internal_get_type_traits(config_.proj_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.proj_type).vec_dot_type)}); mem_requests.push_back({(void**)&proj_output_, sizeof(float) * config_.group_max_len * config_.output_size}); shared_mem_buffer.alloc(this, mem_requests); } Linear::~Linear() { shared_mem_buffer.dealloc(this); } void Linear::warm_up(Backend *backend) { std::vector input_fp32(config_.input_size); std::vector input(config_.input_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type)); std::vector output(config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type)); for (int i = 0; i < config_.input_size; i++) { input_fp32[i] = 0; } from_float(input_fp32.data(), input.data(), config_.input_size, config_.hidden_type); forward_many(1, input.data(), output.data(), backend); } void Linear::forward_many(int qlen, const void* input, void* output, Backend* backend) { const void* proj_input_ptr; if (config_.hidden_type == ggml_internal_get_type_traits(config_.proj_type).vec_dot_type) { proj_input_ptr = input; } else { to_float(input, input_fp32_, qlen * config_.input_size, config_.hidden_type); from_float(input_fp32_, proj_input_, qlen * config_.input_size, ggml_internal_get_type_traits(config_.proj_type).vec_dot_type); proj_input_ptr = proj_input_; } int nth = config_.output_size / config_.stride; backend->do_work_stealing_job(nth, nullptr, [&](int task_id) { int ith = task_id; void* proj_ptr = (uint8_t*)proj_ + ith * config_.stride * config_.input_size * ggml_type_size(config_.proj_type) / ggml_blck_size(config_.proj_type); float* proj_output_ptr = proj_output_ + ith * config_.stride; llamafile_sgemm(config_.stride, qlen, config_.input_size / ggml_blck_size(config_.proj_type), proj_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_input_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_output_ptr, config_.output_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.proj_type, ggml_internal_get_type_traits(config_.proj_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) { for (int i = 0; i < qlen; i++) { float* output_fp32_ptr = proj_output_ + i * config_.output_size + ith * config_.stride; void* output_ptr = (uint8_t*)output + i * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type); } } }, nullptr); if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) { from_float(proj_output_, output, qlen * config_.output_size, config_.hidden_type); } } void Linear::forward(int qlen, const void* input, void* output, Backend* backend) { if (qlen <= 0) { return; } int forward_len = std::min(qlen, config_.group_max_len); forward_many(forward_len, input, output, backend); forward(qlen - forward_len, (uint8_t*)input + forward_len * config_.input_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); } ================================================ FILE: archive/csrc/ktransformers_ext/operators/llamafile/linear.h ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2024-07-12 10:07:58 * @Version : 1.0.0 * @LastEditors : chenht2022 * @LastEditTime : 2024-07-25 10:35:00 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #ifndef CPUINFER_OPERATOR_LINEAR_H #define CPUINFER_OPERATOR_LINEAR_H #include #include #include #include #include #include "../../cpu_backend/backend.h" #include "../../cpu_backend/shared_mem_buffer.h" #include "conversion.h" #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml.h" #include "llamafile/sgemm.h" struct LinearConfig { int input_size; int output_size; int stride; int group_max_len; void* proj; ggml_type proj_type; ggml_type hidden_type; LinearConfig() {} LinearConfig(int input_size, int output_size, int stride, int group_max_len, void* proj, ggml_type proj_type, ggml_type hidden_type) : input_size(input_size), output_size(output_size), stride(stride), group_max_len(group_max_len), proj(proj), proj_type(proj_type), hidden_type(hidden_type) {} }; class Linear { public: Linear(LinearConfig); ~Linear(); void warm_up(Backend* backend); void forward_many(int qlen, const void* input, void* output, Backend* backend); void forward(int qlen, const void* input, void* output, Backend* backend); private: LinearConfig config_; void* proj_; // [output_size * input_size ( /32 if quantized)] float* input_fp32_; // [group_max_len * input_size] uint8_t* proj_input_; // [group_max_len * input_size * ggml_type_size(ggml_internal_get_type_traits(proj_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(proj_type).vec_dot_type)] float* proj_output_; // [group_max_len * output_size] }; #endif ================================================ FILE: archive/csrc/ktransformers_ext/operators/llamafile/mlp.cpp ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2024-07-16 10:43:18 * @Version : 1.0.0 * @LastEditors : kkk1nak0 * @LastEditTime : 2024-08-15 07:44:38 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "mlp.h" MLP::MLP(MLPConfig config) { config_ = config; gate_proj_ = config_.gate_proj; up_proj_ = config_.up_proj; down_proj_ = config_.down_proj; std::vector> mem_requests; mem_requests.push_back({(void**)&input_fp32_, sizeof(float) * config_.group_max_len * config_.hidden_size}); mem_requests.push_back({(void**)&gate_input_, config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)}); mem_requests.push_back({(void**)&up_input_, config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)}); mem_requests.push_back({(void**)&gate_output_, sizeof(float) * config_.group_max_len * config_.intermediate_size}); mem_requests.push_back({(void**)&up_output_, sizeof(float) * config_.group_max_len * config_.intermediate_size}); mem_requests.push_back({(void**)&intermediate_fp32_, sizeof(float) * config_.group_max_len * config_.intermediate_size}); mem_requests.push_back({(void**)&down_input_, config_.group_max_len * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)}); mem_requests.push_back({(void**)&down_output_, sizeof(float) * config_.group_max_len * config_.hidden_size}); shared_mem_buffer.alloc(this, mem_requests); } MLP::~MLP() { shared_mem_buffer.dealloc(this); } void MLP::warm_up(Backend *backend) { std::vector input_fp32(config_.hidden_size); std::vector input(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type)); std::vector output(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type)); for (int i = 0; i < config_.hidden_size; i++) { input_fp32[i] = 0; } from_float(input_fp32.data(), input.data(), config_.hidden_size, config_.hidden_type); forward_many(1, input.data(), output.data(), backend); } static float act_fn(float x) { return x / (1.0f + expf(-x)); } void MLP::forward_many(int qlen, const void* input, void* output, Backend* backend) { const void* gate_input_ptr; const void* up_input_ptr; if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { gate_input_ptr = up_input_ptr = input; } else { to_float(input, input_fp32_, qlen * config_.hidden_size, config_.hidden_type); if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { from_float(input_fp32_, gate_input_, qlen * config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); gate_input_ptr = up_input_ptr = gate_input_; } else { if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) { from_float(input_fp32_, gate_input_, qlen * config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); gate_input_ptr = gate_input_; } else { gate_input_ptr = input; } if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { from_float(input_fp32_, up_input_, qlen * config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type); up_input_ptr = up_input_; } else { up_input_ptr = input; } } } int nth = config_.intermediate_size / config_.stride; backend->do_work_stealing_job(nth, nullptr, [&](int task_id) { int ith = task_id; void* gate_proj_ptr = (uint8_t*)gate_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); float* gate_output_ptr = gate_output_ + ith * config_.stride; llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); void* up_proj_ptr = (uint8_t*)up_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); float* up_output_ptr = up_output_ + ith * config_.stride; llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = 0; i < qlen; i++) { for (int j = ith * config_.stride; j < (ith + 1) * config_.stride; j++) { intermediate_fp32_[i * config_.intermediate_size + j] = act_fn(gate_output_[i * config_.intermediate_size + j]) * up_output_[i * config_.intermediate_size + j]; } if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) { float* intermediate_fp32_ptr = intermediate_fp32_ + i * config_.intermediate_size + ith * config_.stride; void* down_input_ptr = (uint8_t*)down_input_ + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type); from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type); } } }, nullptr); if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) != 0) { from_float(intermediate_fp32_, down_input_, qlen * config_.intermediate_size, ggml_internal_get_type_traits(config_.down_type).vec_dot_type); } nth = config_.hidden_size / config_.stride; backend->do_work_stealing_job(nth, nullptr, [&](int task_id) { int ith = task_id; void* down_proj_ptr = (uint8_t*)down_proj_ + ith * config_.stride * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); float* down_output_ptr = down_output_ + ith * config_.stride; llamafile_sgemm(config_.stride, qlen, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) { for (int i = 0; i < qlen; i++) { float* output_fp32_ptr = down_output_ + i * config_.hidden_size + ith * config_.stride; void* output_ptr = (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type); } } }, nullptr); if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) { from_float(down_output_, output, qlen * config_.hidden_size, config_.hidden_type); } } void MLP::forward(int qlen, const void* input, void* output, Backend* backend) { if (qlen <= 0) { return; } int forward_len = std::min(qlen, config_.group_max_len); forward_many(forward_len, input, output, backend); forward(qlen - forward_len, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); } ================================================ FILE: archive/csrc/ktransformers_ext/operators/llamafile/mlp.h ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2024-07-12 10:07:58 * @Version : 1.0.0 * @LastEditors : chenht2022 * @LastEditTime : 2024-07-25 10:35:06 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #ifndef CPUINFER_OPERATOR_MLP_H #define CPUINFER_OPERATOR_MLP_H #include #include #include #include #include #include "../../cpu_backend/backend.h" #include "../../cpu_backend/shared_mem_buffer.h" #include "conversion.h" #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml.h" #include "llamafile/sgemm.h" struct MLPConfig { int hidden_size; int intermediate_size; int stride; int group_max_len; void* gate_proj; void* up_proj; void* down_proj; ggml_type gate_type; ggml_type up_type; ggml_type down_type; ggml_type hidden_type; MLPConfig() {} MLPConfig(int hidden_size, int intermediate_size, int stride, int group_max_len, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type) : hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_max_len(group_max_len), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {} }; class MLP { public: MLP(MLPConfig); ~MLP(); void warm_up(Backend* backend); void forward_many(int qlen, const void* input, void* output, Backend* backend); void forward(int qlen, const void* input, void* output, Backend* backend); private: MLPConfig config_; void* gate_proj_; // [intermediate_size * hidden_size ( /32 if quantized)] void* up_proj_; // [intermediate_size * hidden_size ( /32 if quantized)] void* down_proj_; // [hidden_size * intermediate_size ( /32 if quantized)] float* input_fp32_; // [group_max_len * hidden_size] uint8_t* gate_input_; // [group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)] uint8_t* up_input_; // [group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)] float* gate_output_; // [group_max_len * intermediate_size] float* up_output_; // [group_max_len * intermediate_size] float* intermediate_fp32_; // [group_max_len * intermediate_size] uint8_t* down_input_; // [group_max_len * intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)] float* down_output_; // [group_max_len * hidden_size] }; #endif ================================================ FILE: archive/csrc/ktransformers_ext/operators/llamafile/moe.cpp ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2024-07-22 02:03:22 * @Version : 1.0.0 * @LastEditors : kkk1nak0 * @LastEditTime : 2024-08-15 07:43:41 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #include "moe.h" #include #include #include #ifdef USE_NUMA #include #include #endif MOE::MOE(MOEConfig config) { config_ = config; gate_proj_ = config_.gate_proj; up_proj_ = config_.up_proj; down_proj_ = config_.down_proj; #ifdef USE_NUMA int numa_nodes = numa_num_configured_nodes(); gate_proj_numa_.resize(numa_nodes); up_proj_numa_.resize(numa_nodes); down_proj_numa_.resize(numa_nodes); size_t exp_inter_hidden_mul_ = (size_t)config.expert_num * config.intermediate_size * config.hidden_size; for (int i = 0; i < numa_nodes; i++) { gate_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type), i); up_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type), i); down_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type), i); if (!gate_proj_numa_[i]) { std::cout << "Memory allocation failed for gate_proj_numa_ on node " << i << std::endl; } if (!up_proj_numa_[i]) { std::cout << "Memory allocation failed for up_proj_numa_ on node " << i << std::endl; } if (!down_proj_numa_[i]) { std::cout << "Memory allocation failed for down_proj_numa_ on node " << i << std::endl; } memcpy(gate_proj_numa_[i], gate_proj_, exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type)); memcpy(up_proj_numa_[i], up_proj_, exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type)); memcpy(down_proj_numa_[i], down_proj_, exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type)); } #endif std::vector> s_mem_requests; s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size}); s_mem_requests.push_back({(void**)&s_gate_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)}); s_mem_requests.push_back({(void**)&s_up_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)}); s_gate_output_.resize(config_.routed_expert_num); s_up_output_.resize(config_.routed_expert_num); s_intermediate_fp32_.resize(config_.routed_expert_num); s_down_input_.resize(config_.routed_expert_num); s_down_output_.resize(config_.routed_expert_num); for (int i = 0; i < config_.routed_expert_num; i++) { s_mem_requests.push_back({(void**)&s_gate_output_[i], sizeof(float) * config_.intermediate_size}); s_mem_requests.push_back({(void**)&s_up_output_[i], sizeof(float) * config_.intermediate_size}); s_mem_requests.push_back({(void**)&s_intermediate_fp32_[i], sizeof(float) * config_.intermediate_size}); s_mem_requests.push_back({(void**)&s_down_input_[i], config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)}); s_mem_requests.push_back({(void**)&s_down_output_[i], sizeof(float) * config_.hidden_size}); } s_mem_requests.push_back({(void**)&s_output_fp32_, sizeof(float) * config_.hidden_size}); shared_mem_buffer.alloc(this, s_mem_requests); std::vector> m_mem_requests; m_input_fp32_.resize(config_.group_max_len); m_gate_input_.resize(config_.group_max_len); m_up_input_.resize(config_.group_max_len); for (int i = 0; i < config_.group_max_len; i++) { m_mem_requests.push_back({(void**)&m_input_fp32_[i], sizeof(float) * config_.hidden_size}); m_mem_requests.push_back({(void**)&m_gate_input_[i], config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)}); m_mem_requests.push_back({(void**)&m_up_input_[i], config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)}); } m_mem_requests.push_back({(void**)&m_local_gate_input_, config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)}); m_mem_requests.push_back({(void**)&m_local_up_input_, config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)}); m_mem_requests.push_back({(void**)&m_local_gate_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size}); m_mem_requests.push_back({(void**)&m_local_up_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size}); m_mem_requests.push_back({(void**)&m_local_intermediate_fp32_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size}); m_mem_requests.push_back({(void**)&m_local_down_input_, config_.routed_expert_num * config_.group_max_len * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)}); m_mem_requests.push_back({(void**)&m_local_down_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.hidden_size}); m_output_fp32_.resize(config_.group_max_len); for (int i = 0; i < config_.group_max_len; i++) { m_mem_requests.push_back({(void**)&m_output_fp32_[i], sizeof(float) * config_.hidden_size}); } shared_mem_buffer.alloc(this, m_mem_requests); m_local_pos_.resize(config_.group_max_len); for (int i = 0; i < config_.group_max_len; i++) { m_local_pos_[i].resize(config_.routed_expert_num); } m_local_num_.resize(config_.expert_num); m_local_gate_input_ptr_.resize(config_.expert_num); m_local_up_input_ptr_.resize(config_.expert_num); m_local_gate_output_ptr_.resize(config_.expert_num); m_local_up_output_ptr_.resize(config_.expert_num); m_local_intermediate_fp32_ptr_.resize(config_.expert_num); m_local_down_input_ptr_.resize(config_.expert_num); m_local_down_output_ptr_.resize(config_.expert_num); } MOE::~MOE() { shared_mem_buffer.dealloc(this); #ifdef USE_NUMA int numa_nodes = numa_num_configured_nodes(); for (int i = 0; i < numa_nodes; i++) { numa_free(gate_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type)); numa_free(up_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type)); numa_free(down_proj_numa_[i], config_.expert_num * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type)); } #endif } void MOE::warm_up(Backend* backend) { std::vector input_fp32(config_.hidden_size); std::vector input(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type)); std::vector output(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type)); for (int i = 0; i < config_.hidden_size; i++) { input_fp32[i] = 0; } from_float(input_fp32.data(), input.data(), config_.hidden_size, config_.hidden_type); for (int i = 0; i < config_.expert_num; i++) { uint64_t expert_ids = i; float weights = 0; forward_one(1, &expert_ids, &weights, input.data(), output.data(), backend); } } static float act_fn(float x) { return x / (1.0f + expf(-x)); } static float act_fn_relu(float x) { if(x > 0.0){ return x; } else { return 0.0; } } void MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) { const void* gate_input_ptr; const void* up_input_ptr; if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { gate_input_ptr = up_input_ptr = input; } else { to_float(input, s_input_fp32_, config_.hidden_size, config_.hidden_type); if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { from_float(s_input_fp32_, s_gate_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); gate_input_ptr = up_input_ptr = s_gate_input_; } else { if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) { from_float(s_input_fp32_, s_gate_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); gate_input_ptr = s_gate_input_; } else { gate_input_ptr = input; } if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { from_float(s_input_fp32_, s_up_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type); up_input_ptr = s_up_input_; } else { up_input_ptr = input; } } } int nth = config_.intermediate_size / config_.stride; backend->do_work_stealing_job(nth * k, nullptr, [&](int task_id) { int expert_idx = task_id / nth; uint64_t expert_id = expert_ids[expert_idx]; int ith = task_id % nth; #ifdef USE_NUMA void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); #else void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); #endif float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); #ifdef USE_NUMA void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); #else void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); #endif float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); if(config_.use_silu){ // use silu as act fn for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i]; } } else { // use relu as act fn for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { s_intermediate_fp32_[expert_idx][i] = act_fn_relu(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i]; } } if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) { float* intermediate_fp32_ptr = s_intermediate_fp32_[expert_idx] + ith * config_.stride; void* down_input_ptr = s_down_input_[expert_idx] + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type); from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type); } }, nullptr); if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) != 0) { for (int i = 0; i < k; i++) { from_float(s_intermediate_fp32_[i], s_down_input_[i], config_.intermediate_size, ggml_internal_get_type_traits(config_.down_type).vec_dot_type); } } nth = config_.hidden_size / config_.stride; backend->do_work_stealing_job(nth, nullptr, [&](int task_id) { int ith = task_id; for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { s_output_fp32_[i] = 0; } for (int expert_idx = 0; expert_idx < k; expert_idx++) { uint64_t expert_id = expert_ids[expert_idx]; #ifdef USE_NUMA void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); #else void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); #endif float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride; llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) { s_output_fp32_[i] += s_down_output_[expert_idx][i] * weights[expert_idx]; } } if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) { float* output_fp32_ptr = s_output_fp32_ + ith * config_.stride; void* output_ptr = (uint8_t*)output + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type); } }, nullptr); if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) { from_float(s_output_fp32_, output, config_.hidden_size, config_.hidden_type); } } void MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) { for (int i = 0; i < config_.expert_num; i++) { m_local_num_[i] = 0; } for (int i = 0; i < qlen; i++) { for (int j = 0; j < k; j++) { m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++; } } uint64_t offset = 0; for (int i = 0; i < config_.expert_num; i++) { m_local_gate_input_ptr_[i] = m_local_gate_input_ + offset * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); m_local_up_input_ptr_[i] = m_local_up_input_ + offset * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type); m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size; m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size; m_local_intermediate_fp32_ptr_[i] = m_local_intermediate_fp32_ + offset * config_.intermediate_size; m_local_down_input_ptr_[i] = m_local_down_input_ + offset * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type); m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size; offset += m_local_num_[i]; } backend->do_work_stealing_job(qlen, nullptr, [&](int i) { const void* gate_input_ptr; const void* up_input_ptr; if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { gate_input_ptr = up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); } else { to_float((uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), m_input_fp32_[i], config_.hidden_size, config_.hidden_type); if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); gate_input_ptr = up_input_ptr = m_gate_input_[i]; } else { if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) { from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type); gate_input_ptr = m_gate_input_[i]; } else { gate_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); } if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) { from_float(m_input_fp32_[i], m_up_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type); up_input_ptr = m_up_input_[i]; } else { up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type); } } } for (int j = 0; j < k; j++) { memcpy(m_local_gate_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type), gate_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)); memcpy(m_local_up_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type), up_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)); } }, nullptr); int stride = QK_K; int nth = config_.intermediate_size / stride; backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) { uint64_t expert_idx = task_id / nth; int ith = task_id % nth; void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx]; #ifdef USE_NUMA void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); #else void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type); #endif float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); void* up_input_ptr = m_local_up_input_ptr_[expert_idx]; #ifdef USE_NUMA void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); #else void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type); #endif float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); for (int i = 0; i < m_local_num_[expert_idx]; i++) { if(config_.use_silu){ for (int j = ith * stride; j < (ith + 1) * stride; j++) { m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j]; } } else { for (int j = ith * stride; j < (ith + 1) * stride; j++) { m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn_relu(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j]; } } float* intermediate_fp32_ptr = m_local_intermediate_fp32_ptr_[expert_idx] + i * config_.intermediate_size + ith * stride; void* down_input_ptr = m_local_down_input_ptr_[expert_idx] + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type); from_float(intermediate_fp32_ptr, down_input_ptr, stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type); } }, nullptr); stride = QK_K; nth = config_.hidden_size / stride; backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) { uint64_t expert_idx = task_id / nth; int ith = task_id % nth; void* down_input_ptr = m_local_down_input_ptr_[expert_idx]; #ifdef USE_NUMA void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); #else void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type); #endif float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride; llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT); }, nullptr); backend->do_work_stealing_job(qlen, nullptr, [&](int i) { for (int e = 0; e < config_.hidden_size; e++) { m_output_fp32_[i][e] = 0; } for (int j = 0; j < k; j++) { for (int e = 0; e < config_.hidden_size; e++) { m_output_fp32_[i][e] += m_local_down_output_ptr_[expert_ids[i * k + j]][m_local_pos_[i][j] * config_.hidden_size + e] * weights[i * k + j]; } } from_float(m_output_fp32_[i], (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), config_.hidden_size, config_.hidden_type); }, nullptr); } void MOE::forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, int* batch_size_tensor, Backend* backend) { qlen = batch_size_tensor[0]; if (qlen < config_.group_min_len) { for (int i = 0; i < qlen; i++) { forward_one(k, expert_ids + i * k, weights + i * k, (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend); } return; } int forward_len = std::min(config_.group_max_len, qlen); forward_many(forward_len, k, expert_ids, weights, input, output, backend); batch_size_tensor[0] -= forward_len; forward(qlen - forward_len, k, expert_ids + forward_len * k, weights + forward_len * k, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), batch_size_tensor, backend); } ================================================ FILE: archive/csrc/ktransformers_ext/operators/llamafile/moe.h ================================================ /** * @Description : * @Author : chenht2022 * @Date : 2024-07-22 02:03:22 * @Version : 1.0.0 * @LastEditors : chenht2022 * @LastEditTime : 2024-07-25 10:35:10 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. **/ #ifndef CPUINFER_OPERATOR_MOE_H #define CPUINFER_OPERATOR_MOE_H #include #include #include #include #include #include "../../cpu_backend/backend.h" #include "../../cpu_backend/shared_mem_buffer.h" #include "conversion.h" #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-quants.h" #include "llama.cpp/ggml.h" #include "llamafile/sgemm.h" struct MOEConfig { int expert_num; int routed_expert_num; int hidden_size; int intermediate_size; int stride; int group_min_len; int group_max_len; bool use_silu; void* gate_proj; void* up_proj; void* down_proj; ggml_type gate_type; ggml_type up_type; ggml_type down_type; ggml_type hidden_type; MOEConfig() {} MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, bool use_silu, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type) : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_min_len(group_min_len), group_max_len(group_max_len), use_silu(use_silu), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {} }; class MOE { public: MOE(MOEConfig); ~MOE(); void warm_up(Backend* backend); void forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend); void forward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend); void forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, int* batch_size_tensor, Backend* backend); private: MOEConfig config_; void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)] #ifdef USE_NUMA std::vector gate_proj_numa_; // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)] std::vector up_proj_numa_; // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)] std::vector down_proj_numa_; // [numa_num, expert_num * hidden_size * intermediate_size ( /32 if quantized)] #endif float* s_input_fp32_; // [hidden_size] uint8_t* s_gate_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)] uint8_t* s_up_input_; // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)] std::vector s_gate_output_; // [routed_expert_num, intermediate_size] std::vector s_up_output_; // [routed_expert_num, intermediate_size] std::vector s_intermediate_fp32_; // [routed_expert_num, intermediate_size] std::vector s_down_input_; // [routed_expert_num, intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)] std::vector s_down_output_; // [routed_expert_num, hidden_size] float* s_output_fp32_; // [hidden_size] std::vector m_input_fp32_; // [group_max_len, hidden_size] std::vector m_gate_input_; // [group_max_len, hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)] std::vector m_up_input_; // [group_max_len, hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)] uint8_t* m_local_gate_input_; // [routed_expert_num * group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)] uint8_t* m_local_up_input_; // [routed_expert_num * group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)] float* m_local_gate_output_; // [routed_expert_num * group_max_len * intermediate_size] float* m_local_up_output_; // [routed_expert_num * group_max_len * intermediate_size] float* m_local_intermediate_fp32_; // [routed_expert_num * group_max_len * intermediate_size] uint8_t* m_local_down_input_; // [routed_expert_num * group_max_len * intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)] float* m_local_down_output_; // [routed_expert_num * group_max_len * hidden_size] std::vector m_output_fp32_; // [group_max_len, hidden_size] std::vector> m_local_pos_; // [group_max_len, routed_expert_num] std::vector m_local_num_; // [expert_num] std::vector m_local_gate_input_ptr_; // [expert_num] std::vector m_local_up_input_ptr_; // [expert_num] std::vector m_local_gate_output_ptr_; // [expert_num] std::vector m_local_up_output_ptr_; // [expert_num] std::vector m_local_intermediate_fp32_ptr_; // [expert_num] std::vector m_local_down_input_ptr_; // [expert_num] std::vector m_local_down_output_ptr_; // [expert_num] }; #endif ================================================ FILE: archive/csrc/ktransformers_ext/vendors/cuda.h ================================================ #pragma once #include #include #include #include #include #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH #define CUBLAS_COMPUTE_16F CUDA_R_16F #define CUBLAS_COMPUTE_32F CUDA_R_32F #define cublasComputeType_t cudaDataType_t #endif // CUDART_VERSION < 11020 ================================================ FILE: archive/csrc/ktransformers_ext/vendors/hip.h ================================================ #pragma once #define HIP_ENABLE_WARP_SYNC_BUILTINS 1 #include #include #include #include #ifdef __HIP_PLATFORM_AMD__ // for rocblas_initialize() #include "rocblas/rocblas.h" #endif // __HIP_PLATFORM_AMD__ #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_OP_N HIPBLAS_OP_N #define CUBLAS_OP_T HIPBLAS_OP_T #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_TF32_TENSOR_OP_MATH 0 #define CUDA_R_16F HIPBLAS_R_16F #define CUDA_R_32F HIPBLAS_R_32F #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended #define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 #define cublasCreate hipblasCreate #define cublasDestroy hipblasDestroy #define cublasGemmEx hipblasGemmEx #define cublasGemmBatchedEx hipblasGemmBatchedEx #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx #define cublasHandle_t hipblasHandle_t #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS #define cublasSetStream hipblasSetStream #define cublasSgemm hipblasSgemm #define cublasStatus_t hipblasStatus_t #define cublasOperation_t hipblasOperation_t #define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6 #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags hipEventCreateWithFlags #define cudaEventDisableTiming hipEventDisableTiming #define cudaEventRecord hipEventRecord #define cudaEventSynchronize hipEventSynchronize #define cudaEvent_t hipEvent_t #define cudaEventDestroy hipEventDestroy #define cudaFree hipFree #define cudaFreeHost hipHostFree #define cudaGetDevice hipGetDevice #define cudaGetDeviceCount hipGetDeviceCount #define cudaGetDeviceProperties hipGetDeviceProperties #define cudaGetErrorString hipGetErrorString #define cudaGetLastError hipGetLastError #define cudaHostRegister hipHostRegister #define cudaHostRegisterPortable hipHostRegisterPortable #define cudaHostRegisterReadOnly hipHostRegisterReadOnly #define cudaHostUnregister hipHostUnregister #define cudaLaunchHostFunc hipLaunchHostFunc #define cudaMalloc hipMalloc #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) #define cudaMemcpy hipMemcpy #define cudaMemcpyAsync hipMemcpyAsync #define cudaMemcpyPeerAsync hipMemcpyPeerAsync #define cudaMemcpy2DAsync hipMemcpy2DAsync #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost #define cudaMemcpyHostToDevice hipMemcpyHostToDevice #define cudaMemcpyKind hipMemcpyKind #define cudaMemset hipMemset #define cudaMemsetAsync hipMemsetAsync #define cudaMemGetInfo hipMemGetInfo #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize #define cudaSetDevice hipSetDevice #define cuDeviceGet hipDeviceGet #define CUdevice hipDevice_t #define CUdeviceptr hipDeviceptr_t #define cuMemUnmap hipMemUnmap #define CUmemAccessDesc hipMemAccessDesc #define cuMemAddressFree hipMemAddressFree #define cuMemRelease hipMemRelease #define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t #define cuMemCreate hipMemCreate #define cuMemAddressReserve hipMemAddressReserve #define cuMemMap hipMemMap #define cuMemSetAccess hipMemSetAccess #define cuMemGetAllocationGranularity hipMemGetAllocationGranularity #define CUmemAllocationProp hipMemAllocationProp #define cuDeviceGetAttribute hipDeviceGetAttribute #define cudaStreamCreateWithFlags hipStreamCreateWithFlags #define cudaStreamDestroy hipStreamDestroy #define cudaStreamFireAndForget hipStreamFireAndForget #define cudaStreamNonBlocking hipStreamNonBlocking #define cudaStreamPerThread hipStreamPerThread #define cudaStreamSynchronize hipStreamSynchronize #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) #define cudaGraphExec_t hipGraphExec_t #define cudaGraphNode_t hipGraphNode_t #define cudaKernelNodeParams hipKernelNodeParams #define cudaKernelNodeParams hipKernelNodeParams #define cudaGraphExecDestroy hipGraphExecDestroy #define cudaGraphLaunch hipGraphLaunch #define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure #define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult #define cudaGraphNodeType hipGraphNodeType #define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel #define cudaGraphInstantiate hipGraphInstantiate #define cudaStreamEndCapture hipStreamEndCapture #define cudaGraphDestroy hipGraphDestroy #define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams #define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction #define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams #define cudaGraphNodeGetType hipGraphNodeGetType #define cudaGraphGetNodes hipGraphGetNodes #define cudaGraphExecUpdate hipGraphExecUpdate #define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed #define cudaStreamBeginCapture hipStreamBeginCapture #define cudaGraph_t hipGraph_t #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess #define cudaHostFn_t hipHostFn_t #define __trap() do { abort(); __builtin_unreachable(); } while(0) #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED #define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE #define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH #define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED #define __CUDA_ARCH__ 1300 #if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) #define GCN #endif #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) #define CDNA #endif #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ defined(__gfx1150__) || defined(__gfx1151__) #define RDNA3 #endif #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) #define RDNA2 #endif #if defined(__gfx1010__) || defined(__gfx1012__) #define RDNA1 #endif #ifndef __has_builtin #define __has_builtin(x) 0 #endif typedef hip_bfloat16 nv_bfloat16; ================================================ FILE: archive/csrc/ktransformers_ext/vendors/musa.h ================================================ #pragma once #include #include #include #include #include #define CUBLAS_COMPUTE_16F CUDA_R_16F #define CUBLAS_COMPUTE_32F CUDA_R_32F #define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F #define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT #define CUBLAS_OP_N MUBLAS_OP_N #define CUBLAS_OP_T MUBLAS_OP_T #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT #define CUDA_R_16F MUSA_R_16F #define CUDA_R_32F MUSA_R_32F #define cublasComputeType_t cudaDataType_t #define cublasCreate mublasCreate #define cublasDestroy mublasDestroy #define cublasGemmEx mublasGemmEx #define cublasGemmBatchedEx mublasGemmBatchedEx #define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx #define cublasHandle_t mublasHandle_t #define cublasSetMathMode mublasSetMathMode #define cublasSetStream mublasSetStream #define cublasSgemm mublasSgemm #define cublasStatus_t mublasStatus_t #define cublasOperation_t mublasOperation_t #define cublasGetStatusString mublasStatus_to_string #define cudaDataType_t musaDataType_t #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess #define cudaDeviceProp musaDeviceProp #define cudaDeviceSynchronize musaDeviceSynchronize #define cudaError_t musaError_t #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags musaEventCreateWithFlags #define cudaEventDisableTiming musaEventDisableTiming #define cudaEventRecord musaEventRecord #define cudaEventSynchronize musaEventSynchronize #define cudaEvent_t musaEvent_t #define cudaEventDestroy musaEventDestroy #define cudaFree musaFree #define cudaFreeHost musaFreeHost #define cudaGetDevice musaGetDevice #define cudaGetDeviceCount musaGetDeviceCount #define cudaGetDeviceProperties musaGetDeviceProperties #define cudaGetErrorString musaGetErrorString #define cudaGetLastError musaGetLastError #define cudaHostRegister musaHostRegister #define cudaHostRegisterPortable musaHostRegisterPortable #define cudaHostRegisterReadOnly musaHostRegisterReadOnly #define cudaHostUnregister musaHostUnregister #define cudaLaunchHostFunc musaLaunchHostFunc #define cudaMalloc musaMalloc #define cudaMallocHost musaMallocHost #define cudaMallocManaged musaMallocManaged #define cudaMemcpy musaMemcpy #define cudaMemcpyAsync musaMemcpyAsync #define cudaMemcpyPeerAsync musaMemcpyPeerAsync #define cudaMemcpy2DAsync musaMemcpy2DAsync #define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice #define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost #define cudaMemcpyHostToDevice musaMemcpyHostToDevice #define cudaMemcpyKind musaMemcpyKind #define cudaMemset musaMemset #define cudaMemsetAsync musaMemsetAsync #define cudaMemGetInfo musaMemGetInfo #define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize #define cudaSetDevice musaSetDevice #define cudaStreamCreateWithFlags musaStreamCreateWithFlags #define cudaStreamDestroy musaStreamDestroy #define cudaStreamFireAndForget musaStreamFireAndForget #define cudaStreamNonBlocking musaStreamNonBlocking #define cudaStreamPerThread musaStreamPerThread #define cudaStreamSynchronize musaStreamSynchronize #define cudaStreamWaitEvent musaStreamWaitEvent #define cudaStream_t musaStream_t #define cudaSuccess musaSuccess // Additional mappings for MUSA virtual memory pool #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED #define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED #define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE #define CUdevice MUdevice #define CUdeviceptr MUdeviceptr #define CUmemAccessDesc MUmemAccessDesc #define CUmemAllocationProp MUmemAllocationProp #define CUmemGenericAllocationHandle MUmemGenericAllocationHandle #define cuDeviceGet muDeviceGet #define cuDeviceGetAttribute muDeviceGetAttribute #define cuMemAddressFree muMemAddressFree #define cuMemAddressReserve muMemAddressReserve #define cuMemCreate muMemCreate #define cuMemGetAllocationGranularity muMemGetAllocationGranularity #define cuMemMap muMemMap #define cuMemRelease muMemRelease #define cuMemSetAccess muMemSetAccess #define cuMemUnmap muMemUnmap #define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize #define cudaFuncSetAttribute musaFuncSetAttribute #define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms #define make_cudaExtent make_musaExtent #define make_cudaPitchedPtr make_musaPitchedPtr // Additional mappings for MUSA graphs #define CUDA_SUCCESS MUSA_SUCCESS #define CUresult MUresult #define cuGetErrorString muGetErrorString #define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure #define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction #define cudaGraphDestroy musaGraphDestroy #define cudaGraphExecDestroy musaGraphExecDestroy #define cudaGraphExec_t musaGraphExec_t #define cudaGraphExecUpdate musaGraphExecUpdate #define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult #define cudaGraphGetNodes musaGraphGetNodes #define cudaGraphInstantiate musaGraphInstantiate #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams #define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams #define cudaGraphLaunch musaGraphLaunch #define cudaGraphNodeGetType musaGraphNodeGetType #define cudaGraphNode_t musaGraphNode_t #define cudaGraphNodeType musaGraphNodeType #define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel #define cudaGraph_t musaGraph_t #define cudaKernelNodeParams musaKernelNodeParams #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed #define cudaStreamEndCapture musaStreamEndCapture typedef mt_bfloat16 nv_bfloat16; ================================================ FILE: archive/csrc/ktransformers_ext/vendors/vendor.h ================================================ #ifndef CPUINFER_VENDOR_VENDOR_H #define CPUINFER_VENDOR_VENDOR_H #ifdef USE_CUDA #include "cuda.h" #elif USE_HIP #define __HIP_PLATFORM_AMD__ #include "hip.h" #elif USE_MUSA #include "musa.h" #endif #endif // CPUINFER_VENDOR_VENDOR_H ================================================ FILE: archive/install-with-cache.sh ================================================ #!/bin/bash set -e # clear build dirs # rm -rf build # rm -rf *.egg-info # rm -rf csrc/build # rm -rf csrc/ktransformers_ext/build # rm -rf csrc/ktransformers_ext/cuda/build # rm -rf csrc/ktransformers_ext/cuda/dist # rm -rf csrc/ktransformers_ext/cuda/*.egg-info rm -rf ~/.ktransformers echo "Installing python dependencies from requirements.txt" pip install -r requirements-local_chat.txt pip install -r ktransformers/server/requirements.txt echo "Installing ktransformers" KTRANSFORMERS_FORCE_BUILD=TRUE USE_BALANCE_SERVE=1 pip install -v . --no-build-isolation pip install third_party/custom_flashinfer/ -v # SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])") # echo "Copying thirdparty libs to $SITE_PACKAGES" # cp -a csrc/balance_serve/build/third_party/prometheus-cpp/lib/libprometheus-cpp-*.so* $SITE_PACKAGES/ # patchelf --set-rpath '$ORIGIN' $SITE_PACKAGES/sched_ext.cpython* echo "Installation completed successfully" ================================================ FILE: archive/install.bat ================================================ @echo off REM clear build dirs rmdir /S /Q ktransformers\ktransformers_ext\build rmdir /S /Q ktransformers\ktransformers_ext\cuda\build rmdir /S /Q ktransformers\ktransformers_ext\cuda\dist rmdir /S /Q ktransformers\ktransformers_ext\out del /F /Q ktransformers\ktransformers_ext\cuda\*.egg-info echo Installing python dependencies from requirements.txt pip install -r requirements-local_chat.txt echo Installing ktransformers set KTRANSFORMERS_FORCE_BUILD=TRUE pip install . --no-build-isolation echo Installation completed successfully ================================================ FILE: archive/install.sh ================================================ #!/bin/bash set -e # default backend DEV="cuda" # parse --dev argument while [[ "$#" -gt 0 ]]; do case $1 in --dev) DEV="$2"; shift ;; *) echo "Unknown parameter passed: $1"; exit 1 ;; esac shift done export DEV_BACKEND="$DEV" echo "Selected backend: $DEV_BACKEND" # clear build dirs rm -rf build rm -rf *.egg-info rm -rf csrc/build rm -rf csrc/ktransformers_ext/build rm -rf csrc/ktransformers_ext/cuda/build rm -rf csrc/ktransformers_ext/cuda/dist rm -rf csrc/ktransformers_ext/cuda/*.egg-info rm -rf ~/.ktransformers echo "Installing python dependencies from requirements.txt" pip install -r requirements-local_chat.txt pip install -r ktransformers/server/requirements.txt echo "Installing ktransformers" KTRANSFORMERS_FORCE_BUILD=TRUE pip install -v . --no-build-isolation if [[ "$DEV_BACKEND" == "cuda" ]]; then echo "Installing custom_flashinfer for CUDA backend" pip install third_party/custom_flashinfer/ fi # SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])") # echo "Copying thirdparty libs to $SITE_PACKAGES" # cp -a csrc/balance_serve/build/third_party/prometheus-cpp/lib/libprometheus-cpp-*.so* $SITE_PACKAGES/ # patchelf --set-rpath '$ORIGIN' $SITE_PACKAGES/sched_ext.cpython* echo "Installation completed successfully" ================================================ FILE: archive/ktransformers/__init__.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : kkk1nak0 Date : 2024-08-15 07:34:46 Version : 1.0.0 LastEditors : chenxl LastEditTime : 2025-02-15 03:53:02 ''' __version__ = "0.4.1" ================================================ FILE: archive/ktransformers/configs/config.yaml ================================================ log: dir: "logs" file: "lexllama.log" #log level: debug, info, warn, error, crit level: "debug" backup_count: -1 server: ip: 0.0.0.0 port: 10002 db: type: "sqllite" database: "server.db" host: "./" pool_size: 10 user: secret_key: "981f1dd2a44e27d68759d0252a486568ed43480b4e616a26e3af3709c3a7ce73" algorithm: "HS256" model: # type: transformers type: balance_serve # type: ktransformers name: DeepSeek-Coder-V2-Instruct path: deepseek-ai/DeepSeek-V2-Lite-Chat gguf_path: /mnt/data/models/Smallthinker-21B device: cuda:0 cache_lens: 16384 max_new_tokens: 500 web: mount: False open_cross_domain: True ext: cpu_infer: 10 long_context: max_seq_len: 32000 block_size: 128 local_windows_len: 4096 second_select_num: 32 anchor_type: DYNAMIC kv_type: FP16 dense_layer_num: 2 anchor_num: 1 preselect_block: True head_select_mode: SHARED preselect_block_count: 32 layer_step: 1 token_step: local_chat: prompt_file: "" async_server: sched_strategy: "FCFS" sched_port: 56441 sched_metrics_port: 54321 kvc2_metrics_port: 54391 max_batch_size: 4 # decode count + prefill count, in one mini batch attn: page_size: 256 chunk_size: 256 kvc2: gpu_only: true utilization_percentage: 1.0 cpu_memory_size_GB: 500 disk_path: /home/wjh/kvc ================================================ FILE: archive/ktransformers/configs/log_config.ini ================================================ [loggers] keys=root,uvicorn,uvicornError,uvicornAccess [handlers] keys=consoleHandler,fileHandler [formatters] keys=detailedFormatter [logger_root] level=INFO handlers=consoleHandler [logger_uvicorn] level=INFO handlers=consoleHandler,fileHandler qualname=uvicorn propagate=0 [logger_uvicornError] level=ERROR handlers=consoleHandler,fileHandler qualname=uvicorn.error propagate=0 [logger_uvicornAccess] level=INFO handlers=consoleHandler,fileHandler qualname=uvicorn.access propagate=0 [handler_consoleHandler] class=StreamHandler level=INFO formatter=detailedFormatter args=(sys.stdout,) [handler_fileHandler] class=logging.FileHandler level=INFO formatter=detailedFormatter args=('uvicorn_logs.log', 'a') [formatter_detailedFormatter] format=%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s datefmt=%Y-%m-%d %H:%M:%S ================================================ FILE: archive/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/__init__.py ================================================ ================================================ FILE: archive/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/format_24.py ================================================ # # Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es). # import torch # This is PyTorch implementation of main part of reorder_meta() # function, from tools/util/include/cutlass/util/host_reorder.h file # of CUTLASS source tree. Furthermore, CUTLASS template for sparse # GEMM decides upon layout of this matrix, and at the moment for the # sparse GEMM executed on tensor cores, this is layout described by # ColumnMajorInterleaved<2> data structure, in # include/cutlass/layout/matrix.h of CUTLASS source tree. The # reordering of meta matrix into meta_reordered matrix calculated # according to these segments of CUTLASS code is re-implemented here. # Note that this calculation produces offsets for scattering metadata # matrix elements into reordered metadata matrix elements (or, # equivalently, for gathering reordered metadata matrix element back # into metadata matrix elements). def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) # Reorder the rows, then swizzle the 2x2 blocks. group_x = 64 group_y = 32 if meta_dtype.itemsize == 2 else 16 dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 + (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 + ((dst_rows % group_x) // 8) * 4) topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) dst_rows += topright - bottomleft dst_cols -= topright - bottomleft # Assumed that meta tensor is to be stored in CUTLASS # InterleavedColumnMajor layout, and reverse engineered # corresponding code to store values into this tensor. interleave = 2 cols_maj = dst_cols // interleave cols_min = dst_cols % interleave return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) # This function converts dense matrix into sparse semi-structured # representation, producing "compressed" matrix, in the layout used by # CUTLASS backend, and corresponding metadata matrix. def sparse_semi_structured_from_dense_cutlass(dense): if dense.dim() != 2: raise RuntimeError( f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 ) m, k = dense.shape device = dense.device meta_dtype = torch.int8 if dense.dtype == torch.int8: meta_dtype = torch.int32 elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: meta_dtype = torch.int16 else: raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 if quadbits_per_meta_elem not in (4, 8): raise RuntimeError( "Invalid number of elements per meta element calculated") if meta_dtype == torch.int32: if m % 16 != 0: raise RuntimeError( f"Number of rows of dense matrix {m} must be divisible by 16") else: if m % 32 != 0: raise RuntimeError( f"Number of rows of dense matrix {m} must be divisible by 32") if k % (4 * quadbits_per_meta_elem) != 0: raise RuntimeError( f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 ) if dense.dtype != torch.float: ksparse = 4 dense_4 = dense.view(-1, k // ksparse, ksparse) m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) else: ksparse = 2 dense_2 = dense.view(-1, k // ksparse, ksparse) m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) meta_ncols = k // (ksparse * quadbits_per_meta_elem) # Encoding quadruples of True/False values as follows: # [True, True, False, False] -> 0b0100 # [True, False, True, False] -> 0b1000 # [False, True, True, False] -> 0b1001 # [True, False, False, True ] -> 0b1100 # [False, True, False, True ] -> 0b1101 # [False, False, True, True ] -> 0b1110 # Thus, lower two bits in the encoding are index of the True value # at the lowest index in the quadruple, and the higher two bits in # the encoding are index of the other True value in the quadruple. # In case there are less than two True values, than False value or # values at some index or indices are considered True for the # encoding. In case there are more than two True values, then the # excess True value(s) at some indices are considered False for # the encoding. The exact encodings used for these cases are as # follows: # [False, False, False, False] -> 0b1110 # [False, False, False, True ] -> 0b1110 # [False, False, True, False] -> 0b1110 # [False, True, False, False] -> 0b1001 # [False, True, True, True ] -> 0b1101 # [True, False, False, False] -> 0b1000 # [True, False, True, True ] -> 0b1100 # [True, True, False, True ] -> 0b0100 # [True, True, True, False] -> 0b0100 # [True, True, True, True ] -> 0b0100 # These particular encodings are chosen, with the help of Espresso # logic minimizer software, for the purpose of minimization of # corresponding Boolean functions, that translate non-zero flags # into encoding bits. Note also possible choices for the first # and last of these encodings were limited only to (0b0100, # 0b1110), in order to produce valid encodings for 1:2 sparsity # case. expr0 = m0 & m1 expr1 = ~m0 & m1 expr2 = ~m0 & ~m1 bit0 = expr1 bit1 = expr2 bit2 = expr0 | expr2 | m3 bit3 = expr1 | ~m1 idxs0 = bit0 | (bit1.to(torch.int64) << 1) idxs1 = bit2 | (bit3.to(torch.int64) << 1) if dense.dtype != torch.float: sparse0 = dense_4.gather( -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) else: sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( m, k // 2) # type: ignore[possibly-undefined] meta_4 = idxs0 | (idxs1 << 2) meta_n = meta_4.view( (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) if quadbits_per_meta_elem == 4: meta = (meta_n[:, :, 0] | (meta_n[:, :, 1] << 4) | (meta_n[:, :, 2] << 8) | (meta_n[:, :, 3] << 12)) elif quadbits_per_meta_elem == 8: meta = (meta_n[:, :, 0] | (meta_n[:, :, 1] << 4) | (meta_n[:, :, 2] << 8) | (meta_n[:, :, 3] << 12) | (meta_n[:, :, 4] << 16) | (meta_n[:, :, 5] << 20) | (meta_n[:, :, 6] << 24) | (meta_n[:, :, 7] << 28)) # Reorder meta tensor elements. meta_reordered = meta.new_empty( (m * meta_ncols, )) # type: ignore[possibly-undefined] meta_offsets = _calculate_meta_reordering_scatter_offsets( m, meta_ncols, meta_dtype, device) meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) return (sparse, meta_reordered.view(m, meta_ncols)) # This function performs reverse of the function above - it # reconstructs dense matrix from a pair of "compressed" matrix, given # in the layout used by CUTLASS backend, and accompanying metadata # matrix. def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): if sparse.dim() != 2: raise RuntimeError( f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 ) m, k = sparse.shape device = sparse.device if meta_reordered.dim() != 2: raise RuntimeError( f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 ) if meta_reordered.device != device: raise RuntimeError( f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 ) meta_dtype = meta_reordered.dtype if meta_dtype not in (torch.int16, torch.int32): raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 ksparse = 4 if sparse.dtype != torch.float else 2 meta_nrows, meta_ncols = meta_reordered.shape if meta_nrows != m: raise RuntimeError( f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 ) if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: raise RuntimeError( f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 "expected according to the number of columns of meta matrix") # Undo meta tensor elements reordering. meta_offsets = _calculate_meta_reordering_scatter_offsets( m, meta_ncols, meta_dtype, device) meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) # Unpack sparse tensor back to original dense tensor, using # information provided by meta tensor. Note that torch.float # datatype is handled pretty much the same as # torch.half/torch.bfloat16, as metadata for a pair of torch.float # value is encoded as if underlying 8 bytes contain four # torch.half/torch.bfloat16 values, where either first two or last # two are zeros. meta_2 = torch.empty( (m, meta_ncols, 2 * quadbits_per_meta_elem), dtype=meta_dtype, device=device, ) if quadbits_per_meta_elem == 4: meta_2[:, :, 0] = meta & 0b11 meta_2[:, :, 1] = (meta >> 2) & 0b11 meta_2[:, :, 2] = (meta >> 4) & 0b11 meta_2[:, :, 3] = (meta >> 6) & 0b11 meta_2[:, :, 4] = (meta >> 8) & 0b11 meta_2[:, :, 5] = (meta >> 10) & 0b11 meta_2[:, :, 6] = (meta >> 12) & 0b11 meta_2[:, :, 7] = (meta >> 14) & 0b11 elif quadbits_per_meta_elem == 8: meta_2[:, :, 0] = meta & 0b11 meta_2[:, :, 1] = (meta >> 2) & 0b11 meta_2[:, :, 2] = (meta >> 4) & 0b11 meta_2[:, :, 3] = (meta >> 6) & 0b11 meta_2[:, :, 4] = (meta >> 8) & 0b11 meta_2[:, :, 5] = (meta >> 10) & 0b11 meta_2[:, :, 6] = (meta >> 12) & 0b11 meta_2[:, :, 7] = (meta >> 14) & 0b11 meta_2[:, :, 8] = (meta >> 16) & 0b11 meta_2[:, :, 9] = (meta >> 18) & 0b11 meta_2[:, :, 10] = (meta >> 20) & 0b11 meta_2[:, :, 11] = (meta >> 22) & 0b11 meta_2[:, :, 12] = (meta >> 24) & 0b11 meta_2[:, :, 13] = (meta >> 26) & 0b11 meta_2[:, :, 14] = (meta >> 28) & 0b11 meta_2[:, :, 15] = (meta >> 30) & 0b11 dense_offsets = meta_2.view(-1) + ( torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view( -1, 1).repeat(1, 2).view(-1) dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device) if sparse.dtype != torch.float: # dense.scatter_(0, dense_offsets, sparse.view(-1)) dense.scatter_(0, dense_offsets, sparse.reshape(-1)) else: dense.view(torch.half).scatter_(0, dense_offsets, sparse.view(torch.half).view(-1)) return dense.view(m, 2 * k) def mask_creator(tensor): """ Class for creating N:M sparsity masks. Masks will be created using the N:M ratio, where for every block of M weights, N will be pruned based on ranked weight value. Each mask will correspond to the given tensor. :param N: The number of weights in a group to keep :param M: The size of a weight group """ N = 2 M = 4 mask = None # for i, tensor in enumerate(tensors): if tensor.numel() % M != 0: raise ValueError( f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups") num_groups = tensor.numel() // M # N:M sparsity for linear layers tensor_temp = tensor.detach().abs().reshape(num_groups, M) index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)] w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) return mask ================================================ FILE: archive/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_24_perms.py ================================================ """This file is used for /tests and /benchmarks""" from typing import Dict, List import numpy import torch # Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501 # # Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 # with the tensor-core format that is described here: # https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 # # As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 # (without the need to use ldmatrix instructions) # noqa: E501 def get_perms_24(num_bits: int): perm_list: List[int] = [] for i in range(32): perm1: List[int] = [] col = i // 4 col_o = col // 2 for block in [0, 1]: for row in [ 2 * (i % 4), 2 * (i % 4) + 1, 2 * (i % 4 + 4), 2 * (i % 4 + 4) + 1, ]: perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) for j in range(4): perm_list.extend([p + 1 * j for p in perm1]) perm = numpy.array(perm_list) if num_bits == 4: interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) elif num_bits == 8: interleave = numpy.array([0, 2, 1, 3]) else: raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() perm = torch.from_numpy(perm) scale_perm: List[int] = [] for i in range(8): scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) scale_perm_single: List[int] = [] for i in range(8): scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) return perm, scale_perm, scale_perm_single marlin_24_perm: Dict[int, torch.Tensor] = {} marlin_24_scale_perm: Dict[int, List[int]] = {} marlin_24_scale_perm_single: Dict[int, List[int]] = {} for num_bits in [4, 8]: perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits) marlin_24_perm[num_bits] = perm_24 marlin_24_scale_perm[num_bits] = scale_perm_24 marlin_24_scale_perm_single[num_bits] = scale_perm_single_24 ================================================ FILE: archive/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_perms.py ================================================ """This file is used for /tests and /benchmarks""" from typing import Dict, List import numpy import torch # Precompute permutations for Marlin weight and scale shuffling # noqa: E501 # # Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 # with the tensor-core format that is described here: # https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 # # As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 # (without the need to use ldmatrix instructions) # noqa: E501 def get_perms(num_bits: int): perm_list: List[int] = [] for i in range(32): perm1: List[int] = [] col = i // 4 for block in [0, 1]: for row in [ 2 * (i % 4), 2 * (i % 4) + 1, 2 * (i % 4 + 4), 2 * (i % 4 + 4) + 1, ]: perm1.append(16 * row + col + 8 * block) for j in range(4): perm_list.extend([p + 256 * j for p in perm1]) perm = numpy.array(perm_list) if num_bits == 4: interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) elif num_bits == 8: interleave = numpy.array([0, 2, 1, 3]) else: raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() perm = torch.from_numpy(perm) scale_perm: List[int] = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) scale_perm_single: List[int] = [] for i in range(4): scale_perm_single.extend( [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return perm, scale_perm, scale_perm_single marlin_perm: Dict[int, torch.Tensor] = {} marlin_scale_perm: Dict[int, List[int]] = {} marlin_scale_perm_single: Dict[int, List[int]] = {} for num_bits in [4, 8]: perm, scale_perm, scale_perm_single = get_perms(num_bits) marlin_perm[num_bits] = perm marlin_scale_perm[num_bits] = scale_perm marlin_scale_perm_single[num_bits] = scale_perm_single ================================================ FILE: archive/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py ================================================ """This file is used for /tests and /benchmarks""" import random import numpy import torch from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.format_24 import ( mask_creator, sparse_semi_structured_from_dense_cutlass) from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_24_perms import ( marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single) from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_perms import ( marlin_perm, marlin_scale_perm, marlin_scale_perm_single) from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.quant_utils import ( get_pack_factor, quantize_weights, sort_weights) __cuda_arch = torch.cuda.get_device_capability() MARLIN_TILE = 16 GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_MIN_THREAD_N = 64 GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MAX_PARALLEL = 16 GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8] GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] GPTQ_MARLIN_SUPPORTED_SYM = [True] def is_marlin_supported(): return __cuda_arch[0] >= 8 def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE): assert q_w.shape == (size_k, size_n) assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" # Permute weights to 16x64 marlin tiles q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) q_w = q_w.permute((0, 2, 1, 3)) q_w = q_w.reshape((size_k // tile, size_n * tile)) q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) return q_w def marlin_weights(q_w, size_k, size_n, num_bits, perm): # Permute q_w = marlin_permute_weights(q_w, size_k, size_n, perm) # Pack pack_factor = get_pack_factor(num_bits) orig_device = q_w.device q_w = q_w.cpu().numpy().astype(numpy.uint32) q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=numpy.uint32) for i in range(pack_factor): q_packed |= q_w[:, i::pack_factor] << num_bits * i q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) return q_packed def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm, scale_perm_single): if group_size < size_k and group_size != -1: s = s.reshape((-1, len(scale_perm)))[:, scale_perm] else: s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] s = s.reshape((-1, size_n)).contiguous() return s def marlin_quantize( w: torch.Tensor, num_bits: int, group_size: int, act_order: bool, ): size_k, size_n = w.shape # Normalize group_size if group_size == -1: group_size = size_k assert group_size <= size_k # Quantize (and apply act_order if provided) q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, act_order) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing sort_indices = torch.empty(0, dtype=torch.int, device=w.device) if act_order: q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Reformat to marlin marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, marlin_perm[num_bits]) marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, marlin_scale_perm[num_bits], marlin_scale_perm_single[num_bits]) # Create result res_list = [marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] for i in range(len(res_list)): res_list[i] = res_list[i].to(w.device) return res_list def vllm_marlin_quantize( w: torch.Tensor, num_bits: int, group_size: int, act_order: bool, ): size_k, size_n = w.shape # Normalize group_size if group_size == -1: group_size = size_k assert group_size <= size_k # Quantize (and apply act_order if provided) w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, act_order) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing sort_indices = torch.empty(0, dtype=torch.int, device=w.device) if act_order: q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Reformat to marlin marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, marlin_perm[num_bits]) marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, marlin_scale_perm[num_bits], marlin_scale_perm_single[num_bits]) # Create result res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] for i in range(len(res_list)): res_list[i] = res_list[i].to(w.device) return res_list def inject_24(w, size_k, size_n): assert w.shape == (size_k, size_n) mask = mask_creator(w.t()).t().cuda().bool() return (mask * w).contiguous(), mask.contiguous() def check_24(w, num_rows_to_sample=50, _verbose=False): BLOCK_SIZE = 4 MAX_NON_ZEROS = 2 w = w.t().contiguous() print("check_24: w.shape = {}".format(w.shape)) num_rows, num_cols = w.shape sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) if _verbose: print(f"Sampled row idxs = {sampled_row_idxs}") total_segments = 0 non_24_segments = 0 for i in sampled_row_idxs: for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): total_segments += 1 block = w[i, j:j + BLOCK_SIZE] num_nonzero = torch.count_nonzero(block) if num_nonzero > MAX_NON_ZEROS: print("i = {} j = {} block = {}".format(i, j, block)) non_24_segments += 1 print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): assert q_24.shape == (size_k, size_n) # Remove zp to normalize over 0 max_q_val = (1 << num_bits) - 1 zp = (max_q_val + 1) // 2 q_24_no_zp = q_24 - zp # Compress q_24_no_zp = q_24_no_zp.t().contiguous() q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( q_24_no_zp) q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() # Restore zp q_24_comp = q_24_no_zp_comp + zp # Resize meta to its actual shape (without moving any data) meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) return q_24_comp, meta def marlin_24_quantize( w: torch.Tensor, num_bits: int, group_size: int, ): size_k, size_n = w.shape # Normalize group_size if group_size == -1: group_size = size_k assert group_size <= size_k # Inject 2:4 sparsity w_24, mask_24 = inject_24(w, size_k, size_n) # Quantize w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24, num_bits, group_size, act_order=False) # Compress quantized weight q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, num_bits) size_k_comp = size_k // 2 # Reformat to marlin marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, num_bits, marlin_24_perm[num_bits]) marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size, marlin_24_scale_perm[num_bits], marlin_24_scale_perm_single[num_bits]) # Create result res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] for i in range(len(res_list)): res_list[i] = res_list[i].to(w.device) return res_list def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( torch.abs(output_ref)) class MarlinWorkspace: def __init__(self, out_features, min_thread_n, max_parallel, device): assert (out_features % min_thread_n == 0), ( "out_features = {} is undivisible by min_thread_n = {}".format( out_features, min_thread_n)) max_workspace_size = ((out_features // min_thread_n) * max_parallel) self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device=device) ================================================ FILE: archive/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py ================================================ """This file is used for /tests and /benchmarks""" import numpy import torch SUPPORTED_NUM_BITS = [4, 8] SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] def get_pack_factor(num_bits): assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" return 32 // num_bits def permute_rows(q_w: torch.Tensor, group_size: int): orig_device = q_w.device k_size, _ = q_w.shape g_idx = torch.zeros((k_size, ), dtype=torch.int32) for i in range(k_size): g_idx[i] = i // group_size # Simulate act_order by doing a random permutation on K rand_perm = torch.randperm(k_size) g_idx = g_idx[rand_perm].contiguous() q_w = q_w[rand_perm, :].contiguous() return ( q_w.to(device=orig_device), g_idx.to(device=orig_device), rand_perm.to(device=orig_device), ) def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, act_order: bool): orig_device = w.device size_k, size_n = w.shape assert w.is_floating_point(), "w must be float" assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" assert group_size in SUPPORTED_GROUP_SIZES + [ size_k ], f"Unsupported groupsize = {group_size}" if group_size == -1: group_size = size_k assert group_size <= size_k max_q_val = 2**num_bits - 1 half_q_val = (max_q_val + 1) // 2 # Reshape to [groupsize, -1] if group_size < size_k: w = w.view((-1, group_size, size_n)) w = w.permute(1, 0, 2) w = w.reshape((group_size, -1)) # Compute scale for each group s = torch.max(torch.abs(w), 0, keepdim=True)[0] s *= 2 / max_q_val # 2 => symmetric # Quantize q_w = torch.round(w / s).int() q_w += half_q_val q_w = torch.clamp(q_w, 0, max_q_val) # Restore original shapes if group_size < size_k: def reshape_w(w): w = w.reshape((group_size, -1, size_n)) w = w.permute(1, 0, 2) w = w.reshape((size_k, size_n)).contiguous() return w q_w = reshape_w(q_w) s = s.reshape((-1, size_n)).contiguous() # Apply act_order g_idx = torch.empty(0, dtype=torch.int, device=w.device) rand_perm = torch.empty(0, dtype=torch.int, device=w.device) if act_order: assert ( group_size < size_k ), "For act_order, groupsize = {} must be less than size_k = {}".format( group_size, size_k) q_w, g_idx, rand_perm = permute_rows(q_w, group_size) return ( q_w.to(device=orig_device), s.to(device=orig_device), g_idx.to(device=orig_device), rand_perm.to(device=orig_device), ) def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): orig_device = q_w.device sort_indices = torch.argsort(g_idx).to( dtype=torch.int32) # Sort based on g_idx g_idx = g_idx[sort_indices].contiguous() q_w = q_w[sort_indices, :].contiguous() return ( q_w.to(device=orig_device), g_idx.to(device=orig_device), sort_indices.to(device=orig_device), ) def gptq_pack( q_w: torch.Tensor, num_bits: int, size_k: int, size_n: int, ): assert q_w.shape == (size_k, size_n) pack_factor = get_pack_factor(num_bits) assert size_k % pack_factor == 0 orig_device = q_w.device q_w = q_w.cpu().numpy().astype(numpy.uint32) q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) for i in range(pack_factor): q_res |= q_w[i::pack_factor, :] << num_bits * i q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) return q_res ================================================ FILE: archive/ktransformers/ktransformers_ext/triton/fp8gemm.py ================================================ # Adopted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py from typing import Tuple import torch import triton import triton.language as tl from triton import Config @triton.jit def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): """ Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. Args: x_ptr (triton.Pointer): Pointer to the input tensor. y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. Returns: None """ pid = tl.program_id(axis=0) offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs).to(tl.float32) s = tl.max(tl.abs(x)) / 448. y = x / s y = y.to(y_ptr.dtype.element_ty) tl.store(y_ptr + offs, y) tl.store(s_ptr + pid, s) def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantizes the input tensor `x` using block-wise quantization. Args: x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - The quantized tensor with dtype `torch.float8_e4m3fn`. - A tensor of scaling factors with dtype `torch.float32`. """ assert x.is_contiguous(), 'Input tensor must be contiguous' assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})' y = torch.empty_like(x, dtype=torch.float8_e4m3fn) s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) return y, s @triton.jit def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): """ Dequantizes weights using the provided scaling factors and stores the result. Args: x_ptr (tl.pointer): Pointer to the quantized weights. s_ptr (tl.pointer): Pointer to the scaling factors. y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. M (int): Number of rows in the weight matrix. N (int): Number of columns in the weight matrix. BLOCK_SIZE (tl.constexpr): Size of the block for tiling. Returns: None """ pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) n = tl.cdiv(N, BLOCK_SIZE) offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) offs = offs_m[:, None] * N + offs_n[None, :] mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) s = tl.load(s_ptr + pid_m * n + pid_n) y = x * s tl.store(y_ptr + offs, y, mask=mask) def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: """ Dequantizes the given weight tensor using the provided scale tensor. Args: x (torch.Tensor): The quantized weight tensor of shape (M, N). s (torch.Tensor): The scale tensor of shape (M, N). block_size (int, optional): The block size to use for dequantization. Defaults to 128. Returns: torch.Tensor: The dequantized weight tensor of the same shape as `x`. Raises: AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. """ assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' M, N = x.size() y = torch.empty_like(x, dtype=torch.get_default_dtype()) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) with torch.cuda.device(x.device): weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) return y fp8_gemm_configs = [ Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8) for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6] ] @triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) @triton.jit def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, a_s_ptr, b_s_ptr, M, N: tl.constexpr, K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): """ Performs a matrix multiplication operation on FP8 matrices with scaling factors. Args: a_ptr (tl.tensor): Pointer to the first input matrix A. b_ptr (tl.tensor): Pointer to the second input matrix B. c_ptr (tl.tensor): Pointer to the output matrix C. a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. M (int): Number of rows in matrix A and C. N (tl.constexpr): Number of columns in matrix B and C. K (tl.constexpr): Number of columns in matrix A and rows in matrix B. BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. Returns: None """ pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) k = tl.cdiv(K, BLOCK_SIZE_K) offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] a_s_ptrs = a_s_ptr + offs_m * k b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for i in range(k): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) a_s = tl.load(a_s_ptrs) b_s = tl.load(b_s_ptrs) accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] a_ptrs += BLOCK_SIZE_K b_ptrs += BLOCK_SIZE_K a_s_ptrs += 1 b_s_ptrs += 1 c = accumulator.to(c_ptr.dtype.element_ty) offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) tl.store(c_ptrs, c, mask=mask) def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): """ Perform a matrix multiplication using FP8 precision. Args: a (torch.Tensor): The first input matrix, must be contiguous. a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. b (torch.Tensor): The second input matrix, must be contiguous. b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. Returns: torch.Tensor: The result of the matrix multiplication. """ assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous' assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous' K = a.size(-1) M = a.numel() // K N = b.size(0) c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) return c ================================================ FILE: archive/ktransformers/local_chat.py ================================================ """ Description : Author : Boxin Zhang, Azure-Tang Version : 0.1.0 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ import os import platform import sys project_dir = os.path.dirname(os.path.dirname(__file__)) sys.path.insert(0, project_dir) import torch try: import torch_npu from torch_npu.contrib import transfer_to_npu from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group from ktransformers.util import utils, npu_graph_runner except: pass import torch.distributed as dist import logging from transformers import ( AutoTokenizer, AutoConfig, AutoModelForCausalLM, GenerationConfig, TextStreamer, ) import json import fire from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM from ktransformers.util.utils import prefill_and_generate, get_compute_capability, xpu_fp16_model from ktransformers.util import utils from ktransformers.models.custom_cache import StaticCache from ktransformers.server.config.config import Config from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor custom_models = { "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM, "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, "MixtralForCausalLM": MixtralForCausalLM, } ktransformer_rules_dir = ( os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" ) default_optimize_rules = { "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml", "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml", } try: torch.npu.config.allow_internal_format = True torch.npu.set_compile_mode(jit_compile=False) except: pass import sys, signal, faulthandler faulthandler.register(signal.SIGUSR1, file=sys.stderr, all_threads=True, chain=False) def local_chat( model_path: str | None = None, optimize_config_path: str = None, gguf_path: str | None = None, max_new_tokens: int = 1000, cpu_infer: int = Config().cpu_infer, use_cuda_graph: bool = True, prompt_file : str | None = None, mode: str = "normal", force_think: bool = False, chunk_size: int = 8192, device: str = "cuda", tp: int = 1, ): Config().cpu_infer = cpu_infer local_rank, world_size = setup_model_parallel(tp=tp) torch.set_grad_enabled(False) if utils.CUR_DEVICE is None: utils.CUR_DEVICE = f"npu:{torch.npu.current_device()}" tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) config.chunk_size = chunk_size npu_graph_runner.LAYER_ID = config.num_hidden_layers if mode == 'long_context': assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode" torch.set_default_dtype(torch.float16) else: torch.set_default_dtype(config.torch_dtype) with torch.device("meta"): if config.architectures[0] in custom_models: print("using custom modeling_xxx.py.") if ( "Qwen2Moe" in config.architectures[0] ): # Qwen2Moe must use flash_attention_2 to avoid overflow. config._attn_implementation = "flash_attention_2" if "Llama" in config.architectures[0]: config._attn_implementation = "eager" if "Mixtral" in config.architectures[0]: config._attn_implementation = "flash_attention_2" model = custom_models[config.architectures[0]](config) else: model = AutoModelForCausalLM.from_config( config, trust_remote_code=True, attn_implementation="flash_attention_2" ) if optimize_config_path is None: if config.architectures[0] in default_optimize_rules: print("using default_optimize_rule for", config.architectures[0]) if local_rank == 0 else None optimize_config_path = default_optimize_rules[config.architectures[0]] print(f'{optimize_config_path=}') if local_rank == 0 else None else: optimize_config_path = input( "please input the path of your rule file(yaml file containing optimize rules):" ) if gguf_path is None: gguf_path = input( "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" ) optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) # 提前absorbed get_absort_weight(model, config) try: model.generation_config = GenerationConfig.from_pretrained(model_path) except Exception as e: print(f"generation config can't auto create, make default. Message: {e}") gen_config = GenerationConfig( temperature=0.6, top_p=0.95, do_sample=True ) model.generation_config = gen_config # model.generation_config = GenerationConfig.from_pretrained(model_path) if model.generation_config.pad_token_id is None: model.generation_config.pad_token_id = model.generation_config.eos_token_id model.eval() logging.basicConfig(level=logging.INFO) system = platform.system() if system == "Windows": os.system("cls") if local_rank == 0 else None else: os.system("clear") if local_rank == 0 else None print(f"{model=}") if local_rank == 0 else None batch_size, seq_length = 1, 16384 # default cache pool params device_map = model.gguf_loader.tensor_device_map static_cache = StaticCache( config = model.config, max_batch_size = batch_size, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype ) torch.distributed.barrier() while True: if local_rank == 0: try: content = input("Chat: \n").strip() except KeyboardInterrupt: dist.barrier() print('Exit all rank with KeyboardInterrupt!') sys.exit(0) if content.startswith('"""'): # prefix """ # multi lines input content = content[3:] + "\n" while True: line = input("") if line.endswith('"""'): # end multi lines input line = line[:-3] # suffix """ if line: content += line + "\n" break else: content += line + "\n" if content == "": if prompt_file != None: content = open(prompt_file, "r").read() else: continue elif os.path.isfile(content): f = open(content, "r") content = f.readlines() f.close() else: content = [f"{len(content)},{max_new_tokens},{content}"] else: content = [""] for line in content: content_tensor = torch.tensor(bytearray(line.encode()), dtype=torch.uint8).to(device=utils.CUR_DEVICE) if world_size > 1: content_size = torch.tensor(len(content_tensor), dtype=torch.int64).to(device=utils.CUR_DEVICE) all_content_sizes = [torch.zeros((1,), dtype=torch.int64).to(device=utils.CUR_DEVICE) for _ in range(world_size)] dist.all_gather(all_content_sizes, content_size) max_content_size = max([size.item() for size in all_content_sizes]) padded_content_tensor = torch.zeros((max_content_size,), dtype=torch.uint8).to(device=utils.CUR_DEVICE) padded_content_tensor[:len(content_tensor)] = content_tensor all_content_tensors = [torch.zeros((max_content_size,), dtype=torch.uint8).to(device=utils.CUR_DEVICE) for _ in range(world_size)] dist.all_gather(all_content_tensors, padded_content_tensor) content_tensor = all_content_tensors[0][:all_content_sizes[0].item()] line = bytes(content_tensor.cpu().numpy()).decode() parts = line.split(",") input_tokens = int(parts[0]) max_new_tokens = int(parts[1]) line = line[line.index(",", line.index(",") + 1) + 1:] messages = [{"role": "user", "content": line}] input_tensor = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ) if force_think: token_thinks = torch.tensor([tokenizer.encode("\\n",add_special_tokens=False)],device=input_tensor.device) input_tensor = torch.cat( [input_tensor, token_thinks], dim=1 ) if mode == 'long_context': assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ "please change max_seq_len in ~/.ktransformers/config.yaml" if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA: generated = prefill_and_generate( model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size, use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim, static_cache=static_cache ) else: generated = prefill_and_generate( model, tokenizer, input_tensor.to(device=utils.CUR_DEVICE), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size, static_cache=static_cache ) if __name__ == "__main__": fire.Fire(local_chat) ================================================ FILE: archive/ktransformers/local_chat_test.py ================================================ """ Description : Author : Boxin Zhang, Azure-Tang Version : 0.1.0 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ import os import platform import sys project_dir = os.path.dirname(os.path.dirname(__file__)) sys.path.insert(0, project_dir) import torch import logging from transformers import ( AutoTokenizer, AutoConfig, AutoModelForCausalLM, GenerationConfig, TextStreamer, ) import json import fire from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM from ktransformers.util.utils import prefill_and_generate, get_compute_capability from ktransformers.server.config.config import Config from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled custom_models = { "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM, "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, "MixtralForCausalLM": MixtralForCausalLM, } ktransformer_rules_dir = ( os.path.dirname(os.path.abspath(__file__)) + "/optimize/optimize_rules/" ) default_optimize_rules = { "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml", "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml", } def local_chat( model_path: str | None = None, optimize_config_path: str = None, gguf_path: str | None = None, max_new_tokens: int = 1000, cpu_infer: int = Config().cpu_infer, use_cuda_graph: bool = True, prompt_file : str | None = None, mode: str = "normal", force_think: bool = False, chunk_prefill_size: int = 8192 ): torch.set_grad_enabled(False) Config().cpu_infer = cpu_infer tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) if mode == 'long_context': assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode" torch.set_default_dtype(torch.float16) else: torch.set_default_dtype(config.torch_dtype) with torch.device("meta"): if config.architectures[0] in custom_models: print("using custom modeling_xxx.py.") if ( "Qwen2Moe" in config.architectures[0] ): # Qwen2Moe must use flash_attention_2 to avoid overflow. config._attn_implementation = "flash_attention_2" if "Llama" in config.architectures[0]: config._attn_implementation = "eager" if "Mixtral" in config.architectures[0]: config._attn_implementation = "flash_attention_2" model = custom_models[config.architectures[0]](config) else: model = AutoModelForCausalLM.from_config( config, trust_remote_code=True, attn_implementation="flash_attention_2" ) if optimize_config_path is None: if config.architectures[0] in default_optimize_rules: print("using default_optimize_rule for", config.architectures[0]) optimize_config_path = default_optimize_rules[config.architectures[0]] else: optimize_config_path = input( "please input the path of your rule file(yaml file containing optimize rules):" ) if gguf_path is None: gguf_path = input( "please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):" ) optimize_and_load_gguf(model, optimize_config_path, gguf_path, config) try: model.generation_config = GenerationConfig.from_pretrained(model_path) except Exception as e: print(f"generation config can't auto create, make default. Message: {e}") gen_config = GenerationConfig( temperature=0.6, top_p=0.95, do_sample=True ) model.generation_config = gen_config # model.generation_config = GenerationConfig.from_pretrained(model_path) if model.generation_config.pad_token_id is None: model.generation_config.pad_token_id = model.generation_config.eos_token_id model.eval() logging.basicConfig(level=logging.INFO) system = platform.system() if system == "Windows": os.system("cls") else: os.system("clear") if prompt_file != None: assert os.path.isfile(prompt_file), "prompt file not exist" print(f"prompt file is {prompt_file}") content = open(prompt_file, "r").read() else: content = "Please write a piece of quicksort code in C++." print('Start Testing...(1 round)') print('Prompt:', content) while True: messages = [{"role": "user", "content": content}] input_tensor = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ) if force_think: token_thinks = torch.tensor([tokenizer.encode("\\n",add_special_tokens=False)],device=input_tensor.device) input_tensor = torch.cat( [input_tensor, token_thinks], dim=1 ) if mode == 'long_context': assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ "please change max_seq_len in ~/.ktransformers/config.yaml" if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8: generated = prefill_and_generate( model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size, use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim ) else: generated = prefill_and_generate( model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size, ) break if __name__ == "__main__": fire.Fire(local_chat) ================================================ FILE: archive/ktransformers/models/__init__.py ================================================ ================================================ FILE: archive/ktransformers/models/ascend/custom_ascend_modeling_deepseek_v3.py ================================================ """ Date: 2024-11-06 10:05:11 LastEditors: djw LastEditTime: 2024-11-13 07:50:51 """ import math from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F import torch_npu import math from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from ktransformers.server.config.config import Config from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput from ktransformers.models.custom_cache import KVC2StaticCache from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Model, DeepseekV3PreTrainedModel from ktransformers.models.configuration_deepseek_v3 import DeepseekV3Config import ktransformers.util.utils as utils torch.set_grad_enabled(False) torch.set_default_dtype(torch.float16) class KNPUDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): # cache: KVC2StaticCache use_cuda_graph = False def __init__( self, config: DeepseekV3Config, stream = None, default_type=torch.float16 ): super().__init__(config) self.model = DeepseekV3Model(config) self.config = config self.config.backend_type = "balance_serve" # self.cache = cache self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.default_type = default_type self.stream = torch_npu.npu.current_stream() if stream is None else stream self.para_stream = torch_npu.npu.Stream() self.call_stream = torch_npu.npu.Stream() def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages): print('[WARN] this custom modeling do not support flash infer, skip this part...') def batch_embeddings(self, batch: ForwardBatchInput, device="npu:0", is_prefill=True): features = [] if is_prefill: start_ids = 0 seq_lens = [] for i in range(batch.minibatch.prefill_batch): assert batch.minibatch.p_kv_len[i] == batch.minibatch.p_q_len[i], \ "[ERROR] current prefill do not support chunk or prefix cache" tokens = batch.minibatch.p_tokens[start_ids: start_ids+batch.minibatch.p_q_len[i]].contiguous() start_ids += batch.minibatch.p_q_len[i] feature = ( self.model.embed_tokens(tokens.to(torch.device('cpu'))) .to(self.default_type) .to(device=device) ) features.append(feature) seq_lens.append(feature.shape[0]) max_seq_len = max(seq_lens) if seq_lens else 0 padded_features = [] for feat in features: curr_len = feat.shape[0] if curr_len < max_seq_len: pad_len = max_seq_len - curr_len padded_feat = torch.nn.functional.pad( feat, (0, 0, 0, pad_len), mode='constant', value=0.0 ) padded_features.append(padded_feat) else: padded_features.append(feat) features_t = torch.stack(padded_features) else: for i in range(batch.minibatch.decode_batch): if batch.minibatch.d_tokens.dim() == 1: tokens = batch.minibatch.d_tokens.contiguous() else: tokens = batch.minibatch.d_tokens[i].contiguous() feature = ( self.model.embed_tokens(tokens.to(torch.device('cpu'))) .to(self.default_type) .to(device=device) ) features.append(feature) features_t = torch.stack(features) return features_t def print_callback(self, param): with torch.npu.stream(self.call_stream): hidden_states = param[0] print("########################################") print("hidden_states is ", hidden_states) print("########################################") def forward( self, batch: ForwardBatchInput | None = None, features: torch.Tensor | None = None, past_key_value: KVC2StaticCache | None = None, bsz_tensors: torch.Tensor | None = None, num_tokens_tensors: torch.Tensor | None = None, page_idx: torch.Tensor | None = None, page_offset: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, block_tables: torch.Tensor | None = None, cuda_graph_idx: int | None = -1, is_prefill: bool = True ) -> ForwardBatchOutput: # NPU use direct block table from ForwardBatchInput instead of page_idx & page_offset if features.ndim == 2: hidden_states = features.unsqueeze(0) elif features.ndim == 1: hidden_states = features.unsqueeze(0).unsqueeze(0) # (bsz, seqlen, hidden) else: hidden_states = features (bsz, q_len, hidden_size) = hidden_states.shape if is_prefill: position_ids = -1 * torch.ones(bsz, q_len).to(batch.minibatch.p_position_ids.device) bsz_real = torch.zeros(bsz).to(batch.minibatch.p_position_ids.device) # convert merged into batched start_ids = 0 for i, qlen in enumerate(batch.minibatch.p_q_len): position_ids[i, 0:qlen] = batch.minibatch.p_position_ids[start_ids:start_ids+qlen] start_ids += qlen bsz_real[i] = qlen block_tables = batch.minibatch.p_block_tables kv_len = batch.minibatch.p_kv_len[0] q_len_raw = batch.minibatch.p_q_len kv_len_raw = batch.minibatch.p_kv_len else: position_ids = batch.minibatch.d_position_ids if len(position_ids.shape) == 1: position_ids = position_ids.unsqueeze(0) block_tables = batch.minibatch.d_block_tables kv_len = batch.minibatch.d_kv_len[0] q_len_raw = None kv_len_raw = batch.minibatch.d_kv_len_list bsz_real = None for i, decode_layer in enumerate(self.model.layers): residual = hidden_states hidden_states = decode_layer.input_layernorm(hidden_states) # generate chunk_mask automatically. if is_prefill: attn_mask = -65504.0 * torch.triu(torch.ones(q_len, kv_len, device=hidden_states.device), diagonal=1) attn_mask = attn_mask.unsqueeze(0).unsqueeze(0) # (bsz, 1, q_len, kv_len) if bsz > 1: attn_mask = attn_mask.expand(bsz, attn_mask.shape[1], attn_mask.shape[2], attn_mask.shape[3]) else: attn_mask = None # print_ex(f"####: before self_attn of layer {i}...") hidden_states, _, _ = decode_layer.self_attn(hidden_states, position_ids=position_ids, attention_mask=attn_mask, past_key_value=past_key_value, num_tokens_tensors=num_tokens_tensors, page_idx=page_idx, page_offset=page_offset, block_table=block_tables, q_len_raw=q_len_raw, kv_len_raw=kv_len_raw, is_prefill=is_prefill, stream = self.stream, ) hidden_states = residual + hidden_states # mlp residual = hidden_states hidden_states = decode_layer.post_attention_layernorm(hidden_states) # print_ex(f"####: before mlp of layer {i}...") hidden_states = decode_layer.mlp(hidden_states, self.stream, self.para_stream) hidden_states = hidden_states.squeeze(0) hidden_states = residual + hidden_states # print_ex(f"####: fill output...") forward_batch_output = ForwardBatchOutput() # with torch_npu.npu.stream(self.stream): hidden_states_without_norm = hidden_states.clone() local_logit = self.lm_head(self.model.norm(hidden_states)) for bsz in range(local_logit.size(0)): if bsz_real is not None: index = int(bsz_real[bsz].item()) result = local_logit[bsz][:index] else: result = local_logit[bsz] forward_batch_output.logits.append(result) forward_batch_output.pre_hidden_states.append(hidden_states_without_norm[bsz]) return forward_batch_output def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors, num_heads: int, head_dim_ckv: int, head_dim_kpe: int, page_size: int, causal: bool, sm_scale: float, q_data_type: torch.dtype, kv_data_type: torch.dtype,): print('[WARN] this custom modeling do not support flash infer, skip this part...') ================================================ FILE: archive/ktransformers/models/ascend/custom_ascend_modeling_qwen3.py ================================================ # coding=utf-8 # Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved. # Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import math from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch_npu from dataclasses import dataclass from torch.nn import functional as F import torch.utils.checkpoint from ktransformers.server.config.config import Config from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput from ktransformers.models.custom_cache import KVC2Qwen3Cache from ktransformers.models.modeling_qwen3_moe import Qwen3MoePreTrainedModel, Qwen3MoeModel from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig import ktransformers.util.utils as utils from ktransformers.operators.ascend.ascend_layernorm import KQwen3FinalRMSNormNPU torch.set_grad_enabled(False) torch.set_default_dtype(torch.float16) class KNPUQwen3MoeForCausalLM(Qwen3MoePreTrainedModel): cache: "KVC2Qwen3Cache" use_cuda_graph = False def __init__( self, config: "Qwen3MoeConfig", cache: "KVC2Qwen3Cache", stream: Optional["torch_npu.npu.Stream"] = None, default_type: torch.dtype = torch.float16, ): super().__init__(config) self.model = Qwen3MoeModel(config) self.config = config self.config.backend_type = "balance_serve" self.cache = cache self.vocab_size = config.vocab_size self.model.to(torch.float16) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.default_type = default_type self.stream = torch_npu.npu.current_stream() if stream is None else stream self.para_stream = torch_npu.npu.Stream() self.call_stream = torch_npu.npu.Stream() if hasattr(self.model, "embed_tokens"): self.model.embed_tokens.weight.data = self.model.embed_tokens.weight.data.to(torch.float16) if hasattr(self.model, "norm"): self.model.norm.weight.data = self.model.norm.weight.data.to(torch.float16) if getattr(self.model.norm, "bias", None) is not None: self.model.norm.bias.data = self.model.norm.bias.data.to(torch.float16) try: orig_norm = self.model.norm self.model.norm = KQwen3FinalRMSNormNPU(orig_norm) except Exception as e: print(f"[INIT][WARN] replace model.norm failed: {e}", flush=True) def init_wrapper(self): print("[WARN] KNPUQwen3MoeForCausalLM does not use flashinfer wrapper on NPU, skip init_wrapper...") # --------------------------------------------------- # Embedding:support prefill / decode modes # --------------------------------------------------- def batch_embeddings( self, batch: "ForwardBatchInput", device: str = "npu:0", is_prefill: bool = True, ) -> torch.Tensor: features = [] if is_prefill: start_ids = 0 seq_lens = [] for i in range(batch.minibatch.prefill_batch): qlen = int(batch.minibatch.p_q_len[i]) kvlen = int(batch.minibatch.p_kv_len[i]) if kvlen < qlen: raise AssertionError( f"[ERROR] p_kv_len({kvlen}) < p_q_len({qlen}) " f"for prefill idx={i}, this should not happen" ) tokens = batch.minibatch.p_tokens[start_ids: start_ids + qlen].contiguous() start_ids += qlen feat = ( self.model.embed_tokens(tokens.to(torch.device("cpu"))) .to(self.default_type) .to(device=device) ) features.append(feat) seq_lens.append(qlen) max_seq_len = max(seq_lens) if seq_lens else 0 # Pad the current chunk to the maximum q_len with [bsz, max_q_len, hidden]. padded_features = [] for feat in features: curr_len = feat.shape[0] if curr_len < max_seq_len: pad_len = max_seq_len - curr_len padded_feat = torch.nn.functional.pad( feat, (0, 0, 0, pad_len), mode="constant", value=0.0, ) padded_features.append(padded_feat) else: padded_features.append(feat) features_t = torch.stack(padded_features, dim=0) # [bsz, max_seq_len, hidden] else: for i in range(batch.minibatch.decode_batch): if batch.minibatch.d_tokens.dim() == 1: tokens = batch.minibatch.d_tokens.contiguous() else: tokens = batch.minibatch.d_tokens[i].contiguous() feature = ( self.model.embed_tokens(tokens.to(torch.device("cpu"))) .to(self.default_type) .to(device=device) ) features.append(feature) features_t = torch.stack(features) # [decode_bsz, decode_q_len, hidden] return features_t def forward( self, batch: Optional["ForwardBatchInput"] = None, features: torch.Tensor | None = None, cache=None, bsz_tensors: torch.Tensor | None = None, num_tokens_tensors: torch.Tensor | None = None, page_idx: torch.Tensor | None = None, page_offset: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, block_tables: torch.Tensor | None = None, cuda_graph_idx: int | None = 0, is_prefill: bool = True, ) -> "ForwardBatchOutput": try: is_capturing = torch.npu.is_current_stream_capturing() except Exception: is_capturing = False # features: [bsz, q_len, hidden] if features.ndim == 2: hidden_states = features.unsqueeze(0) elif features.ndim == 1: hidden_states = features.unsqueeze(0).unsqueeze(0) else: hidden_states = features bsz, q_len, hidden_size = hidden_states.shape minibatch = batch.minibatch if is_prefill: device_pos = minibatch.p_position_ids.device position_ids = -1 * torch.ones( bsz, q_len, dtype=minibatch.p_position_ids.dtype, device=device_pos, ) bsz_real = torch.zeros(bsz, dtype=torch.int32, device=device_pos) start_ids = 0 for i, qlen in enumerate(minibatch.p_q_len): position_ids[i, :qlen] = minibatch.p_position_ids[start_ids:start_ids + qlen] start_ids += int(qlen.item()) bsz_real[i] = qlen block_tables = minibatch.p_block_tables kv_len = minibatch.p_kv_len[0] q_len_raw = minibatch.p_q_len kv_len_raw = minibatch.p_kv_len kv_len_tensor = kv_len_raw else: position_ids = minibatch.d_position_ids if position_ids.dim() == 1: position_ids = position_ids.unsqueeze(0) block_tables = minibatch.d_block_tables kv_len = minibatch.d_kv_len[0] q_len_raw = None kv_len_tensor = minibatch.d_kv_len_list bsz_real = None # ==================== layer loop ==================== for i, decode_layer in enumerate(self.model.layers): # ---------- Attention Block ---------- attn_residual = hidden_states hidden_states = decode_layer.input_layernorm(hidden_states) attn_out = decode_layer.self_attn( hidden_states, past_key_value=self.cache, position_ids=position_ids, num_tokens_tensors=num_tokens_tensors, page_idx=page_idx, page_offset=page_offset, block_table=block_tables, q_len_raw=q_len_raw, kv_len_raw=kv_len_tensor, is_prefill=is_prefill, stream=self.stream, ) hidden_states = attn_residual + attn_out # ---------- MLP Block ---------- mlp_residual = hidden_states hidden_states = decode_layer.post_attention_layernorm(hidden_states) mlp_in = hidden_states mlp_out = decode_layer.mlp( mlp_in, num_tokens_tensors, cuda_graph_idx, ) if isinstance(mlp_out, tuple): moe_y = mlp_out[0] else: moe_y = mlp_out hidden_states = mlp_residual + moe_y forward_batch_output = ForwardBatchOutput() hidden_states_without_norm = hidden_states.clone() normed = self.model.norm(hidden_states) local_logit = self.lm_head(normed) B_out = local_logit.size(0) for b in range(B_out): if (bsz_real is not None) and (not is_capturing): valid_len = int(bsz_real[b].item()) result = local_logit[b, :valid_len] pre_h = hidden_states_without_norm[b, :valid_len] else: result = local_logit[b] pre_h = hidden_states_without_norm[b] forward_batch_output.logits.append(result) forward_batch_output.pre_hidden_states.append(pre_h) return forward_batch_output def flash_infer_attn_plan( self, batch: "ForwardBatchInput", bsz_tensors: torch.Tensor, num_tokens_tensors: torch.Tensor, num_q_heads: int, num_kv_heads: int, head_dim: int, page_size: int, causal: bool, q_data_type: torch.dtype, kv_data_type: torch.dtype, cuda_graph_idx: int = 0, ): print("[WARN] KNPUQwen3MoeForCausalLM on NPU does not support flashinfer, skip flash_infer_attn_plan...") ================================================ FILE: archive/ktransformers/models/configuration_deepseek.py ================================================ # Adapted from # https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat-0628/blob/main/configuration_deepseek.py # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. # Copyright (c) 2024 by KVCache.AI, All Rights Reserved. from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} class DeepseekV2Config(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the DeepSeek-V2. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 102400): Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`DeepseekV2Model`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 11008): Dimension of the MLP representations. moe_intermediate_size (`int`, *optional*, defaults to 1407): Dimension of the MoE representations. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer decoder. n_shared_experts (`int`, *optional*, defaults to None): Number of shared experts, None means dense model. n_routed_experts (`int`, *optional*, defaults to None): Number of routed experts, None means dense model. routed_scaling_factor (`float`, *optional*, defaults to 1.0): Scaling factor or routed experts. topk_method (`str`, *optional*, defaults to `gready`): Topk method used in routed gate. n_group (`int`, *optional*, defaults to None): Number of groups for routed experts. topk_group (`int`, *optional*, defaults to None): Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). num_experts_per_tok (`int`, *optional*, defaults to None): Number of selected experts, None means dense model. moe_layer_freq (`int`, *optional*, defaults to 1): The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. first_k_dense_replace (`int`, *optional*, defaults to 0): Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). \--k dense layers--/ norm_topk_prob (`bool`, *optional*, defaults to False): Whether to normalize the weights of the routed experts. scoring_func (`str`, *optional*, defaults to 'softmax'): Method of computing expert weights. aux_loss_alpha (`float`, *optional*, defaults to 0.001): Auxiliary loss weight coefficient. seq_aux = (`bool`, *optional*, defaults to True): Whether to compute the auxiliary loss for each individual sample. num_key_value_heads (`int`, *optional*): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 2048): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. pad_token_id (`int`, *optional*): Padding token id. bos_token_id (`int`, *optional*, defaults to 1): Beginning of stream token id. eos_token_id (`int`, *optional*, defaults to 2): End of stream token id. pretraining_tp (`int`, *optional*, defaults to 1): Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is necessary to ensure exact reproducibility of the pretraining results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. ```python >>> from transformers import DeepseekV2Model, DeepseekV2Config >>> # Initializing a Deepseek-V2 style configuration >>> configuration = DeepseekV2Config() >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "deepseek_v2" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, vocab_size=102400, hidden_size=4096, intermediate_size=11008, moe_intermediate_size = 1407, num_hidden_layers=30, num_attention_heads=32, num_key_value_heads=32, n_shared_experts = None, n_routed_experts = None, ep_size = 1, routed_scaling_factor = 1.0, kv_lora_rank = 512, q_lora_rank = 1536, qk_rope_head_dim = 64, v_head_dim = 128, qk_nope_head_dim = 128, topk_method = 'gready', n_group = None, topk_group = None, num_experts_per_tok = None, moe_layer_freq = 1, first_k_dense_replace = 0, norm_topk_prob = False, scoring_func = 'softmax', aux_loss_alpha = 0.001, seq_aux = True, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=None, bos_token_id=100000, eos_token_id=100001, pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, cpu_quant=None, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.moe_intermediate_size = moe_intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.n_shared_experts = n_shared_experts self.n_routed_experts = n_routed_experts self.ep_size = ep_size self.routed_scaling_factor = routed_scaling_factor self.kv_lora_rank = kv_lora_rank self.q_lora_rank = q_lora_rank self.qk_rope_head_dim = qk_rope_head_dim self.v_head_dim = v_head_dim self.qk_nope_head_dim = qk_nope_head_dim self.topk_method = topk_method self.n_group = n_group self.topk_group = topk_group self.num_experts_per_tok = num_experts_per_tok self.moe_layer_freq = moe_layer_freq self.first_k_dense_replace = first_k_dense_replace self.norm_topk_prob = norm_topk_prob self.scoring_func = scoring_func self.aux_loss_alpha = aux_loss_alpha self.seq_aux = seq_aux # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.pretraining_tp = pretraining_tp self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.cpu_quant = cpu_quant super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) ================================================ FILE: archive/ktransformers/models/configuration_deepseek_v3.py ================================================ from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} class DeepseekV3Config(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the DeepSeek-V3. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 129280): Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`DeepseekV3Model`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 11008): Dimension of the MLP representations. moe_intermediate_size (`int`, *optional*, defaults to 1407): Dimension of the MoE representations. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer decoder. num_nextn_predict_layers (`int`, *optional*, defaults to 1): Number of nextn predict layers in the DeepSeekV3 Model. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer decoder. n_shared_experts (`int`, *optional*, defaults to None): Number of shared experts, None means dense model. n_routed_experts (`int`, *optional*, defaults to None): Number of routed experts, None means dense model. routed_scaling_factor (`float`, *optional*, defaults to 1.0): Scaling factor or routed experts. topk_method (`str`, *optional*, defaults to `gready`): Topk method used in routed gate. n_group (`int`, *optional*, defaults to None): Number of groups for routed experts. topk_group (`int`, *optional*, defaults to None): Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). num_experts_per_tok (`int`, *optional*, defaults to None): Number of selected experts, None means dense model. moe_layer_freq (`int`, *optional*, defaults to 1): The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. first_k_dense_replace (`int`, *optional*, defaults to 0): Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). \--k dense layers--/ norm_topk_prob (`bool`, *optional*, defaults to False): Whether to normalize the weights of the routed experts. scoring_func (`str`, *optional*, defaults to 'softmax'): Method of computing expert weights. aux_loss_alpha (`float`, *optional*, defaults to 0.001): Auxiliary loss weight coefficient. seq_aux = (`bool`, *optional*, defaults to True): Whether to compute the auxiliary loss for each individual sample. num_key_value_heads (`int`, *optional*): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 2048): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. pad_token_id (`int`, *optional*): Padding token id. bos_token_id (`int`, *optional*, defaults to 1): Beginning of stream token id. eos_token_id (`int`, *optional*, defaults to 2): End of stream token id. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. ```python >>> from transformers import DeepseekV3Model, DeepseekV3Config >>> # Initializing a Deepseek-V3 style configuration >>> configuration = DeepseekV3Config() >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "deepseek_v3" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, vocab_size=129280, hidden_size=7168, intermediate_size=18432, moe_intermediate_size = 2048, num_hidden_layers=61, num_nextn_predict_layers=1, num_attention_heads=128, num_key_value_heads=128, n_shared_experts = 1, n_routed_experts = 256, ep_size = 1, routed_scaling_factor = 2.5, kv_lora_rank = 512, q_lora_rank = 1536, qk_rope_head_dim = 64, v_head_dim = 128, qk_nope_head_dim = 128, topk_method = 'noaux_tc', n_group = 8, topk_group = 4, num_experts_per_tok = 8, moe_layer_freq = 1, first_k_dense_replace = 3, norm_topk_prob = True, scoring_func = 'sigmoid', hidden_act="silu", max_position_embeddings=4096, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=None, bos_token_id=0, eos_token_id=1, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.moe_intermediate_size = moe_intermediate_size self.num_hidden_layers = num_hidden_layers self.num_nextn_predict_layers = num_nextn_predict_layers self.num_attention_heads = num_attention_heads self.n_shared_experts = n_shared_experts self.n_routed_experts = n_routed_experts self.ep_size = ep_size self.routed_scaling_factor = routed_scaling_factor self.kv_lora_rank = kv_lora_rank self.q_lora_rank = q_lora_rank self.qk_rope_head_dim = qk_rope_head_dim self.v_head_dim = v_head_dim self.qk_nope_head_dim = qk_nope_head_dim self.topk_method = topk_method self.n_group = n_group self.topk_group = topk_group self.num_experts_per_tok = num_experts_per_tok self.moe_layer_freq = moe_layer_freq self.first_k_dense_replace = first_k_dense_replace self.norm_topk_prob = norm_topk_prob self.scoring_func = scoring_func # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) ================================================ FILE: archive/ktransformers/models/configuration_glm4_moe.py ================================================ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/glm4_moe/modular_glm4_moe.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_glm4_moe.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation class Glm4MoeConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Glm4MoeModel`]. It is used to instantiate a Glm4Moe model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of [THUDM/GLM-4-100B-A10B](https://huggingface.co/THUDM/GLM-4-100B-A10B). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 151552): Vocabulary size of the Glm4Moe model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Glm4MoeModel`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 10944): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 46): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 96): Number of attention heads for each attention layer in the Transformer encoder. partial_rotary_factor (`float`, *optional*, defaults to 0.5): The factor of the partial rotary position. num_key_value_heads (`int`, *optional*, defaults to 8): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details, check out [this paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 131072): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether the model's input and output word embeddings should be tied. rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly. Expected contents: `rope_type` (`str`): The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation. `factor` (`float`, *optional*): Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length. `original_max_position_embeddings` (`int`, *optional*): Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during pretraining. `attention_factor` (`float`, *optional*): Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value. `beta_fast` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32. `beta_slow` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1. `short_factor` (`list[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to short contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `long_factor` (`list[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to long contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `low_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. moe_intermediate_size (`int`, *optional*, defaults to 1408): Intermediate size of the routed expert. num_experts_per_tok (`int`, *optional*, defaults to 8): number of experts per token. n_shared_experts (`int`, *optional*, defaults to 1): Number of shared experts. n_routed_experts (`int`, *optional*, defaults to 128): Number of routed experts. routed_scaling_factor (`float`, *optional*, defaults to 1.0): Scaling factor or routed experts. n_group (`int`, *optional*, defaults to 1): Number of groups for routed experts. topk_group (`int`, *optional*, defaults to 1): Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). first_k_dense_replace (`int`, *optional*, defaults to 1): Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). \--k dense layers--/ norm_topk_prob (`bool`, *optional*, defaults to `True`): Whether to normalize the topk probabilities. use_qk_norm (`bool`, *optional*, defaults to `False`): Whether to use query-key normalization in the attention ```python >>> from transformers import Glm4MoeModel, Glm4MoeConfig >>> # Initializing a Glm4Moe style configuration >>> configuration = Glm4MoeConfig() >>> # Initializing a model from the GLM-4-MOE-100B-A10B style configuration >>> model = Glm4MoeModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "glm4_moe" keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `Glm4Moe` base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.experts.*.gate_proj": "colwise", "layers.*.mlp.experts.*.up_proj": "colwise", "layers.*.mlp.experts.*.down_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } def __init__( self, vocab_size=151552, hidden_size=4096, intermediate_size=10944, num_hidden_layers=46, num_attention_heads=96, partial_rotary_factor=0.5, num_key_value_heads=8, hidden_act="silu", max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-5, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, moe_intermediate_size=1408, num_experts_per_tok=8, n_shared_experts=1, n_routed_experts=128, routed_scaling_factor=1.0, n_group=1, topk_group=1, first_k_dense_replace=1, norm_topk_prob=True, use_qk_norm=False, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.partial_rotary_factor = partial_rotary_factor self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, move it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) # MoE arguments self.moe_intermediate_size = moe_intermediate_size self.num_experts_per_tok = num_experts_per_tok self.n_group = n_group self.topk_group = topk_group self.n_shared_experts = n_shared_experts self.n_routed_experts = n_routed_experts self.routed_scaling_factor = routed_scaling_factor self.first_k_dense_replace = first_k_dense_replace self.norm_topk_prob = norm_topk_prob self.use_qk_norm = use_qk_norm super().__init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) __all__ = ["Glm4MoeConfig"] ================================================ FILE: archive/ktransformers/models/configuration_llama.py ================================================ # coding=utf-8 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """LLaMA model configuration""" from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation class LlamaConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the LLaMA-7B. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`LlamaModel`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 11008): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 2048): The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, Llama 2 up to 4096, CodeLlama up to 16384. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. pad_token_id (`int`, *optional*): Padding token id. bos_token_id (`int`, *optional*, defaults to 1): Beginning of stream token id. eos_token_id (`int`, *optional*, defaults to 2): End of stream token id. pretraining_tp (`int`, *optional*, defaults to 1): Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is necessary to ensure exact reproducibility of the pretraining results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly. Expected contents: `rope_type` (`str`): The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation. `factor` (`float`, *optional*): Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length. `original_max_position_embeddings` (`int`, *optional*): Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during pretraining. `attention_factor` (`float`, *optional*): Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value. `beta_fast` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32. `beta_slow` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1. `short_factor` (`List[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to short contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `long_factor` (`List[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to long contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `low_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. mlp_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. ```python >>> from transformers import LlamaModel, LlamaConfig >>> # Initializing a LLaMA llama-7b style configuration >>> configuration = LlamaConfig() >>> # Initializing a model from the llama-7b style configuration >>> model = LlamaModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "llama" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, vocab_size=32000, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, mlp_bias=False, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.pretraining_tp = pretraining_tp self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.mlp_bias = mlp_bias # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, move it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) ================================================ FILE: archive/ktransformers/models/configuration_qwen2_moe.py ================================================ # coding=utf-8 # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Qwen2MoE model configuration""" from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) class Qwen2MoeConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Qwen2MoeModel`]. It is used to instantiate a Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of Qwen1.5-MoE-A2.7B" [Qwen/Qwen1.5-MoE-A2.7B"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B"). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 151936): Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Qwen2MoeModel`] hidden_size (`int`, *optional*, defaults to 2048): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 5632): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 24): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer encoder. num_key_value_heads (`int`, *optional*, defaults to 16): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 32768): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether the model's input and output word embeddings should be tied. rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. use_sliding_window (`bool`, *optional*, defaults to `False`): Whether to use sliding window attention. sliding_window (`int`, *optional*, defaults to 4096): Sliding window attention (SWA) window size. If not specified, will default to `4096`. max_window_layers (`int`, *optional*, defaults to 28): The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. decoder_sparse_step (`int`, *optional*, defaults to 1): The frequency of the MoE layer. moe_intermediate_size (`int`, *optional*, defaults to 1408): Intermediate size of the routed expert. shared_expert_intermediate_size (`int`, *optional*, defaults to 5632): Intermediate size of the shared expert. num_experts_per_tok (`int`, *optional*, defaults to 4): Number of selected experts. num_experts (`int`, *optional*, defaults to 60): Number of routed experts. norm_topk_prob (`bool`, *optional*, defaults to `False`): Whether to normalize the topk probabilities. output_router_logits (`bool`, *optional*, defaults to `False`): Whether or not the router logits should be returned by the model. Enabeling this will also allow the model to output the auxiliary loss, including load balancing loss and router z-loss. router_aux_loss_coef (`float`, *optional*, defaults to 0.001): The aux loss factor for the total loss. mlp_only_layers (`List[int]`, *optional*, defaults to `[]`): Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock The list contains layer index, from 0 to num_layers-1 if we have num_layers layers If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. ```python >>> from transformers import Qwen2MoeModel, Qwen2MoeConfig >>> # Initializing a Qwen2MoE style configuration >>> configuration = Qwen2MoeConfig() >>> # Initializing a model from the Qwen1.5-MoE-A2.7B" style configuration >>> model = Qwen2MoeModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "qwen2_moe" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, vocab_size=151936, hidden_size=2048, intermediate_size=5632, num_hidden_layers=24, num_attention_heads=16, num_key_value_heads=16, hidden_act="silu", max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, decoder_sparse_step=1, moe_intermediate_size=1408, shared_expert_intermediate_size=5632, num_experts_per_tok=4, num_experts=60, norm_topk_prob=False, output_router_logits=False, router_aux_loss_coef=0.001, mlp_only_layers=None, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.use_sliding_window = use_sliding_window self.sliding_window = sliding_window if use_sliding_window else None self.max_window_layers = max_window_layers self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.attention_dropout = attention_dropout # MoE arguments self.decoder_sparse_step = decoder_sparse_step self.moe_intermediate_size = moe_intermediate_size self.shared_expert_intermediate_size = shared_expert_intermediate_size self.num_experts_per_tok = num_experts_per_tok self.num_experts = num_experts self.norm_topk_prob = norm_topk_prob self.output_router_logits = output_router_logits self.router_aux_loss_coef = router_aux_loss_coef self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers super().__init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) ================================================ FILE: archive/ktransformers/models/configuration_qwen3_moe.py ================================================ # coding=utf-8 # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Qwen3MoE model configuration""" from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging logger = logging.get_logger(__name__) class Qwen3MoeConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Qwen3MoeModel`]. It is used to instantiate a Qwen3MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of [Qwen/Qwen3-MoE-15B-A2B](https://huggingface.co/Qwen/Qwen3-15B-A2B). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 151936): Vocabulary size of the Qwen3MoE model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Qwen3MoeModel`] hidden_size (`int`, *optional*, defaults to 2048): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 6144): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 24): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer encoder. num_key_value_heads (`int`, *optional*, defaults to 4): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 32768): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether the model's input and output word embeddings should be tied. rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly. Expected contents: `rope_type` (`str`): The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation. `factor` (`float`, *optional*): Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length. `original_max_position_embeddings` (`int`, *optional*): Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during pretraining. `attention_factor` (`float`, *optional*): Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value. `beta_fast` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32. `beta_slow` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1. `short_factor` (`List[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to short contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `long_factor` (`List[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to long contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `low_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. use_sliding_window (`bool`, *optional*, defaults to `False`): Whether to use sliding window attention. sliding_window (`int`, *optional*, defaults to 4096): Sliding window attention (SWA) window size. If not specified, will default to `4096`. max_window_layers (`int`, *optional*, defaults to 28): The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. decoder_sparse_step (`int`, *optional*, defaults to 1): The frequency of the MoE layer. moe_intermediate_size (`int`, *optional*, defaults to 768): Intermediate size of the routed expert. num_experts_per_tok (`int`, *optional*, defaults to 8): Number of selected experts. num_experts (`int`, *optional*, defaults to 128): Number of routed experts. norm_topk_prob (`bool`, *optional*, defaults to `False`): Whether to normalize the topk probabilities. output_router_logits (`bool`, *optional*, defaults to `False`): Whether or not the router logits should be returned by the model. Enabeling this will also allow the model to output the auxiliary loss, including load balancing loss and router z-loss. router_aux_loss_coef (`float`, *optional*, defaults to 0.001): The aux loss factor for the total loss. mlp_only_layers (`List[int]`, *optional*, defaults to `[]`): Indicate which layers use Qwen3MoeMLP rather than Qwen3MoeSparseMoeBlock The list contains layer index, from 0 to num_layers-1 if we have num_layers layers If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. ```python >>> from transformers import Qwen3MoeModel, Qwen3MoeConfig >>> # Initializing a Qwen3MoE style configuration >>> configuration = Qwen3MoeConfig() >>> # Initializing a model from the Qwen3-15B-A2B" style configuration >>> model = Qwen3MoeModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "qwen3_moe" keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `Qwen3Moe` base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } def __init__( self, vocab_size=151936, hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=32, num_key_value_heads=4, hidden_act="silu", max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, decoder_sparse_step=1, moe_intermediate_size=768, num_experts_per_tok=8, num_experts=128, norm_topk_prob=False, output_router_logits=False, router_aux_loss_coef=0.001, mlp_only_layers=None, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.use_sliding_window = use_sliding_window self.sliding_window = sliding_window if use_sliding_window else None self.max_window_layers = max_window_layers self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, move it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) # MoE arguments self.decoder_sparse_step = decoder_sparse_step self.moe_intermediate_size = moe_intermediate_size self.num_experts_per_tok = num_experts_per_tok self.num_experts = num_experts self.norm_topk_prob = norm_topk_prob self.output_router_logits = output_router_logits self.router_aux_loss_coef = router_aux_loss_coef self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers super().__init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) __all__ = ["Qwen3MoeConfig"] ================================================ FILE: archive/ktransformers/models/configuration_qwen3_next.py ================================================ # coding=utf-8 # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Qwen3-Next model configuration""" from transformers.configuration_utils import PretrainedConfig, layer_type_validation from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging logger = logging.get_logger(__name__) class Qwen3NextConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a Qwen3-Next model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 151936): Vocabulary size of the model. Defines the number of different tokens that can be represented by the `inputs_ids`. hidden_size (`int`, *optional*, defaults to 2048): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 5632): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 48): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer encoder. num_key_value_heads (`int`, *optional*, defaults to 2): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. hidden_act (`str`, *optional*, defaults to `"silu"`): The non-linear activation function in the decoder. max_position_embeddings (`int`, *optional*, defaults to 32768): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether the model's input and output word embeddings should be tied. rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly. Expected contents: `rope_type` (`str`): The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation. `factor` (`float`, *optional*): Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length. `original_max_position_embeddings` (`int`, *optional*): Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during pretraining. `attention_factor` (`float`, *optional*): Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value. `beta_fast` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear ramp function. If unspecified, it defaults to 32. `beta_slow` (`float`, *optional*): Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear ramp function. If unspecified, it defaults to 1. `short_factor` (`List[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to short contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `long_factor` (`List[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to long contexts (< `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `low_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE partial_rotary_factor (`float`, *optional*, defaults to 0.25): Percentage of the query and keys which will have rotary embedding. attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. head_dim (`int`, *optional*, defaults to 256): Projection weights dimension in multi-head attention. linear_conv_kernel_dim (`int`, *optional*, defaults to 4): Kernel size of the convolution used in linear attention layers. linear_key_head_dim (`int`, *optional*, defaults to 128): Dimension of each key head in linear attention. linear_value_head_dim (`int`, *optional*, defaults to 128): Dimension of each value head in linear attention. linear_num_key_heads (`int`, *optional*, defaults to 16): Number of key heads used in linear attention layers. linear_num_value_heads (`int`, *optional*, defaults to 32): Number of value heads used in linear attention layers. decoder_sparse_step (`int`, *optional*, defaults to 1): The frequency of the MoE layer. moe_intermediate_size (`int`, *optional*, defaults to 512): Intermediate size of the routed expert. shared_expert_intermediate_size (`int`, *optional*, defaults to 512): Intermediate size of the shared expert. num_experts_per_tok (`int`, *optional*, defaults to 10): Number of selected experts. num_experts (`int`, *optional*, defaults to 512): Number of routed experts. norm_topk_prob (`bool`, *optional*, defaults to `True`): Whether to normalize the topk probabilities. output_router_logits (`bool`, *optional*, defaults to `False`): Whether or not the router logits should be returned by the model. Enabling this will also allow the model to output the auxiliary loss, including load balancing loss and router z-loss. router_aux_loss_coef (`float`, *optional*, defaults to 0.001): The aux loss factor for the total loss. mlp_only_layers (`list[int]`, *optional*, defaults to `[]`): Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock The list contains layer index, from 0 to num_layers-1 if we have num_layers layers If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. layer_types (`list[str]`, *optional*): Types of each layer (attention or linear). ```python >>> from transformers import Qwen3NextModel, Qwen3NextConfig >>> # Initializing a Qwen3Next style configuration >>> configuration = Qwen3NextConfig() >>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration >>> model = Qwen3NextModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ``` """ model_type = "qwen3_next" keys_to_ignore_at_inference = ["past_key_values"] base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.experts.*.gate_proj": "colwise", "layers.*.mlp.experts.*.up_proj": "colwise", "layers.*.mlp.experts.*.down_proj": "rowwise", "layers.*.mlp.shared_experts.gate_proj": "colwise", "layers.*.mlp.shared_experts.up_proj": "colwise", "layers.*.mlp.shared_experts.down_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } def __init__( self, vocab_size=151936, hidden_size=2048, intermediate_size=5632, num_hidden_layers=48, num_attention_heads=16, num_key_value_heads=2, hidden_act="silu", max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, partial_rotary_factor=0.25, attention_bias=False, attention_dropout=0.0, head_dim=256, linear_conv_kernel_dim=4, linear_key_head_dim=128, linear_value_head_dim=128, linear_num_key_heads=16, linear_num_value_heads=32, decoder_sparse_step=1, moe_intermediate_size=512, shared_expert_intermediate_size=512, num_experts_per_tok=10, num_experts=512, norm_topk_prob=True, output_router_logits=False, router_aux_loss_coef=0.001, mlp_only_layers=[], layer_types=None, **kwargs, ): super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.partial_rotary_factor = partial_rotary_factor self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.head_dim = head_dim rope_config_validation(self) self.layer_types = layer_types if self.layer_types is None: self.layer_types = [ "linear_attention" if bool((i + 1) % 4) else "full_attention" for i in range(self.num_hidden_layers) ] layer_type_validation(self.layer_types) # linear attention part self.linear_conv_kernel_dim = linear_conv_kernel_dim self.linear_key_head_dim = linear_key_head_dim self.linear_value_head_dim = linear_value_head_dim self.linear_num_key_heads = linear_num_key_heads self.linear_num_value_heads = linear_num_value_heads # MoE arguments self.decoder_sparse_step = decoder_sparse_step self.moe_intermediate_size = moe_intermediate_size self.shared_expert_intermediate_size = shared_expert_intermediate_size self.num_experts_per_tok = num_experts_per_tok self.num_experts = num_experts self.norm_topk_prob = norm_topk_prob self.output_router_logits = output_router_logits self.router_aux_loss_coef = router_aux_loss_coef self.mlp_only_layers = mlp_only_layers __all__ = ["Qwen3NextConfig"] ================================================ FILE: archive/ktransformers/models/configuration_smallthinker.py ================================================ # coding=utf-8 from transformers.configuration_utils import PretrainedConfig class SmallthinkerConfig(PretrainedConfig): """ This is the configuration class to store the configuration of a [`SmallthinkerModel`]. It is used to instantiate a Smallthinker model according to the specified arguments, defining the model architecture. The default values for each of the parameters are the same as the ones used in the original Smallthinker 4B model. General configs: - model_type: "smallthinker" - model_name - num_hidden_layers - hidden_size Tokenizer configs: - pad_token_id - bos_token_id - eos_token_id Embedding configs: - vocab_size RMSNorm configs: - rms_norm_eps Attention configs: - num_attention_heads - num_key_value_heads - head_dim - use_cache - use_qk_norm - rope_layout: array of 0 or 1s, 0 for nope, 1 for rope - rope_theta - max_position_embeddings - sliding_window_layout: array of 0 or 1s, 0 for normal attention, 1 for SWA - sliding_window_size General FFN configs: - moe_layer_layout: array of 0 or 1s, 0 for dense layer, 1 for MoE layer Dense FFN configs: - dense_ffn_hidden_size MoE FFN configs: - moe_num_primary_experts - moe_shared_primary_experts - moe_ffn_hidden_size - moe_enable_early_router: Use attention output as router input if true - moe_primary_router_use_sigmoid: Use normalized sigmoid - moe_num_active_primary_experts - moe_enable_secondary_experts - moe_num_secondary_experts - moe_secondary_expert_size LM Head configs: - tie_word_embeddings Visibility configs: - profile_sparsity Other configs: - initializer_range """ def __init__(self, model_type = "smallthinker", model_name="smallthinker_4b_base", num_hidden_layers=32, hidden_size=1536, pad_token_id=None, bos_token_id=151643, eos_token_id=[151643,151645], vocab_size=151936, rms_norm_eps=1e-6, num_attention_heads=12, num_key_value_heads=2, head_dim=128, use_cache=True, use_qk_norm=False, rope_layout=[1]*32, rope_theta=1e6, max_position_embeddings=4096 * 32, sliding_window_layout=[0]*32, sliding_window_size=4096, moe_layer_layout=[1]*32, dense_ffn_hidden_size=4096, moe_num_primary_experts=32, moe_shared_primary_experts=0, moe_ffn_hidden_size=768, moe_enable_early_router=True, moe_primary_router_apply_softmax=False, moe_num_active_primary_experts=4, moe_enable_secondary_experts=False, moe_num_secondary_experts=0, moe_secondary_expert_size=0, tie_word_embeddings=True, initializer_range=0.02, **kwargs, ): moe_layer_layout = [1]*num_hidden_layers # Configuration sanitizers assert num_attention_heads % num_key_value_heads == 0, "[Smallthinker config sanitizer] num_attention_heads must be divisible by num_key_value_heads" assert len(rope_layout) == num_hidden_layers, "[Smallthinker config sanitizer] rope_layout must have the same length as num_hidden_layers" assert len(sliding_window_layout) == num_hidden_layers, "[Smallthinker config sanitizer] sliding_window_layout must have the same length as num_hidden_layers" assert len(moe_layer_layout) == num_hidden_layers, "[Smallthinker config sanitizer] moe_layer_layout must have the same length as num_hidden_layers" if any(moe_layer_layout): assert moe_num_primary_experts != 0, "[Smallthinker config sanitizer] moe_num_primary_experts must be set non-zero if there is any MoE layer" assert moe_ffn_hidden_size != 0, "[Smallthinker config sanitizer] moe_ffn_hidden_size must be set non-zero if there is any MoE layer" assert moe_num_active_primary_experts != 0, "[Smallthinker config sanitizer] moe_num_active_primary_experts must be set non-zero if there is any MoE layer" if moe_enable_secondary_experts: assert moe_num_secondary_experts != 0, "[Smallthinker config sanitizer] moe_num_secondary_experts must be set non-zero if moe_enable_secondary_experts is True" assert moe_secondary_expert_size != 0, "[Smallthinker config sanitizer] moe_secondary_expert_size must be set non-zero if moe_enable_secondary_experts is True" assert moe_num_secondary_experts * moe_secondary_expert_size == moe_ffn_hidden_size, "[Smallthinker config sanitizer] moe_num_secondary_experts * moe_secondary_expert_size must equal moe_ffn_hidden_size" if not all(moe_layer_layout): assert dense_ffn_hidden_size != 0, "[Smallthinker config sanitizer] dense_ffn_hidden_size must be set non-zero if there is any dense FFN layer" # General configs self.model_type = model_type self.model_name = model_name self.num_hidden_layers = num_hidden_layers self.hidden_size = hidden_size # Tokenizer configs self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id # Embedding configs self.vocab_size = vocab_size # RMSNorm configs self.rms_norm_eps = rms_norm_eps # Attention configs self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.head_dim = head_dim self.use_cache = use_cache self.use_qk_norm = use_qk_norm self.rope_layout = rope_layout self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.sliding_window_layout = sliding_window_layout self.sliding_window_size = sliding_window_size # General FFN configs self.moe_layer_layout = moe_layer_layout # Dense FFN configs self.dense_ffn_hidden_size = dense_ffn_hidden_size # MoE FFN configs self.moe_num_primary_experts = moe_num_primary_experts self.moe_shared_primary_experts = moe_shared_primary_experts self.moe_ffn_hidden_size = moe_ffn_hidden_size self.num_experts_per_tok = moe_num_active_primary_experts self.moe_intermediate_size = moe_ffn_hidden_size self.moe_enable_early_router = moe_enable_early_router self.moe_primary_router_apply_softmax = moe_primary_router_apply_softmax self.moe_num_active_primary_experts = moe_num_active_primary_experts self.moe_enable_secondary_experts = moe_enable_secondary_experts self.moe_num_secondary_experts = moe_num_secondary_experts self.moe_secondary_expert_size = moe_secondary_expert_size # Logging configs # self.output_router_logits = False # Other configs self.initializer_range = initializer_range super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs) self._attn_implementation = "eager" # SDPA is not allowed for now # if self._attn_implementation != "flash_attention_2": # raise NotImplementedError("SDPA impl is buggy for now. NEVER TRY TO USE IT.") __all__ = ["SmallthinkerConfig"] ================================================ FILE: archive/ktransformers/models/custom_cache.py ================================================ ''' Description : Author : Boxin Zhang Version : 0.1.0 ''' # Adapted from # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/cache_utils.py # Copyright 2018- The Hugging Face team. All rights reserved. # Copyright (c) 2024 by KVCache.AI, All Rights Reserved. import torch import torch.nn as nn import transformers from transformers import Cache, PretrainedConfig from typing import List, Optional, Dict, Any, Tuple try: import torch_npu from ktransformers.util import utils from ktransformers.server.balance_serve.inference.forward_batch import ForwardMiniBatchCombine, ForwardMiniBatchSplit use_torch_npu = torch_npu.npu.is_available() except: use_torch_npu = False from transformers.models.llama.modeling_llama import LlamaDecoderLayer from ktransformers.server.balance_serve.settings import sched_ext class StaticCache(transformers.StaticCache): """ Static Cache class to be used with `torch.compile(model)`. Parameters: config (`PretrainedConfig): The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): The maximum batch size with which the model will be used. max_cache_len (`int`): The maximum sequence length with which the model will be used. device (`torch.device` or `dict`): The device on which the cache should be initialized. Should be the same as the layer. If a `dict`, it should contain the `device` key with the device name as the value. dtype (*optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. """ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None: Cache.__init__(self, layer_class_to_replicate=LlamaDecoderLayer) self._max_batch_size = max_batch_size if use_torch_npu: self.position = [0] self._max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads if config.architectures[0] == "DeepseekV3ForCausalLM": self.head_dim = config.qk_rope_head_dim else: self.head_dim = ( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads ) self.dtype = dtype if dtype is not None else torch.float32 self.num_key_value_heads = ( config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] cache_shape = (max_batch_size, self.num_key_value_heads, self._max_cache_len, self.head_dim) if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM": # TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically if use_torch_npu: self.page_size = 128 self.page_size_tensor = torch.tensor( self.page_size, dtype=torch.int32, ).npu() self.max_pages_per_batch = (self._max_cache_len + self.page_size - 1) // self.page_size self.max_pages = (self._max_cache_len + self.page_size - 1) // self.page_size * self._max_batch_size else: self.page_size = 64 self.max_pages = (self._max_cache_len + self.page_size - 1) // self.page_size latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) self.kv_lora_rank = config.kv_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim # TODO: support real page table self.page_table_map = dict() self.page_table_list = [] for idx in range(config.num_hidden_layers): if isinstance(device, dict): target_device = device[f"blk.{idx}.self_attn"]["generate_device"] else: target_device = device if target_device not in self.page_table_map: if use_torch_npu: page_table = torch.zeros((max_batch_size, self.max_pages_per_batch), dtype=torch.int32, device=target_device) for seq_id in range(max_batch_size): page_table[seq_id, :] = torch.arange(seq_id * self.max_pages_per_batch, seq_id * self.max_pages_per_batch + self.max_pages_per_batch, dtype=torch.int32, device=target_device) else: page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device) for seq_id in range(max_batch_size): page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device) self.page_table_map[target_device] = page_table self.page_table_list.append(self.page_table_map[target_device]) self.is_MLA = True self.is_page = True else: key_shape = cache_shape value_shape = cache_shape self.is_MLA = False self.past_tokens = [] self.num_hidden_layers = config.num_hidden_layers for idx in range(self.num_hidden_layers): # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. if isinstance(device, dict): target_device = device[f"blk.{idx}.self_attn"]["generate_device"] else: target_device = device if self.is_MLA: new_layer_key_cache = torch.zeros(latent_shape, dtype=self.dtype, device=target_device) new_layer_value_cache = None torch._dynamo.mark_static_address(new_layer_key_cache) else: new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device) new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device) torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) self.past_tokens.append(0) @property def max_batch_size(self): return self._max_batch_size @property def max_cache_len(self): return self._max_cache_len def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. It is VERY important to index using a tensor, otherwise you introduce a copy to the device. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input to know how where to write in the cache. Return: A tuple containing the updated key and value states. """ cache_position = cache_kwargs.get("cache_position") k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] self.past_tokens[layer_idx] += cache_position.size(0) #print(cache_position) if self.is_MLA: if use_torch_npu: page_idx = cache_position // self.page_size_tensor page_offset = cache_position % self.page_size_tensor page_idx = page_idx.unsqueeze(0).expand(self.max_batch_size, -1) page_offset = page_offset.unsqueeze(0).expand(self.max_batch_size, -1) page_idx_offset = torch.arange(self.max_batch_size, device=page_idx.device) * self.max_pages_per_batch page_idx = page_idx + page_idx_offset.unsqueeze(1) combined = torch.cat([key_states, value_states], dim=-1) combined = combined.contiguous() # key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) k_out[page_idx, page_offset] = combined else: page_idx = cache_position // self.page_size page_offset = cache_position % self.page_size # key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states return k_out, self.page_table_list[layer_idx] else: k_out[:, :, cache_position] = key_states v_out[:, :, cache_position] = value_states return k_out, v_out def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model.""" # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. # TODO: deprecate this function in favor of `cache_position` return self.past_tokens[layer_idx] def change_seq_length(self, bias: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model.""" # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. # TODO: deprecate this function in favor of `cache_position` for layer_idx in range(self.num_hidden_layers): self.past_tokens[layer_idx] += bias def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states.""" return self.max_cache_len def get_usable_length(self, kv_seq_len, layer_idx: Optional[int] = 0) -> int: return 0 def reset(self): """Resets the cache values while preserving the objects""" for layer_idx in range(len(self.key_cache)): # In-place ops prevent breaking the static address self.key_cache[layer_idx].zero_() if self.value_cache[layer_idx] is not None: self.value_cache[layer_idx].zero_() self.past_tokens[layer_idx] = 0 if use_torch_npu: self.position = [0] def remove_suffix(self, start_pos): for layer_idx in range(len(self.key_cache)): # In-place ops prevent breaking the static address if self.is_MLA: k_cache = self.key_cache[layer_idx] k_cache.view(-1, k_cache.shape[-1])[start_pos:].zero_() else: self.key_cache[layer_idx][..., start_pos:, :].zero_() self.value_cache[layer_idx][..., start_pos:, :].zero_() self.past_tokens[layer_idx] = start_pos def get_max_cache_shape(self) -> Tuple[int, int, int, int]: """Returns the maximum shape of the cache.""" return self.max_cache_len class KVC2StaticCache: """ Static Cache class connect with KVC2 remind: page_idx & page_offset info need to refs to forward batching, only contains KV Block Tensor here """ def __init__(self, config: PretrainedConfig, max_batch_size, page_size: int = 256, dtype=torch.bfloat16, device=None) -> None: super().__init__() self.config = config self.dtype = dtype self.device = torch.device("npu:0") self.kv_lora_rank = config.kv_lora_rank self.max_batch_size = max_batch_size self.page_size = page_size self.k_caches = [] self.v_caches = [] self.num_hidden_layers = config.num_hidden_layers self.is_MLA = True if config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"] else False # kv cache stored in kvc2 # self.past_tokens = [] def load(self, inference_context): # assert self.is_MLA and len(inference_context.k_cache) == 1, "currently only support MLA and Cache Pool TP=1" from ktransformers.util.utils import get_current_device for i in range(self.config.num_hidden_layers): new_layer_key_cache = inference_context.k_cache[int(torch.distributed.get_rank())][i].to(get_current_device()) torch._dynamo.mark_static_address(new_layer_key_cache) self.k_caches.append( new_layer_key_cache # [TP_idx, layer_idx, page_idx, page_size, kv_head_num, kv_head_dim] ) self.v_caches.append(None) self.max_cache_len = self.k_caches[0].shape[0] * self.k_caches[0].shape[1] # page_len * page_size def update( self, combined: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): must have page_idx (`torch.Tensor`): & page_offset (`torch.Tensor`) & cache_position (`torch.Tensor`) Return: A tuple containing the updated key and value states. """ page_idx, page_offset = cache_kwargs.get("page_idx"), cache_kwargs.get("page_offset") if page_idx is None or page_offset is None: raise ValueError('[ERROR] block info:page_idx & page_offset missing!') k_out = self.k_caches[layer_idx] assert self.is_MLA, "currently only support DeepSeekV3 on NPU balance server" if page_idx.dim() == 1: page_idx_tmp = page_idx.unsqueeze(0) page_offset_tmp = page_offset.unsqueeze(0) else: page_idx_tmp = page_idx page_offset_tmp = page_offset k_out[page_idx_tmp, page_offset_tmp] = combined return k_out, page_idx def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model.""" raise ValueError('kvc2 cache pool no longer hold seq_length info, refer to forward batching') def get_usable_length(self, kv_seq_len, layer_idx: Optional[int] = 0) -> int: return 0 def change_seq_length(self, bias: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model.""" raise ValueError('kvc2 cache pool no longer hold seq_length info, refer to forward batching') def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states.""" return self.max_cache_len def reset(self, inference_context): assert self.is_MLA and len(inference_context.k_cache) == 1, "currently only support MLA and Cache Pool TP=1" self.k_caches = [] self.v_caches = [] for i in range(self.config.num_hidden_layers): self.k_caches.append( inference_context.k_cache[0][i] ) self.v_caches.append(None) self.max_cache_len = self.k_caches[0].shape[0] * self.k_caches[0].shape[1] # page_len * page_size def get_page_table(self, mini_batch, bsz_tensors: torch.tensor = None, is_prefill=True): if is_prefill: # TODO add padding support q_lens = [mini_batch.p_q_len[idx] for idx in range(mini_batch.prefill_batch)] page_local_idx = -1 * torch.ones(mini_batch.prefill_batch, max(q_lens), dtype=mini_batch.p_position_ids.dtype, device=mini_batch.p_position_ids.device) page_offset = -1 * torch.ones_like(page_local_idx) # convert merged into batched start_ids = 0 for i in range(mini_batch.prefill_batch): page_offset[i, 0:q_lens[i]] = mini_batch.p_position_ids[start_ids:start_ids+q_lens[i]] % self.page_size page_local_idx[i, 0:q_lens[i]] = mini_batch.p_position_ids[start_ids:start_ids+q_lens[i]] // self.page_size for j in range(q_lens[i]): # get global page idx index by local page idx from block table, as followed decode page_local_idx[i, j] = mini_batch.p_block_tables[i, page_local_idx[i, j]] start_ids += q_lens[i] page_idx = page_local_idx # only padding will cause page_local_idx/page_offset still have -1 value # you can use following code as check # indices = torch.where(page_offset == -1) # assert not indices[0].numel() > 0, 'there still have un-calculated page_idx value' else: page_local_idx = mini_batch.d_position_ids // self.page_size page_offset = mini_batch.d_position_ids % self.page_size for i in range(mini_batch.decode_batch): page_local_idx[i] = mini_batch.d_block_tables[i, page_local_idx[i]] page_idx = page_local_idx return page_idx, page_offset class KDeepSeekV3Cache(nn.Module): def __init__( self, config: PretrainedConfig, page_size: int = 256, dtype=torch.bfloat16, device=torch.device("cuda:0"), ): super().__init__() self.config = config self.dtype = dtype self.device = device self.kv_lora_rank = config.kv_lora_rank self.page_size = page_size self.k_caches = [] self.v_caches = [] def load(self, inference_context: "sched_ext.InferenceContext"): for i in range(self.config.num_hidden_layers): self.k_caches.append( inference_context.k_cache[0][i] ) self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1] def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, page_idx: torch.Tensor, page_offset: torch.Tensor, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. It is VERY important to index using a tensor, otherwise you introduce a copy to the device. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input to know how where to write in the cache. Return: A tuple containing the updated key and value states. """ k_out = self.k_caches[layer_idx] k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states.reshape(-1, *key_states.shape[2:]) k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states.reshape(-1, *value_states.shape[2:]) return k_out def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor): page_offset = cache_position % self.page_size page_idx_local = cache_position // self.page_size query_ids = torch.zeros_like(cache_position) for i in range(len(q_indptr) - 1): start_idx = q_indptr[i] end_idx = q_indptr[i + 1] query_ids[start_idx:end_idx] = i page_idx = torch.zeros_like(page_idx_local) for i in range(bsz_tensors[0]): query_id = query_ids[i] local_block = page_idx_local[i] start_block = kv_indptr[query_id] if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]: page_idx[i] = kv_indices[start_block + local_block] return page_idx, page_offset class KGQACache(nn.Module): def __init__( self, config: PretrainedConfig, page_size: int = 256, dtype=torch.bfloat16, device=torch.device("cuda:0"), ): super().__init__() self.config = config self.dtype = dtype self.device = device self.page_size = page_size self.k_caches = [] self.v_caches = [] def load(self, inference_context: "sched_ext.InferenceContext"): print(self.config.num_hidden_layers) for i in range(self.config.num_hidden_layers): self.k_caches.append( inference_context.k_cache[0][i] ) self.v_caches.append( inference_context.v_cache[0][i] ) self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1] def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor): page_offset = cache_position % self.page_size page_idx_local = cache_position // self.page_size query_ids = torch.zeros_like(cache_position) for i in range(len(q_indptr) - 1): start_idx = q_indptr[i] end_idx = q_indptr[i + 1] query_ids[start_idx:end_idx] = i page_idx = torch.zeros_like(page_idx_local) for i in range(bsz_tensors[0]): query_id = query_ids[i] local_block = page_idx_local[i] start_block = kv_indptr[query_id] if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]: page_idx[i] = kv_indices[start_block + local_block] return page_idx, page_offset def get_k_cache(self, layer_idx): return self.k_caches[layer_idx] def get_v_cache(self, layer_idx): return self.v_caches[layer_idx] class KVC2Qwen3Cache(nn.Module): def __init__(self, config, max_batch_size, page_size=256, dtype=torch.bfloat16, device=None): super().__init__() self.config = config self.max_batch_size = max_batch_size self.page_size = page_size self.dtype = dtype self.device = device if device else torch.device("npu:0") self.num_layers = config.num_hidden_layers self.num_kv_heads = config.num_key_value_heads self.head_dim = config.head_dim self.k_caches = [] self.v_caches = [] # ------------------------- 绑定到底层 kvc2 pool ------------------------- def load(self, inference_context): from ktransformers.util.utils import get_current_device dev = get_current_device() self.k_caches = [] self.v_caches = [] rank = ( torch.distributed.get_rank() if (torch.distributed.is_available() and torch.distributed.is_initialized()) else 0 ) for i in range(self.num_layers): k_buf = inference_context.k_cache[rank][i].to(dev).to(self.dtype) v_buf = inference_context.v_cache[rank][i].to(dev).to(self.dtype) torch._dynamo.mark_static_address(k_buf) torch._dynamo.mark_static_address(v_buf) self.k_caches.append(k_buf) self.v_caches.append(v_buf) # num_pages * page_size self.max_cache_len = self.k_caches[0].shape[0] * self.k_caches[0].shape[1] # ------------------------- 写 KV ------------------------- @torch.no_grad() def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ): if cache_kwargs is None: raise ValueError("[KVC2Qwen3Cache] cache_kwargs must contain page_idx & page_offset") page_idx: Optional[torch.Tensor] = cache_kwargs.get("page_idx", None) page_offset: Optional[torch.Tensor] = cache_kwargs.get("page_offset", None) if page_idx is None or page_offset is None: raise ValueError("[KVC2Qwen3Cache] page_idx & page_offset are required in cache_kwargs") k_out = self.k_caches[layer_idx] v_out = self.v_caches[layer_idx] # -------- 1) 修正维度顺序:[B, KvH, Q, D] -> [B, Q, KvH, D] -------- if key_states.dim() == 4 and key_states.shape[1] == self.num_kv_heads: key_states = key_states.transpose(1, 2).contiguous() value_states = value_states.transpose(1, 2).contiguous() if key_states.shape != value_states.shape: raise ValueError( f"[KVC2Qwen3Cache] key_states.shape {key_states.shape} " f"!= value_states.shape {value_states.shape}" ) if key_states.dim() != 4: raise ValueError( f"[KVC2Qwen3Cache] expect key_states dim=4, got {key_states.dim()} " f"(shape={key_states.shape})" ) bsz, q_len, kv_heads, head_dim = key_states.shape if kv_heads != self.num_kv_heads or head_dim != self.head_dim: raise ValueError( f"[KVC2Qwen3Cache] KV shape mismatch: " f"got num_kv_heads={kv_heads}, head_dim={head_dim}, " f"expected num_kv_heads={self.num_kv_heads}, head_dim={self.head_dim}" ) # -------- 2) flatten page_idx / page_offset 为一维 -------- page_idx = page_idx.reshape(-1) page_offset = page_offset.reshape(-1) # -------- 3) flatten KV,并强制 dtype 与 cache 对齐 -------- val_dtype = k_out.dtype flat_k = key_states.to(val_dtype).reshape(-1, kv_heads, head_dim) flat_v = value_states.to(val_dtype).reshape(-1, kv_heads, head_dim) # -------- 4) 真正写入 K / V -------- # k_out / v_out: [num_pages, page_size, num_kv_heads, head_dim] k_out[page_idx, page_offset] = flat_k v_out[page_idx, page_offset] = flat_v # ------------------------- get K/V ------------------------- def get_k_cache(self, layer_idx): return self.k_caches[layer_idx] def get_v_cache(self, layer_idx): return self.v_caches[layer_idx] # ------------------------- page table 计算 ------------------------- def get_page_table( self, mini_batch, bsz_tensors: torch.Tensor = None, is_prefill: bool = True, ): if is_prefill: # prefill: merged positions => batched (B, T_chunk) q_lens = [int(mini_batch.p_q_len[idx]) for idx in range(mini_batch.prefill_batch)] if len(q_lens) == 0: return None, None max_q_len = max(q_lens) page_local_idx = -1 * torch.ones( mini_batch.prefill_batch, max_q_len, dtype=mini_batch.p_position_ids.dtype, device=mini_batch.p_position_ids.device, ) page_offset = -1 * torch.ones_like(page_local_idx) start_ids = 0 for i in range(mini_batch.prefill_batch): cur_len = q_lens[i] pos = mini_batch.p_position_ids[start_ids:start_ids + cur_len] # global pos of this chunk # local block + offset by page_size page_offset[i, 0:cur_len] = pos % self.page_size page_local_idx[i, 0:cur_len] = pos // self.page_size # local block -> global page id via block_tables for j in range(cur_len): blk = page_local_idx[i, j] page_local_idx[i, j] = mini_batch.p_block_tables[i, blk] start_ids += cur_len page_idx = page_local_idx else: # decode: decode_batch = 当前 step 的 batch_size, 每条样本通常 1 个 token page_local_idx = mini_batch.d_position_ids // self.page_size page_offset = mini_batch.d_position_ids % self.page_size for i in range(mini_batch.decode_batch): blk = page_local_idx[i] page_local_idx[i] = mini_batch.d_block_tables[i, blk] page_idx = page_local_idx return page_idx, page_offset ================================================ FILE: archive/ktransformers/models/custom_modeling_deepseek_v2.py ================================================ import math from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F import math from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput from ktransformers.models.custom_cache import KDeepSeekV3Cache from ktransformers.models.modeling_deepseek import DeepseekV2Model, DeepseekV2PreTrainedModel from ktransformers.models.configuration_deepseek import DeepseekV2Config torch.set_grad_enabled(False) torch.set_default_dtype(torch.bfloat16) import flashinfer class KDeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): kv_cache: KDeepSeekV3Cache use_cuda_graph = False def __init__( self, config, kv_cache, ): super().__init__(config) self.model = DeepseekV2Model(config) self.config = config self.kv_cache = kv_cache self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages): self.use_cuda_graph = use_cuda_graph self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) self.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device) self.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device) self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device) self.paged_kv_len_buf = torch.empty((max_batch_size,), dtype=torch.int32, device=device) self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( self.workspace_buffer, use_cuda_graph=use_cuda_graph, qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf, kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf, backend = "fa2", ) def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): features = [] for i in range(batch.batch_size): tokens = batch.minibatch.tokens.contiguous() feature = ( self.model.embed_tokens(tokens.to(torch.device('cpu'))) .to(torch.bfloat16) .to(device=device) ) features.append(feature) return features def forward( self, batch: ForwardBatchInput | None = None, features: List[torch.Tensor] | None = None, bsz_tensors: torch.Tensor | None = None, num_tokens_tensors: torch.Tensor | None = None, page_idx: torch.Tensor | None = None, page_offset: torch.Tensor | None = None, ) -> ForwardBatchOutput: current_stream = torch.cuda.current_stream() forward_batch_output = ForwardBatchOutput() hidden_states = features[0] with torch.cuda.stream(current_stream): residual = torch.zeros_like(hidden_states) for i, decode_layer in enumerate(self.model.layers): if self.model.transfer_map is not None and i in self.model.transfer_map: prev_stream = torch.cuda.current_stream() cur_device = self.model.transfer_map[i] if cur_device not in self.model.stream_device_map: self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) torch.cuda.set_device(cur_device) self.model.stream_device_map[cur_device].wait_stream(prev_stream) torch.cuda.set_stream(self.model.stream_device_map[cur_device]) hidden_states = hidden_states.to( self.model.transfer_map[i], non_blocking=True ) batch.minibatch.position_ids = ( batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True) if batch.minibatch.position_ids is not None else None ) hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual) hidden_states = decode_layer.self_attn(hidden_states, self.kv_cache, position_ids=batch.minibatch.position_ids, wrapper=self.wrapper, bsz_tensors=num_tokens_tensors, cache_position=batch.minibatch.positions, batch_indices=batch.minibatch.batch_indices, kv_indices=batch.minibatch.kv_indices, kv_indptr=batch.minibatch.kv_indptr, kv_last_page_len=batch.minibatch.kv_last_page_len, q_indptr=batch.minibatch.q_indptr, page_idx=page_idx, page_offset=page_offset ) hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual) if i < 3: hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors) else: hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors) hidden_states = hidden_states.squeeze(0) forward_batch_output = ForwardBatchOutput() assert batch.batch_size == 1 with torch.cuda.stream(current_stream): local_logit = self.lm_head(self.model.norm(hidden_states[batch.minibatch.logits_start], num_tokens_tensors, residual[batch.minibatch.logits_start])[0]) # local_logit = local_logit[batch.minibatch.logits_start] forward_batch_output.logits.append(local_logit) return forward_batch_output def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors, num_heads: int, head_dim_ckv: int, head_dim_kpe: int, page_size: int, causal: bool, sm_scale: float, q_data_type: torch.dtype, kv_data_type: torch.dtype,): minibatch = batch.minibatch self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type) ================================================ FILE: archive/ktransformers/models/custom_modeling_deepseek_v3.py ================================================ """ Date: 2024-11-06 10:05:11 LastEditors: djw LastEditTime: 2024-11-13 07:50:51 """ import math from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F import math from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput from ktransformers.models.custom_cache import KDeepSeekV3Cache from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Model, DeepseekV3PreTrainedModel from ktransformers.models.configuration_deepseek_v3 import DeepseekV3Config torch.set_grad_enabled(False) torch.set_default_dtype(torch.bfloat16) import flashinfer class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): cache: KDeepSeekV3Cache use_cuda_graph = False def __init__( self, config: DeepseekV3Config, cache, ): super().__init__(config) self.model = DeepseekV3Model(config) self.config = config self.cache = cache self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages): self.use_cuda_graph = use_cuda_graph self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) self.qo_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device) self.paged_kv_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device) self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device) self.paged_kv_len_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device) self.bsz_tensor_buf = torch.empty((1, ), dtype=torch.int32, device=device) self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( self.workspace_buffer, use_cuda_graph=use_cuda_graph, qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf, kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf, bsz_tensor=self.bsz_tensor_buf, backend = "fa2", ) def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): features = [] for i in range(batch.batch_size): tokens = batch.minibatch.tokens.contiguous() feature = ( self.model.embed_tokens(tokens.to(torch.device('cpu'))) .to(torch.bfloat16) .to(device=device) ) features.append(feature) return features def forward( self, batch: ForwardBatchInput | None = None, features: List[torch.Tensor] | None = None, bsz_tensors: torch.Tensor | None = None, num_tokens_tensors: torch.Tensor | None = None, page_idx: torch.Tensor | None = None, page_offset: torch.Tensor | None = None, cuda_graph_idx: int | None = -1 ) -> ForwardBatchOutput: current_stream = torch.cuda.current_stream() forward_batch_output = ForwardBatchOutput() hidden_states = features[0] with torch.cuda.stream(current_stream): residual = torch.zeros_like(hidden_states) for i, decode_layer in enumerate(self.model.layers): # can't use now, only one flashinfer wrapper if self.model.transfer_map is not None and i in self.model.transfer_map: prev_stream = torch.cuda.current_stream() cur_device = self.model.transfer_map[i] if cur_device not in self.model.stream_device_map: self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) torch.cuda.set_device(cur_device) self.model.stream_device_map[cur_device].wait_stream(prev_stream) torch.cuda.set_stream(self.model.stream_device_map[cur_device]) hidden_states = hidden_states.to( self.model.transfer_map[i], non_blocking=True ) batch.minibatch.position_ids = ( batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True) if batch.minibatch.position_ids is not None else None ) hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual) hidden_states = decode_layer.self_attn(hidden_states, self.cache, position_ids=batch.minibatch.position_ids, wrapper=self.wrapper, num_tokens_tensors=num_tokens_tensors, page_idx=page_idx, page_offset=page_offset ) hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual) if i < self.config.first_k_dense_replace: hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors) else: hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx) hidden_states = hidden_states.squeeze(0) forward_batch_output = ForwardBatchOutput() with torch.cuda.stream(current_stream): local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors) forward_batch_output.logits.append(local_logit) return forward_batch_output def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors, num_heads: int, head_dim_ckv: int, head_dim_kpe: int, page_size: int, causal: bool, sm_scale: float, q_data_type: torch.dtype, kv_data_type: torch.dtype,): minibatch = batch.minibatch self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type, bsz_tensors) ================================================ FILE: archive/ktransformers/models/custom_modeling_glm4_moe.py ================================================ """ Date: 2024-11-06 10:05:11 LastEditors: djw LastEditTime: 2024-11-13 07:50:51 """ import math from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F import math from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput from ktransformers.models.custom_cache import KGQACache from ktransformers.models.modeling_glm4_moe import Glm4MoeModel, Glm4MoePreTrainedModel from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn torch.set_grad_enabled(False) torch.set_default_dtype(torch.bfloat16) import flashinfer class KGlm4MoeForCausalLM(Glm4MoePreTrainedModel): cache: KGQACache use_cuda_graph = False def __init__( self, config: Glm4MoeConfig, cache, ): super().__init__(config) self.model = Glm4MoeModel(config) self.config = config self.cache = cache self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.attn = [None] * 100 def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0): self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device) def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): features = [] for i in range(batch.batch_size): tokens = batch.minibatch.tokens.contiguous() feature = ( self.model.embed_tokens(tokens.to(torch.device('cpu'))) .to(torch.bfloat16) .to(device=device) ) features.append(feature) return features def forward( self, batch: ForwardBatchInput | None = None, features: List[torch.Tensor] | None = None, bsz_tensors: torch.Tensor | None = None, num_tokens_tensors: torch.Tensor | None = None, page_idx: torch.Tensor | None = None, page_offset: torch.Tensor | None = None, cuda_graph_idx: int | None = 0 ) -> ForwardBatchOutput: current_stream = torch.cuda.current_stream() forward_batch_output = ForwardBatchOutput() hidden_states = features[0] self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0]) freqs_cis = self.model.rotary_emb(hidden_states.unsqueeze(0), batch.minibatch.position_ids.unsqueeze(0)) with torch.cuda.stream(current_stream): residual = torch.zeros_like(hidden_states) for i, decode_layer in enumerate(self.model.layers): hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual) hidden_states = decode_layer.self_attn(hidden_states, self.cache, freqs_cis, wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors, position_ids=batch.minibatch.position_ids ) hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual) if i < self.model.config.first_k_dense_replace: hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors) else: hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors, cuda_graph_idx) # hidden_states = hidden_states.squeeze(0) forward_batch_output = ForwardBatchOutput() with torch.cuda.stream(current_stream): local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors) forward_batch_output.logits.append(local_logit) return forward_batch_output def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors, num_q_heads: int, num_kv_heads: int, head_dim: int, page_size: int, causal: bool, q_data_type: torch.dtype, kv_data_type: torch.dtype, cuda_graph_idx: int = 0 ): minibatch = batch.minibatch self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type) ================================================ FILE: archive/ktransformers/models/custom_modeling_qwen2_moe.py ================================================ """ Date: 2024-11-06 10:05:11 LastEditors: djw LastEditTime: 2024-11-13 07:50:51 """ import math from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F import math from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput from ktransformers.models.custom_cache import KGQACache from ktransformers.models.modeling_qwen2_moe import Qwen2MoeModel, Qwen2MoePreTrainedModel from ktransformers.models.configuration_qwen2_moe import Qwen2MoeConfig from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn torch.set_grad_enabled(False) torch.set_default_dtype(torch.bfloat16) import flashinfer class KQwen2MoeForCausalLM(Qwen2MoePreTrainedModel): cache: KGQACache use_cuda_graph = False def __init__( self, config: Qwen2MoeConfig, cache, ): super().__init__(config) self.model = Qwen2MoeModel(config) self.config = config self.cache = cache self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.attn = [None] * 100 def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0): self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device) def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): features = [] for i in range(batch.batch_size): tokens = batch.minibatch.tokens.contiguous() feature = ( self.model.embed_tokens(tokens.to(torch.device('cpu'))) .to(torch.bfloat16) .to(device=device) ) features.append(feature) return features def forward( self, batch: ForwardBatchInput | None = None, features: List[torch.Tensor] | None = None, bsz_tensors: torch.Tensor | None = None, num_tokens_tensors: torch.Tensor | None = None, page_idx: torch.Tensor | None = None, page_offset: torch.Tensor | None = None, cuda_graph_idx: int | None = 0 ) -> ForwardBatchOutput: current_stream = torch.cuda.current_stream() forward_batch_output = ForwardBatchOutput() hidden_states = features[0] self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0]) with torch.cuda.stream(current_stream): residual = torch.zeros_like(hidden_states) for i, decode_layer in enumerate(self.model.layers): if self.model.transfer_map is not None and i in self.model.transfer_map: prev_stream = torch.cuda.current_stream() cur_device = self.model.transfer_map[i] if cur_device not in self.model.stream_device_map: self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) torch.cuda.set_device(cur_device) self.model.stream_device_map[cur_device].wait_stream(prev_stream) torch.cuda.set_stream(self.model.stream_device_map[cur_device]) hidden_states = hidden_states.to( self.model.transfer_map[i], non_blocking=True ) batch.minibatch.position_ids = ( batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True) if batch.minibatch.position_ids is not None else None ) hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual) hidden_states = decode_layer.self_attn(hidden_states, self.cache, position_ids=batch.minibatch.position_ids, wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors, page_idx=page_idx, page_offset=page_offset ) hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual) hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx) hidden_states = hidden_states.squeeze(0) forward_batch_output = ForwardBatchOutput() with torch.cuda.stream(current_stream): local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors) forward_batch_output.logits.append(local_logit) return forward_batch_output def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors, num_q_heads: int, num_kv_heads: int, head_dim: int, page_size: int, causal: bool, q_data_type: torch.dtype, kv_data_type: torch.dtype, cuda_graph_idx: int = 0 ): minibatch = batch.minibatch self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors,num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type) ================================================ FILE: archive/ktransformers/models/custom_modeling_qwen3_moe.py ================================================ """ Date: 2024-11-06 10:05:11 LastEditors: djw LastEditTime: 2024-11-13 07:50:51 """ import math from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F import math from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput from ktransformers.models.custom_cache import KGQACache from ktransformers.models.modeling_qwen3_moe import Qwen3MoeModel, Qwen3MoePreTrainedModel from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn torch.set_grad_enabled(False) torch.set_default_dtype(torch.bfloat16) import flashinfer class KQwen3MoeForCausalLM(Qwen3MoePreTrainedModel): cache: KGQACache use_cuda_graph = False def __init__( self, config: Qwen3MoeConfig, cache = None, ): super().__init__(config) self.model = Qwen3MoeModel(config) self.config = config self.cache = cache self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.attn = [None] * 100 def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0): self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device) def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): features = [] for i in range(batch.batch_size): tokens = batch.minibatch.tokens.contiguous() feature = ( self.model.embed_tokens(tokens.to(torch.device('cpu'))) .to(torch.bfloat16) .to(device=device) ) features.append(feature) return features def forward( self, batch: ForwardBatchInput | None = None, features: List[torch.Tensor] | None = None, bsz_tensors: torch.Tensor | None = None, num_tokens_tensors: torch.Tensor | None = None, page_idx: torch.Tensor | None = None, page_offset: torch.Tensor | None = None, cuda_graph_idx: int | None = 0 ) -> ForwardBatchOutput: current_stream = torch.cuda.current_stream() forward_batch_output = ForwardBatchOutput() hidden_states = features[0] self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0]) with torch.cuda.stream(current_stream): residual = torch.zeros_like(hidden_states) for i, decode_layer in enumerate(self.model.layers): if self.model.transfer_map is not None and i in self.model.transfer_map: prev_stream = torch.cuda.current_stream() cur_device = self.model.transfer_map[i] if cur_device not in self.model.stream_device_map: self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) torch.cuda.set_device(cur_device) self.model.stream_device_map[cur_device].wait_stream(prev_stream) torch.cuda.set_stream(self.model.stream_device_map[cur_device]) hidden_states = hidden_states.to( self.model.transfer_map[i], non_blocking=True ) batch.minibatch.position_ids = ( batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True) if batch.minibatch.position_ids is not None else None ) hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual) hidden_states = decode_layer.self_attn(hidden_states, self.cache, position_ids=batch.minibatch.position_ids, wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors, page_idx=page_idx, page_offset=page_offset ) hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual) hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx) hidden_states = hidden_states.squeeze(0) forward_batch_output = ForwardBatchOutput() with torch.cuda.stream(current_stream): local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors) forward_batch_output.logits.append(local_logit) return forward_batch_output def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors, num_q_heads: int, num_kv_heads: int, head_dim: int, page_size: int, causal: bool, q_data_type: torch.dtype, kv_data_type: torch.dtype, cuda_graph_idx: int = 0 ): minibatch = batch.minibatch self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type) ================================================ FILE: archive/ktransformers/models/custom_modeling_qwen3_next.py ================================================ """ Date: 2024-11-06 10:05:11 LastEditors: djw LastEditTime: 2024-11-13 07:50:51 """ import math from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F import math from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput from ktransformers.models.custom_cache import KGQACache from ktransformers.models.modeling_qwen3_next import Qwen3NextModel, Qwen3NextPreTrainedModel from ktransformers.models.configuration_qwen3_next import Qwen3NextConfig from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn torch.set_grad_enabled(False) torch.set_default_dtype(torch.bfloat16) import flashinfer class KQwen3NextForCausalLM(Qwen3NextPreTrainedModel): cache: KGQACache use_cuda_graph = False def __init__( self, config: Qwen3NextConfig, cache = None, ): super().__init__(config) self.model = Qwen3NextModel(config) self.config = config self.cache = cache self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.attn = [None] * 100 self.conv_states = [None for _ in range(config.num_hidden_layers)] self.recurrent_states = [None for _ in range(config.num_hidden_layers)] def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0): self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device) def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): features = [] for i in range(batch.batch_size): tokens = batch.minibatch.tokens.contiguous() feature = ( self.model.embed_tokens(tokens.to(torch.device('cpu'))) .to(torch.bfloat16) .to(device=device) ) features.append(feature) return features def reset_conv_states(self): for i in range(self.config.num_hidden_layers): self.conv_states[i] = None self.recurrent_states[i] = None def forward( self, batch: ForwardBatchInput | None = None, features: List[torch.Tensor] | None = None, bsz_tensors: torch.Tensor | None = None, num_tokens_tensors: torch.Tensor | None = None, page_idx: torch.Tensor | None = None, page_offset: torch.Tensor | None = None, cuda_graph_idx: int | None = 0 ) -> ForwardBatchOutput: current_stream = torch.cuda.current_stream() forward_batch_output = ForwardBatchOutput() q_len = features[0].size(0) if q_len > 1: self.reset_conv_states() hidden_states = features[0] self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0]) freqs_cis = self.model.rotary_emb(hidden_states.unsqueeze(0), batch.minibatch.position_ids.unsqueeze(0)) residual = torch.zeros_like(hidden_states) for i, decode_layer in enumerate(self.model.layers): hidden_states = hidden_states.contiguous().clone() # 断开别名 + 连续 residual = residual.contiguous().clone() hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual) hidden_states = hidden_states.contiguous() residual = residual.contiguous() if self.config.layer_types[i] != "linear_attention": hidden_states = decode_layer.self_attn(hidden_states, self.cache, freqs_cis, wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors) else: hs = hidden_states.unsqueeze(0).contiguous().clone() hs = decode_layer.linear_attn(hs, self.conv_states, self.recurrent_states, bsz_tensors=num_tokens_tensors) hidden_states = hs.squeeze(0).contiguous() hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual) hs2 = hidden_states.unsqueeze(0).contiguous().clone() hidden_states = decode_layer.mlp(hs2, num_tokens_tensors, cuda_graph_idx).squeeze(0).contiguous() if not torch.isfinite(hidden_states).all(): raise RuntimeError(f"NaN after layer {i}") # print(f"Layer {i} output: {hidden_states}") forward_batch_output = ForwardBatchOutput() with torch.cuda.stream(current_stream): local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors) forward_batch_output.logits.append(local_logit) return forward_batch_output def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors, num_q_heads: int, num_kv_heads: int, head_dim: int, page_size: int, causal: bool, q_data_type: torch.dtype, kv_data_type: torch.dtype, cuda_graph_idx: int = 0 ): minibatch = batch.minibatch self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type) ================================================ FILE: archive/ktransformers/models/custom_modeling_smallthinker.py ================================================ """ Date: 2024-11-06 10:05:11 LastEditors: djw LastEditTime: 2024-11-13 07:50:51 """ import math from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F import math from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput from ktransformers.models.custom_cache import KGQACache from ktransformers.models.modeling_smallthinker import SmallthinkerModel, SmallthinkerPreTrainedModel from ktransformers.models.configuration_smallthinker import SmallthinkerConfig from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn torch.set_grad_enabled(False) torch.set_default_dtype(torch.bfloat16) import flashinfer class KSmallThinkerForCausalLM(SmallthinkerPreTrainedModel): cache: KGQACache use_cuda_graph = False def __init__( self, config: SmallthinkerConfig, cache, ): super().__init__(config) self.model = SmallthinkerModel(config) self.config = config self.cache = cache self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.attn = [None] * 100 def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0): self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device) def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): features = [] for i in range(batch.batch_size): tokens = batch.minibatch.tokens.contiguous() feature = ( self.model.embed_tokens(tokens.to(torch.device('cpu'))) .to(torch.bfloat16) .to(device=device) ) features.append(feature) return features def forward( self, batch: ForwardBatchInput | None = None, features: List[torch.Tensor] | None = None, bsz_tensors: torch.Tensor | None = None, num_tokens_tensors: torch.Tensor | None = None, page_idx: torch.Tensor | None = None, page_offset: torch.Tensor | None = None, cuda_graph_idx: int | None = 0 ) -> ForwardBatchOutput: current_stream = torch.cuda.current_stream() forward_batch_output = ForwardBatchOutput() hidden_states = features[0] self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0]) freqs_cis = self.model.rotary_emb(hidden_states.unsqueeze(0), batch.minibatch.position_ids.unsqueeze(0)) with torch.cuda.stream(current_stream): residual = torch.zeros_like(hidden_states) for i, decode_layer in enumerate(self.model.layers): router_input = hidden_states hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual) hidden_states = decode_layer.self_attn(hidden_states, self.cache, freqs_cis if self.model.rope_layout[i] else None, wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors, position_ids=batch.minibatch.position_ids ) hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual) if not self.config.moe_layer_layout[i]: hidden_states = decode_layer.block_sparse_moe(hidden_states, num_tokens_tensors) else: hidden_states = decode_layer.block_sparse_moe(router_input, hidden_states, num_tokens_tensors, cuda_graph_idx) # hidden_states = hidden_states.squeeze(0) forward_batch_output = ForwardBatchOutput() with torch.cuda.stream(current_stream): local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors) forward_batch_output.logits.append(local_logit) return forward_batch_output def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors, num_q_heads: int, num_kv_heads: int, head_dim: int, page_size: int, causal: bool, q_data_type: torch.dtype, kv_data_type: torch.dtype, cuda_graph_idx: int = 0 ): minibatch = batch.minibatch self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type) ================================================ FILE: archive/ktransformers/models/modeling_deepseek.py ================================================ # coding=utf-8 ''' Description : Author : Boxin Zhang Version : 0.1.0 ''' # Adapted from # https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat-0628/blob/main/modeling_deepseek.py # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. # Copyright (c) 2024 by KVCache.AI, All Rights Reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch DeepSeek model.""" import math import warnings from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, ) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel from transformers.pytorch_utils import ( ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13, ) from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from transformers.utils.import_utils import is_torch_fx_available from .configuration_deepseek import DeepseekV2Config import torch.distributed as dist import numpy as np if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. if is_torch_fx_available(): if not is_torch_greater_or_equal_than_1_13: import torch.fx _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DeepseekV2Config" def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad( torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) ) return ( indices, cu_seqlens, max_seqlen_in_batch, ) class DeepseekV2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ DeepseekV2RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps self.hidden_size = hidden_size def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return (self.weight * hidden_states).to(input_dtype) ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm) # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->DeepseekV2 class DeepseekV2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2 class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): """DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" def __init__( self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, ): raise NotImplementedError("LinearScalingRotaryEmbedding is not supported now.") self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange( self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype ) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2 class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" def __init__( self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, ): raise NotImplementedError("DynamicNTKScalingRotaryEmbedding is not supported now.") self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len if seq_len > self.max_position_embeddings: base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / ( base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange( self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype ) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Inverse dim formula to find dim based on number of rotations def yarn_find_correction_dim( num_rotations, dim, base=10000, max_position_embeddings=2048 ): return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( 2 * math.log(base) ) # Find dim range bounds based on rotations def yarn_find_correction_range( low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 ): low = math.floor( yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) ) high = math.ceil( yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) ) return max(low, 0), min(high, dim - 1) # Clamp values just in case def yarn_get_mscale(scale=1, mscale=1): if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 def yarn_linear_ramp_mask(min, max, dim): if min == max: max += 0.001 # Prevent singularity linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding): def __init__( self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, original_max_position_embeddings=4096, beta_fast=32, beta_slow=1, mscale=1, mscale_all_dim=0, ): nn.Module.__init__(self) self.scaling_factor = scaling_factor self.original_max_position_embeddings = original_max_position_embeddings self.beta_fast = beta_fast self.beta_slow = beta_slow self.mscale = mscale self.mscale_all_dim = mscale_all_dim self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base freq_extra = 1.0 / ( self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) ) freq_inter = 1.0 / ( self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) ) low, high = yarn_find_correction_range( self.beta_fast, self.beta_slow, dim, self.base, self.original_max_position_embeddings, ) inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( device=device, dtype=torch.float32 ) inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask self.register_buffer("inv_freq", inv_freq, persistent=False) self._mscale = float( yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) ) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos()* self._mscale sin = emb.sin()* self._mscale return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) b, h, s, d = q.shape q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) b, h, s, d = k.shape k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class DeepseekV2MLP(nn.Module): def __init__(self, config, hidden_size=None, intermediate_size=None): super().__init__() self.config = config self.hidden_size = config.hidden_size if hidden_size is None else hidden_size self.intermediate_size = ( config.intermediate_size if intermediate_size is None else intermediate_size ) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): act = self.act_fn(self.gate_proj(x)) * self.up_proj(x) down_proj = self.down_proj(act) return down_proj class MoEGate(nn.Module): def __init__(self, config): super().__init__() self.config = config self.top_k = config.num_experts_per_tok self.n_routed_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor self.scoring_func = config.scoring_func self.alpha = config.aux_loss_alpha self.seq_aux = config.seq_aux self.topk_method = config.topk_method self.n_group = config.n_group self.topk_group = config.topk_group # topk selection algorithm self.norm_topk_prob = config.norm_topk_prob self.gating_dim = config.hidden_size self.weight = nn.Parameter( torch.empty((self.n_routed_experts, self.gating_dim)) ) self.reset_parameters() def reset_parameters(self) -> None: import torch.nn.init as init init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape ### compute gating score hidden_states = hidden_states.view(-1, h) logits = F.linear( hidden_states.type(torch.float32), self.weight.type(torch.float32), None ) if self.scoring_func == "softmax": scores = logits.softmax(dim=-1, dtype=torch.float32) else: raise NotImplementedError( f"insupportable scoring function for MoE gating: {self.scoring_func}" ) ### select top-k experts if self.topk_method == "greedy": topk_weight, topk_idx = torch.topk( scores, k=self.top_k, dim=-1, sorted=False ) elif self.topk_method == "group_limited_greedy": group_scores = ( scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values ) # [n, n_group] group_idx = torch.topk( group_scores, k=self.topk_group, dim=-1, sorted=False )[ 1 ] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] score_mask = ( group_mask.unsqueeze(-1) .expand( bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group ) .reshape(bsz * seq_len, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] topk_weight, topk_idx = torch.topk( tmp_scores, k=self.top_k, dim=-1, sorted=False ) ### norm gate to sum 1 if self.top_k > 1 and self.norm_topk_prob: denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 topk_weight = topk_weight / denominator else: topk_weight = topk_weight * self.routed_scaling_factor ### expert-level computation auxiliary loss if self.training and self.alpha > 0.0: scores_for_aux = scores aux_topk = self.top_k # always compute aux loss based on the naive greedy topk method topk_idx_for_aux_loss = topk_idx.view(bsz, -1) if self.seq_aux: scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) ce = torch.zeros( bsz, self.n_routed_experts, device=hidden_states.device ) ce.scatter_add_( 1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), ).div_(seq_len * aux_topk / self.n_routed_experts) aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum( dim=1 ).mean() * self.alpha else: mask_ce = F.one_hot( topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts ) ce = mask_ce.float().mean(0) Pi = scores_for_aux.mean(0) fi = ce * self.n_routed_experts aux_loss = (Pi * fi).sum() * self.alpha else: aux_loss = None return topk_idx, topk_weight, aux_loss class AddAuxiliaryLoss(torch.autograd.Function): """ The trick function of adding auxiliary (aux) loss, which includes the gradient of the aux loss during backpropagation. """ @staticmethod def forward(ctx, x, loss): assert loss.numel() == 1 ctx.dtype = loss.dtype ctx.required_aux_loss = loss.requires_grad return x @staticmethod def backward(ctx, grad_output): grad_loss = None if ctx.required_aux_loss: grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) return grad_output, grad_loss class DeepseekV2MoE(nn.Module): """ A mixed expert module containing shared experts. """ def __init__(self, config): super().__init__() self.config = config self.num_experts_per_tok = config.num_experts_per_tok if hasattr(config, "ep_size") and config.ep_size > 1: assert config.ep_size == dist.get_world_size() self.ep_size = config.ep_size self.experts_per_rank = config.n_routed_experts // config.ep_size self.ep_rank = dist.get_rank() self.experts = nn.ModuleList( [ ( DeepseekV2MLP( config, intermediate_size=config.moe_intermediate_size ) if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank else None ) for i in range(config.n_routed_experts) ] ) else: self.ep_size = 1 self.experts_per_rank = config.n_routed_experts self.ep_rank = 0 self.experts = nn.ModuleList( [ DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size) for i in range(config.n_routed_experts) ] ) self.gate = MoEGate(config) if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP( config=config, intermediate_size=intermediate_size ) def forward(self, hidden_states): identity = hidden_states orig_shape = hidden_states.shape topk_idx, topk_weight, aux_loss = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) flat_topk_idx = topk_idx.view(-1) if self.training: hidden_states = hidden_states.repeat_interleave( self.num_experts_per_tok, dim=0 ) y = torch.empty_like(hidden_states) for i, expert in enumerate(self.experts): y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) y = y.view(*orig_shape) y = AddAuxiliaryLoss.apply(y, aux_loss) else: y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) if self.config.n_shared_experts is not None: y = y + self.shared_experts(identity) return y @torch.no_grad() def moe_infer(self, x, topk_ids, topk_weight): cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) cnts.scatter_(1, topk_ids, 1) tokens_per_expert = cnts.sum(dim=0) idxs = topk_ids.view(-1).argsort() sorted_tokens = x[idxs // topk_ids.shape[1]] sorted_tokens_shape = sorted_tokens.shape if self.ep_size > 1: tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) tokens_per_expert_group = tokens_per_expert.new_empty( tokens_per_expert.shape[0] ) dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) output_splits = ( tokens_per_expert_group.view(self.ep_size, -1) .sum(1) .cpu() .numpy() .tolist() ) gathered_tokens = sorted_tokens.new_empty( tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] ) input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() dist.all_to_all( list(gathered_tokens.split(output_splits)), list(sorted_tokens.split(input_split_sizes)), ) tokens_per_expert_post_gather = tokens_per_expert_group.view( self.ep_size, self.experts_per_rank ).sum(dim=0) gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) s = 0 for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): gatherd_idxs[s : s + k] = i % self.experts_per_rank s += k gatherd_idxs = gatherd_idxs.argsort() sorted_tokens = gathered_tokens[gatherd_idxs] tokens_per_expert = tokens_per_expert_post_gather tokens_per_expert = tokens_per_expert.cpu().numpy() outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue expert = self.experts[i + self.ep_rank * self.experts_per_rank] tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = expert(tokens_for_this_expert) outputs.append(expert_out) start_idx = end_idx outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) if self.ep_size > 1: new_x = torch.empty_like(outs) new_x[gatherd_idxs] = outs gathered_tokens = new_x.new_empty(*sorted_tokens_shape) dist.all_to_all( list(gathered_tokens.split(input_split_sizes)), list(new_x.split(output_splits)), ) outs = gathered_tokens new_x = torch.empty_like(outs) new_x[idxs] = outs final_out = ( new_x.view(*topk_ids.shape, -1) .type(topk_weight.dtype) .mul_(topk_weight.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype) ) return final_out # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 class DeepseekV2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.q_lora_rank = config.q_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim self.kv_lora_rank = config.kv_lora_rank self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim self.is_causal = True if self.q_lora_rank is None: self.q_proj = nn.Linear( self.hidden_size, self.num_heads * self.q_head_dim, bias=False ) else: self.q_a_proj = nn.Linear( self.hidden_size, config.q_lora_rank, bias=config.attention_bias ) self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank) self.q_b_proj = nn.Linear( config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False ) self.kv_a_proj_with_mqa = nn.Linear( self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias=config.attention_bias, ) self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank) self.kv_b_proj = nn.Linear( config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias=False, ) self.o_proj = nn.Linear( self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias, ) self._init_rope() self.softmax_scale = self.q_head_dim ** (-0.5) if self.config.rope_scaling is not None: mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) scaling_factor = self.config.rope_scaling["factor"] if mscale_all_dim: mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) self.softmax_scale = self.softmax_scale * mscale * mscale def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = DeepseekV2RotaryEmbedding( self.qk_rope_head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( self.qk_rope_head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, ) elif scaling_type == "dynamic": self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding( self.qk_rope_head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, ) elif scaling_type == "yarn": kwargs = { key: self.config.rope_scaling[key] for key in [ "original_max_position_embeddings", "beta_fast", "beta_slow", "mscale", "mscale_all_dim", ] if key in self.config.rope_scaling } self.rotary_emb = DeepseekV2YarnRotaryEmbedding( self.qk_rope_head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, **kwargs, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return ( tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) .transpose(1, 2) .contiguous() ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) k_nope, value_states = torch.split( kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) kv_seq_len = value_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) attn_weights = ( torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale ) if attention_mask is not None: attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query_states.dtype) attn_weights = nn.functional.dropout( attn_weights, p=self.attention_dropout, training=self.training ) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2 class DeepseekV2FlashAttention2(DeepseekV2Attention): """ DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # DeepseekV2FlashAttention2 attention does not support output_attentions if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) # overwrite attention_mask with padding_mask attention_mask = kwargs.pop("padding_mask") output_attentions = False bsz, q_len, _ = hidden_states.size() q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) k_nope, value_states = torch.split( kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) kv_seq_len = value_states.shape[-2] kv_seq_len = value_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if self.q_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) dropout_rate = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (DeepseekV2RMSNorm handles it correctly) input_dtype = query_states.dtype if input_dtype == torch.float32: # Handle the case where the model is quantized if hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype elif torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.q_a_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, position_ids=position_ids, dropout=dropout_rate, softmax_scale=self.softmax_scale, ) if self.q_head_dim != self.v_head_dim: attn_output = attn_output[:, :, :, : self.v_head_dim] attn_output = attn_output.reshape( bsz, q_len, self.num_heads * self.v_head_dim ).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def _flash_attention_forward( self, query_states, key_states, value_states, attention_mask, query_length, position_ids, dropout=0.0, softmax_scale=None, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. # Args: query_states (`torch.Tensor`): Input query states to be passed to Flash Attention API key_states (`torch.Tensor`): Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`int`, *optional*): Attention dropout softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ if not self._flash_attn_uses_top_left_mask: causal = self.is_causal else: # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__. causal = self.is_causal and query_length != 1 # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] ( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens, ) = self._upad_input( query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, ) attn_output = pad_input( attn_output_unpad, indices_q, batch_size, query_length ) else: if query_length == 1: position_ids = position_ids.to(dtype=torch.int32).squeeze(1) attn_output = flash_attn_with_kvcache( query_states, key_states, value_states, cache_seqlens=position_ids, softmax_scale=softmax_scale, causal=causal, ) else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, ) return attn_output def _upad_input( self, query_layer, key_layer, value_layer, attention_mask, query_length ): indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) value_layer = index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k, ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( query_layer, attention_mask ) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) ATTENTION_CLASSES = { "eager": DeepseekV2Attention, "flash_attention_2": DeepseekV2FlashAttention2, } class DeepseekV2DecoderLayer(nn.Module): def __init__(self, config: DeepseekV2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( config=config, layer_idx=layer_idx ) self.mlp = ( DeepseekV2MoE(config) if ( config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0 ) else DeepseekV2MLP(config) ) self.input_layernorm = DeepseekV2RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.post_attention_layernorm = DeepseekV2RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs DeepseekV2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`DeepseekV2Config`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", DeepseekV2_START_DOCSTRING, ) class DeepseekV2PreTrainedModel(PreTrainedModel): config_class = DeepseekV2Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["DeepseekV2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_cache_class = True _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() DeepseekV2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @add_start_docstrings( "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", DeepseekV2_START_DOCSTRING, ) class DeepseekV2Model(DeepseekV2PreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] Args: config: DeepseekV2Config """ def __init__(self, config: DeepseekV2Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx ) self.layers = nn.ModuleList( [ DeepseekV2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." ) use_cache = False past_key_values_length = 0 if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) # embed positions hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = ( next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache ) if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing if attention_mask.max() != 0: raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = DeepseekV2Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC ) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states[:,-1:,:]).float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, use_cache=True, **kwargs, ): past_length = 0 # Omit tokens covered by past_key_values if past_key_values is not None: if isinstance(past_key_values, Cache): past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) if past_key_values.get_max_length() is not None else None ) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) elif use_cache: cache_position = cache_position[-input_length:] model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, "cache_position": cache_position, } ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple( past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past ), ) return reordered_past @add_start_docstrings( """ The DeepseekV2 Model transformer with a sequence classification head on top (linear layer). [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, DeepseekV2_START_DOCSTRING, ) class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = DeepseekV2Model(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError( "Cannot handle batch sizes > 1 if no padding token is defined." ) if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: sequence_lengths = ( torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 ).to(logits.device) else: sequence_lengths = -1 pooled_logits = logits[ torch.arange(batch_size, device=logits.device), sequence_lengths ] loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and ( labels.dtype == torch.long or labels.dtype == torch.int ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct( pooled_logits.view(-1, self.num_labels), labels.view(-1) ) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) ================================================ FILE: archive/ktransformers/models/modeling_deepseek_v3.py ================================================ # coding=utf-8 # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch DeepSeek model.""" import math import warnings from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.generation import GenerationMixin from transformers.modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, ) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel from transformers.pytorch_utils import ( ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13, ) from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from transformers.utils.import_utils import is_torch_fx_available from .configuration_deepseek_v3 import DeepseekV3Config import torch.distributed as dist import numpy as np if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. if is_torch_fx_available(): if not is_torch_greater_or_equal_than_1_13: import torch.fx _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) try: import torch_npu use_torch_npu = torch_npu.npu.is_available() except: use_torch_npu = False logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DeepseekV3Config" def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad( torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) ) return ( indices, cu_seqlens, max_seqlen_in_batch, ) class DeepseekV3RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ DeepseekV3RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps self.hidden_size = hidden_size def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm) class DeepseekV3RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype(), ) self.max_seq_len_cached = None def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange( self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype ) freqs = torch.outer(t, self.inv_freq.to(t.device)) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype), ) # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3 class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" def __init__( self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, ): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange( self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype ) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3 class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" def __init__( self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, ): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len if seq_len > self.max_position_embeddings: base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / ( base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange( self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype ) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Inverse dim formula to find dim based on number of rotations def yarn_find_correction_dim( num_rotations, dim, base=10000, max_position_embeddings=2048 ): return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( 2 * math.log(base) ) # Find dim range bounds based on rotations def yarn_find_correction_range( low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 ): low = math.floor( yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) ) high = math.ceil( yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) ) return max(low, 0), min(high, dim - 1) # Clamp values just in case def yarn_get_mscale(scale=1, mscale=1): if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 def yarn_linear_ramp_mask(min, max, dim): if min == max: max += 0.001 # Prevent singularity linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): def __init__( self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, original_max_position_embeddings=4096, beta_fast=32, beta_slow=1, mscale=1, mscale_all_dim=0, ): self.scaling_factor = scaling_factor self.original_max_position_embeddings = original_max_position_embeddings self.beta_fast = beta_fast self.beta_slow = beta_slow self.mscale = mscale self.mscale_all_dim = mscale_all_dim super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len dim = self.dim freq_extra = 1.0 / ( self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) ) freq_inter = 1.0 / ( self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) ) low, high = yarn_find_correction_range( self.beta_fast, self.beta_slow, dim, self.base, self.original_max_position_embeddings, ) inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( device=device, dtype=torch.float32 ) inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(seq_len, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) _mscale = float( yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) ) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer( "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False ) self.register_buffer( "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False ) # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`): The position indices of the tokens corresponding to the query and key tensors. For example, this can be used to pass offsetted position ids when working with a KV-cache. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim) b, h, s, d = q.shape q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) b, h, s, d = k.shape k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class DeepseekV3MLP(nn.Module): def __init__(self, config, hidden_size=None, intermediate_size=None): super().__init__() self.config = config self.hidden_size = config.hidden_size if hidden_size is None else hidden_size self.intermediate_size = ( config.intermediate_size if intermediate_size is None else intermediate_size ) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class MoEGate(nn.Module): def __init__(self, config): super().__init__() self.config = config self.top_k = config.num_experts_per_tok self.n_routed_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor self.scoring_func = config.scoring_func self.topk_method = config.topk_method self.n_group = config.n_group self.topk_group = config.topk_group # topk selection algorithm self.norm_topk_prob = config.norm_topk_prob self.gating_dim = config.hidden_size self.weight = nn.Parameter( torch.empty((self.n_routed_experts, self.gating_dim)) ) if self.topk_method == "noaux_tc": self.e_score_correction_bias = nn.Parameter( torch.empty((self.n_routed_experts)) ) self.reset_parameters() def reset_parameters(self) -> None: import torch.nn.init as init init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape ### compute gating score hidden_states = hidden_states.view(-1, h) logits = F.linear( hidden_states.type(torch.float32), self.weight.type(torch.float32), None ) if self.scoring_func == "sigmoid": scores = logits.sigmoid() else: raise NotImplementedError( f"insupportable scoring function for MoE gating: {self.scoring_func}" ) ### select top-k experts if self.topk_method == "noaux_tc": #assert not self.training scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) group_scores = ( scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1) ) # [n, n_group] group_idx = torch.topk( group_scores, k=self.topk_group, dim=-1, sorted=False )[ 1 ] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] score_mask = ( group_mask.unsqueeze(-1) .expand( bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group ) .reshape(bsz * seq_len, -1) ) # [n, e] tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] _, topk_idx = torch.topk( tmp_scores, k=self.top_k, dim=-1, sorted=False ) topk_weight = scores.gather(1, topk_idx) else: raise NotImplementedError( f"insupportable TopK function for MoE gating: {self.topk_method}" ) ### norm gate to sum 1 if self.top_k > 1 and self.norm_topk_prob: denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 topk_weight = topk_weight / denominator topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor return topk_idx, topk_weight class DeepseekV3MoE(nn.Module): """ A mixed expert module containing shared experts. """ def __init__(self, config): super().__init__() self.config = config self.num_experts_per_tok = config.num_experts_per_tok if hasattr(config, "ep_size") and config.ep_size > 1: assert config.ep_size == dist.get_world_size() self.ep_size = config.ep_size self.experts_per_rank = config.n_routed_experts // config.ep_size self.ep_rank = dist.get_rank() self.experts = nn.ModuleList( [ ( DeepseekV3MLP( config, intermediate_size=config.moe_intermediate_size ) if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank else None ) for i in range(config.n_routed_experts) ] ) else: self.ep_size = 1 self.experts_per_rank = config.n_routed_experts self.ep_rank = 0 self.experts = nn.ModuleList( [ DeepseekV3MLP( config, intermediate_size=config.moe_intermediate_size ) for i in range(config.n_routed_experts) ] ) self.gate = MoEGate(config) if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV3MLP( config=config, intermediate_size=intermediate_size ) def forward(self, hidden_states): identity = hidden_states orig_shape = hidden_states.shape topk_idx, topk_weight = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) flat_topk_idx = topk_idx.view(-1) if not self.training: y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) if self.config.n_shared_experts is not None: y = y + self.shared_experts(identity) return y @torch.no_grad() def moe_infer(self, x, topk_ids, topk_weight): cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) cnts.scatter_(1, topk_ids, 1) tokens_per_expert = cnts.sum(dim=0) idxs = topk_ids.view(-1).argsort() sorted_tokens = x[idxs // topk_ids.shape[1]] sorted_tokens_shape = sorted_tokens.shape if self.ep_size > 1: tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) tokens_per_expert_group = tokens_per_expert.new_empty( tokens_per_expert.shape[0] ) dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) output_splits = ( tokens_per_expert_group.view(self.ep_size, -1) .sum(1) .cpu() .numpy() .tolist() ) gathered_tokens = sorted_tokens.new_empty( tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] ) input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() dist.all_to_all( list(gathered_tokens.split(output_splits)), list(sorted_tokens.split(input_split_sizes)), ) tokens_per_expert_post_gather = tokens_per_expert_group.view( self.ep_size, self.experts_per_rank ).sum(dim=0) gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) s = 0 for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): gatherd_idxs[s : s + k] = i % self.experts_per_rank s += k gatherd_idxs = gatherd_idxs.argsort() sorted_tokens = gathered_tokens[gatherd_idxs] tokens_per_expert = tokens_per_expert_post_gather tokens_per_expert = tokens_per_expert.cpu().numpy() outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue expert = self.experts[i + self.ep_rank * self.experts_per_rank] tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = expert(tokens_for_this_expert) outputs.append(expert_out) start_idx = end_idx outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) if self.ep_size > 1: new_x = torch.empty_like(outs) new_x[gatherd_idxs] = outs gathered_tokens = new_x.new_empty(*sorted_tokens_shape) dist.all_to_all( list(gathered_tokens.split(input_split_sizes)), list(new_x.split(output_splits)), ) outs = gathered_tokens new_x = torch.empty_like(outs) new_x[idxs] = outs final_out = ( new_x.view(*topk_ids.shape, -1) .type(topk_weight.dtype) .mul_(topk_weight.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype) ) return final_out # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3 class DeepseekV3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.q_lora_rank = config.q_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim self.kv_lora_rank = config.kv_lora_rank self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim self.is_causal = True if self.q_lora_rank is None: self.q_proj = nn.Linear( self.hidden_size, self.num_heads * self.q_head_dim, bias=False ) else: self.q_a_proj = nn.Linear( self.hidden_size, config.q_lora_rank, bias=config.attention_bias ) self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) self.q_b_proj = nn.Linear( config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False ) self.kv_a_proj_with_mqa = nn.Linear( self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias=config.attention_bias, ) self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) self.kv_b_proj = nn.Linear( config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias=False, ) self.o_proj = nn.Linear( self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias, ) self._init_rope() self.softmax_scale = self.q_head_dim ** (-0.5) if self.config.rope_scaling is not None: mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) scaling_factor = self.config.rope_scaling["factor"] if mscale_all_dim: mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) self.softmax_scale = self.softmax_scale * mscale * mscale def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = DeepseekV3RotaryEmbedding( self.qk_rope_head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding( self.qk_rope_head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, ) elif scaling_type == "dynamic": self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding( self.qk_rope_head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, ) elif scaling_type == "yarn": kwargs = { key: self.config.rope_scaling[key] for key in [ "original_max_position_embeddings", "beta_fast", "beta_slow", "mscale", "mscale_all_dim", ] if key in self.config.rope_scaling } self.rotary_emb = DeepseekV3YarnRotaryEmbedding( self.qk_rope_head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, **kwargs, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return ( tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) .transpose(1, 2) .contiguous() ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) k_nope, value_states = torch.split( kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) kv_seq_len = value_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) attn_weights = ( torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale ) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) assert attention_mask is not None if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query_states.dtype) attn_weights = nn.functional.dropout( attn_weights, p=self.attention_dropout, training=self.training ) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3 class DeepseekV3FlashAttention2(DeepseekV3Attention): """ DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # DeepseekV3FlashAttention2 attention does not support output_attentions if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) # overwrite attention_mask with padding_mask attention_mask = kwargs.pop("padding_mask") output_attentions = False bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) k_nope, value_states = torch.split( kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) kv_seq_len = value_states.shape[-2] kv_seq_len = value_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if self.q_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) dropout_rate = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (DeepseekV3RMSNorm handles it correctly) input_dtype = query_states.dtype if input_dtype == torch.float32: # Handle the case where the model is quantized if hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype elif torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = ( self.q_proj.weight.dtype if self.q_lora_rank is None else self.q_a_proj.weight.dtype ) logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, softmax_scale=self.softmax_scale, ) if self.q_head_dim != self.v_head_dim: attn_output = attn_output[:, :, :, : self.v_head_dim] attn_output = attn_output.reshape( bsz, q_len, self.num_heads * self.v_head_dim ).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def _flash_attention_forward( self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. Args: query_states (`torch.Tensor`): Input query states to be passed to Flash Attention API key_states (`torch.Tensor`): Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`int`, *optional*): Attention dropout softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ if not self._flash_attn_uses_top_left_mask: causal = self.is_causal else: # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__. causal = self.is_causal and query_length != 1 # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] ( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens, ) = self._upad_input( query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, ) attn_output = pad_input( attn_output_unpad, indices_q, batch_size, query_length ) else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, ) return attn_output def _upad_input( self, query_layer, key_layer, value_layer, attention_mask, query_length ): indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) value_layer = index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k, ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( query_layer, attention_mask ) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) ATTENTION_CLASSES = { "eager": DeepseekV3Attention, "flash_attention_2": DeepseekV3FlashAttention2, } class DeepseekV3DecoderLayer(nn.Module): def __init__(self, config: DeepseekV3Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( config=config, layer_idx=layer_idx ) self.mlp = ( DeepseekV3MoE(config) if ( config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0 ) else DeepseekV3MLP(config) ) self.input_layernorm = DeepseekV3RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.post_attention_layernorm = DeepseekV3RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, is_prefill: Optional[bool] = False, **kwargs, ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, is_prefill=is_prefill, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs DeepseekV3_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`DeepseekV3Config`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", DeepseekV3_START_DOCSTRING, ) class DeepseekV3PreTrainedModel(PreTrainedModel): config_class = DeepseekV3Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["DeepseekV3DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() DeepseekV3_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @add_start_docstrings( "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", DeepseekV3_START_DOCSTRING, ) class DeepseekV3Model(DeepseekV3PreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`] Args: config: DeepseekV3Config """ def __init__(self, config: DeepseekV3Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx ) self.layers = nn.ModuleList( [ DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") past_key_values_length = 0 if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device, ) position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = ( attention_mask if (attention_mask is not None and 0 in attention_mask) else None ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, ) # embed positions hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = ( next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache ) if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing if attention_mask.max() != 0: raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = DeepseekV3Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC ) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, is_prefill: Optional[bool] = False, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, is_prefill=is_prefill, ) hidden_states = outputs[0] if use_torch_npu: hidden_states_without_norm = outputs[-1] logits = self.lm_head(hidden_states) else: logits = self.lm_head(hidden_states[:,-1:,:]) logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: if use_torch_npu: output = (logits,) + outputs[1:] + (hidden_states_without_norm,) else: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs, ): if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as # input) if ( attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1] ): input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple( past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past ), ) return reordered_past @add_start_docstrings( """ The DeepseekV3 Model transformer with a sequence classification head on top (linear layer). [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, DeepseekV3_START_DOCSTRING, ) class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = DeepseekV3Model(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError( "Cannot handle batch sizes > 1 if no padding token is defined." ) if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: sequence_lengths = ( torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 ).to(logits.device) else: sequence_lengths = -1 pooled_logits = logits[ torch.arange(batch_size, device=logits.device), sequence_lengths ] loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and ( labels.dtype == torch.long or labels.dtype == torch.int ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct( pooled_logits.view(-1, self.num_labels), labels.view(-1) ) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) ================================================ FILE: archive/ktransformers/models/modeling_glm4_moe.py ================================================ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/glm4_moe/modular_glm4_moe.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_glm4_moe.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional, Union import torch import torch.nn.functional as F from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin # from transformers.integrations import use_kernel_forward_from_hub from transformers.masking_utils import create_causal_mask from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack # from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple from transformers.utils import auto_docstring, can_return_tuple # from transformers.utils.generic import check_model_inputs from .configuration_glm4_moe import Glm4MoeConfig def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, # **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) # Keep half or full tensor for later concatenation rotary_dim = cos.shape[-1] q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] # Apply rotary embeddings on the first half or full tensor q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) # Concatenate back to full shape q_embed = torch.cat([q_embed, q_pass], dim=-1) k_embed = torch.cat([k_embed, k_pass], dim=-1) return q_embed, k_embed class Glm4MoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Glm4MoeConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.use_qk_norm = config.use_qk_norm if self.use_qk_norm: self.q_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape) key_states = self.k_proj(hidden_states).view(hidden_shape) value_states = self.v_proj(hidden_states).view(hidden_shape) if self.use_qk_norm: # main diff from Llama query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Glm4MoeMLP(nn.Module): def __init__(self, config, hidden_size=None, intermediate_size=None): super().__init__() self.config = config self.hidden_size = config.hidden_size if hidden_size is None else hidden_size self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class Glm4MoeTopkRouter(nn.Module): def __init__(self, config: Glm4MoeConfig): super().__init__() self.config = config self.top_k = config.num_experts_per_tok self.n_routed_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor self.n_group = config.n_group self.topk_group = config.topk_group self.norm_topk_prob = config.norm_topk_prob self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32)) @torch.no_grad() def get_topk_indices(self, scores): scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) group_scores = ( scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) .topk(2, dim=-1)[0] .sum(dim=-1) ) group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] group_mask = torch.zeros_like(group_scores) group_mask.scatter_(1, group_idx, 1) score_mask = ( group_mask.unsqueeze(-1) .expand(-1, self.n_group, self.n_routed_experts // self.n_group) .reshape(-1, self.n_routed_experts) ) scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] return topk_indices def forward(self, hidden_states): hidden_states = hidden_states.view(-1, self.config.hidden_size) router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) scores = router_logits.sigmoid() topk_indices = self.get_topk_indices(scores) topk_weights = scores.gather(1, topk_indices) if self.norm_topk_prob: denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 topk_weights /= denominator topk_weights = topk_weights * self.routed_scaling_factor return topk_indices, topk_weights # @use_kernel_forward_from_hub("RMSNorm") class Glm4MoeRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Glm4MoeRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.hidden_size = hidden_size self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Glm4MoeMoE(nn.Module): """ A mixed expert module containing shared experts. """ def __init__(self, config): super().__init__() self.config = config self.experts = nn.ModuleList( [ Glm4MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_routed_experts) ] ) self.gate = Glm4MoeTopkRouter(config) self.shared_experts = Glm4MoeMLP( config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts ) def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): r""" CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused to not have to do a loop here (deepseek has 256 experts soooo yeah). """ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) expert_mask = expert_mask.permute(2, 0, 1) for expert_idx in range(len(self.experts)): expert = self.experts[expert_idx] mask = expert_mask[expert_idx] token_indices, weight_indices = torch.where(mask) if token_indices.numel() > 0: expert_weights = topk_weights[token_indices, weight_indices] expert_input = hidden_states[token_indices] expert_output = expert(expert_input) weighted_output = expert_output * expert_weights.unsqueeze(-1) final_hidden_states.index_add_(0, token_indices, weighted_output) # in original deepseek, the output of the experts are gathered once we leave this module # thus the moe module is itelsf an IsolatedParallel module # and all expert are "local" meaning we shard but we don't gather return final_hidden_states.type(hidden_states.dtype) def forward(self, hidden_states): residuals = hidden_states orig_shape = hidden_states.shape topk_indices, topk_weights = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) return hidden_states class Glm4MoeDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Glm4MoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Glm4MoeAttention(config=config, layer_idx=layer_idx) if layer_idx >= config.first_k_dense_replace: self.mlp = Glm4MoeMoE(config) else: self.mlp = Glm4MoeMLP(config) self.input_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC # **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states @auto_docstring class Glm4MoePreTrainedModel(PreTrainedModel): config: Glm4MoeConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Glm4MoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_static_cache = False _supports_attention_backend = True _can_record_outputs = { "hidden_states": Glm4MoeDecoderLayer, "attentions": Glm4MoeAttention, } def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, Glm4MoeRMSNorm): module.weight.data.fill_(1.0) elif isinstance(module, Glm4MoeTopkRouter): module.weight.data.normal_(mean=0.0, std=std) class Glm4MoeRotaryEmbedding(nn.Module): def __init__(self, config: Glm4MoeConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @auto_docstring class Glm4MoeModel(Glm4MoePreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"model\.layers\.92.*", r"model\.layers\.46.*"] def __init__(self, config: Glm4MoeConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [Glm4MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Glm4MoeRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value # @check_model_inputs @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, # **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position: torch.Tensor = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = create_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers[: self.config.num_hidden_layers]: hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) @auto_docstring class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) self.model = Glm4MoeModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, # **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from transformers import AutoTokenizer, Glm4MoeForCausalLM >>> model = Glm4MoeForCausalLM.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf") >>> tokenizer = AutoTokenizer.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, # **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) __all__ = ["Glm4MoePreTrainedModel", "Glm4MoeModel", "Glm4MoeForCausalLM"] ================================================ FILE: archive/ktransformers/models/modeling_llama.py ================================================ # coding=utf-8 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import PreTrainedModel from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_llama import LlamaConfig logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LlamaConfig" class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) class LlamaRotaryEmbedding(nn.Module): def __init__( self, dim=None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, rope_type="default", config: Optional[LlamaConfig] = None, ): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base self.device = device self.scaling_factor = scaling_factor self.rope_type = rope_type self.config = config # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} if config is None: logger.warning_once( "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " "`config` argument. All other arguments will be removed in v4.45" ) self.rope_kwargs = { "rope_type": rope_type, "factor": scaling_factor, "dim": dim, "base": base, "max_position_embeddings": max_position_embeddings, } self.rope_type = rope_type self.max_seq_len_cached = max_position_embeddings self.original_max_seq_len = max_position_embeddings else: # BC: "rope_type" was originally "type" if config.rope_scaling is not None: self.rope_type = config.rope_scaling.get( "rope_type", config.rope_scaling.get("type") ) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, **self.rope_kwargs ) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 # seq_len = position_ids[0, -1] + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len, **self.rope_kwargs ) self.register_buffer( "inv_freq", inv_freq, persistent=False ) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if ( seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len ): # reset self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): # if "dynamic" in self.rope_type: # self._dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block inv_freq_expanded = ( self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) ) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = ( device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" ) with torch.autocast(device_type=device_type, enabled=False): freqs = ( inv_freq_expanded.float() @ position_ids_expanded.float() ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention cos = cos * self.attention_scaling sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" def __init__(self, *args, **kwargs): logger.warning_once( "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." ) kwargs["rope_type"] = "linear" super().__init__(*args, **kwargs) class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" def __init__(self, *args, **kwargs): logger.warning_once( "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " "__init__)." ) kwargs["rope_type"] = "dynamic" super().__init__(*args, **kwargs) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class LlamaMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear( self.hidden_size, self.intermediate_size, bias=config.mlp_bias ) self.up_proj = nn.Linear( self.hidden_size, self.intermediate_size, bias=config.mlp_bias ) self.down_proj = nn.Linear( self.intermediate_size, self.hidden_size, bias=config.mlp_bias ) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): if self.config.pretraining_tp > 1: slice = self.intermediate_size // self.config.pretraining_tp gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) up_proj_slices = self.up_proj.weight.split(slice, dim=0) down_proj_slices = self.down_proj.weight.split(slice, dim=1) gate_proj = torch.cat( [ F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp) ], dim=-1, ) up_proj = torch.cat( [ F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp) ], dim=-1, ) intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) down_proj = [ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) ] down_proj = sum(down_proj) else: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear( self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.v_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.o_proj = nn.Linear( self.hidden_size, self.hidden_size, bias=config.attention_bias ) # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) self.rotary_emb = LlamaRotaryEmbedding(config=self.config) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[ Tuple[torch.Tensor, torch.Tensor] ] = None, # will become mandatory in v4.45 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: key_value_slicing = ( self.num_key_value_heads * self.head_dim ) // self.config.pretraining_tp query_slices = self.q_proj.weight.split( (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 ) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [ F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp) ] query_states = torch.cat(query_states, dim=-1) key_states = [ F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp) ] key_states = torch.cat(key_states, dim=-1) value_states = [ F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp) ] value_states = torch.cat(value_states, dim=-1) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view( bsz, q_len, self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul( query_states, key_states.transpose(2, 3) ) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query_states.dtype) attn_weights = nn.functional.dropout( attn_weights, p=self.attention_dropout, training=self.training ) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1) if self.config.pretraining_tp > 1: attn_output = attn_output.split( self.hidden_size // self.config.pretraining_tp, dim=2 ) o_proj_slices = self.o_proj.weight.split( self.hidden_size // self.config.pretraining_tp, dim=1 ) attn_output = sum( [ F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp) ] ) else: attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class LlamaFlashAttention2(LlamaAttention): """ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[ Tuple[torch.Tensor, torch.Tensor] ] = None, # will become mandatory in v4.45 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" ) output_attentions = False bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view( bsz, q_len, self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) dropout_rate = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_output = _flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, sliding_window=getattr(self, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, is_causal=self.is_causal, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class LlamaSdpaAttention(LlamaAttention): """ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ # Adapted from LlamaAttention.forward def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[ Tuple[torch.Tensor, torch.Tensor] ] = None, # will become mandatory in v4.45 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view( bsz, q_len, self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value LLAMA_ATTENTION_CLASSES = { "eager": LlamaAttention, "flash_attention_2": LlamaFlashAttention2, "sdpa": LlamaSdpaAttention, } class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( config=config, layer_idx=layer_idx ) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[ Tuple[torch.Tensor, torch.Tensor] ] = None, # will become mandatory in v4.45 **kwargs, ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs LLAMA_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`LlamaConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) class LlamaPreTrainedModel(PreTrainedModel): config_class = LlamaConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() LLAMA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ @add_start_docstrings( "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) class LlamaModel(LlamaPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] Args: config: LlamaConfig """ def __init__(self, config: LlamaConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx ) self.layers = nn.ModuleList( [ LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False if ( use_cache and not isinstance(past_key_values, Cache) and not self.training ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: next_cache = next_cache.to_legacy_cache() if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if ( self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing if attention_mask.max() != 0: raise ValueError( "Custom 4D attention mask should be passed in inverted form with max==0`" ) causal_mask = attention_mask else: causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device, ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange( target_length, device=device ) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand( input_tensor.shape[0], 1, -1, -1 ) if attention_mask is not None: causal_mask = ( causal_mask.clone() ) # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = ( causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[ :, :, :, :mask_length ].masked_fill(padding_mask, min_dtype) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended( causal_mask, min_dtype ) return causal_mask class LlamaForCausalLM(LlamaPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC ) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, LlamaForCausalLM >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split( self.vocab_size // self.config.pretraining_tp, dim=0 ) logits = [ F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp) ] logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) # logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif ( input_ids.shape[1] != cache_position.shape[0] ): # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = { "input_ids": input_ids.contiguous() } # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, } ) return model_inputs @add_start_docstrings( """ The LLaMa Model transformer with a sequence classification head on top (linear layer). [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, LLAMA_START_DOCSTRING, ) class LlamaForSequenceClassification(LlamaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = LlamaModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError( "Cannot handle batch sizes > 1 if no padding token is defined." ) if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility sequence_lengths = ( torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 ) sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 pooled_logits = logits[ torch.arange(batch_size, device=logits.device), sequence_lengths ] loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and ( labels.dtype == torch.long or labels.dtype == torch.int ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct( pooled_logits.view(-1, self.num_labels), labels.view(-1) ) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) @add_start_docstrings( """ The Llama Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). """, LLAMA_START_DOCSTRING, ) class LlamaForQuestionAnswering(LlamaPreTrainedModel): base_model_prefix = "transformer" # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama def __init__(self, config): super().__init__(config) self.transformer = LlamaModel(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.transformer.embed_tokens def set_input_embeddings(self, value): self.transformer.embed_tokens = value @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.transformer( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1).to(start_logits.device) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1).to(end_logits.device) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @add_start_docstrings( """ The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, LLAMA_START_DOCSTRING, ) class LlamaForTokenClassification(LlamaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = LlamaModel(config) if getattr(config, "classifier_dropout", None) is not None: classifier_dropout = config.classifier_dropout elif getattr(config, "hidden_dropout", None) is not None: classifier_dropout = config.hidden_dropout else: classifier_dropout = 0.1 self.dropout = nn.Dropout(classifier_dropout) self.score = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) ================================================ FILE: archive/ktransformers/models/modeling_mixtral.py ================================================ # coding=utf-8 ''' Description : Author : kkk1nak0 Date : 2024-07-29 02:58:57 Version : 1.0.0 LastEditors : kkk1nak0 LastEditTime : 2024-08-02 06:08:34 ''' # Adapted from # https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. # Copyright (c) 2024 by KVCache.AI, All Rights Reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Mixtral model.""" import inspect import math from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_causal_attention_mask, ) from transformers.modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13 from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, logging, replace_return_docstrings, ) from transformers.utils.import_utils import is_torch_fx_available from transformers.models.mixtral.configuration_mixtral import MixtralConfig if is_flash_attn_2_available(): from flash_attn import flash_attn_varlen_func, flash_attn_func, flash_attn_with_kvcache from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. if is_torch_fx_available(): if not is_torch_greater_or_equal_than_1_13: import torch.fx _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MixtralConfig" def load_balancing_loss_func( gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None ) -> float: r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between experts is too unbalanced. Args: gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of shape [batch_size X sequence_length, num_experts]. attention_mask (`torch.Tensor`, None): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. num_experts (`int`, *optional*): Number of experts Returns: The auxiliary loss. """ if gate_logits is None or not isinstance(gate_logits, tuple): return 0 if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) if attention_mask is None: # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.mean(expert_mask.float(), dim=0) # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) else: batch_size, sequence_length = attention_mask.shape num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( attention_mask[None, :, :, None, None] .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) .reshape(-1, top_k, num_experts) .to(compute_device) ) # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( expert_attention_mask, dim=0 ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) .reshape(-1, num_experts) .to(compute_device) ) # Compute the average probability of routing to these experts router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( router_per_expert_attention_mask, dim=0 ) overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, max_seqlen_in_batch, ) # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral class MixtralRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ MixtralRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral # TODO @longjie no longer copied from Mistral after static cache class MixtralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self.max_seq_len_cached = max_position_embeddings @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb # TODO @longjie no longer copied from Mistral after static cache def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`): The position indices of the tokens corresponding to the query and key tensors. For example, this can be used to pass offsetted position ids when working with a KV-cache. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) # copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral # TODO @longjie no longer copied from Mistral after static cache class MixtralAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer and "Generating Long Sequences with Sparse Transformers". """ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True self.attention_dropout = config.attention_dropout if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.rotary_emb = MixtralRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value # copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral # TODO @longjie no longer copied from Mistral after static cache class MixtralFlashAttention2(MixtralAttention): """ Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ): bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) use_sliding_windows = ( _flash_supports_window_size and getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window and self.config.use_sliding_window ) if not _flash_supports_window_size: logger.warning_once( "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" " make sure to upgrade flash-attn library." ) if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 if ( getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window and cache_has_contents ): slicing_tokens = 1 - self.config.sliding_window past_key = past_key_value[self.layer_idx][0] past_value = past_key_value[self.layer_idx][1] past_key = past_key[:, :, slicing_tokens:, :].contiguous() past_value = past_value[:, :, slicing_tokens:, :].contiguous() if past_key.shape[-2] != self.config.sliding_window - 1: raise ValueError( f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" f" {past_key.shape}" ) if attention_mask is not None: attention_mask = attention_mask[:, slicing_tokens:] attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # we slice the states for static kv cache to be supported in FA2. Not sure it's a must as compile fails # for bsz == 1, avoid using slice to capture cuda graph if cache_position is not None and q_len > 1: key_states = key_states[:, :, : cache_position[-1] + 1, :] value_states = value_states[:, :, : cache_position[-1] + 1, :] # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) dropout_rate = 0.0 if not self.training else self.attention_dropout # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) # Reashape to the expected shape for Flash Attention query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, position_ids=position_ids, dropout=dropout_rate, sliding_window=getattr(self.config, "sliding_window", None), is_causal=self.is_causal, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def _flash_attention_forward( self, query_states, key_states, value_states, attention_mask, q_len, position_ids, dropout, sliding_window, is_causal, softmax_scale=None, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. Args: query_states (`torch.Tensor`): Input query states to be passed to Flash Attention API key_states (`torch.Tensor`): Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`float`): Attention dropout """ # Decide whether to use SWA or not by layer index. # if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: # use_sliding_windows = False use_sliding_windows = False # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( query_states, key_states, value_states, attention_mask, q_len ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens if not use_sliding_windows: attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=is_causal, ) else: attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=is_causal, window_size=(self.config.sliding_window, self.config.sliding_window), ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len) else: if not use_sliding_windows: if q_len == 1: position_ids = position_ids.to(dtype=torch.int32).squeeze(1) attn_output = flash_attn_with_kvcache( query_states, key_states, value_states, cache_seqlens=position_ids, softmax_scale=softmax_scale, causal=is_causal, ) else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=is_causal, ) else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=is_causal, window_size=(self.config.sliding_window, self.config.sliding_window), ) return attn_output # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape # On the first iteration we need to properly re-create the padding mask # by slicing it on the proper place if kv_seq_len != attention_mask.shape[-1]: attention_mask_num_tokens = attention_mask.shape[-1] attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) # copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral # TODO @longjie no longer copied from Mistral after static cache class MixtralSdpaAttention(MixtralAttention): """ Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ # Adapted from MixtralAttention.forward def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( "MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value MIXTRAL_ATTENTION_CLASSES = { "eager": MixtralAttention, "flash_attention_2": MixtralFlashAttention2, "sdpa": MixtralSdpaAttention, } class MixtralBlockSparseTop2MLP(nn.Module): def __init__(self, config: MixtralConfig): super().__init__() self.ffn_dim = config.intermediate_size self.hidden_dim = config.hidden_size self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # gate self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) # down self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # up self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states): current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) current_hidden_states = self.w2(current_hidden_states) return current_hidden_states class MixtralSparseMoeBlock(nn.Module): """ This implementation is strictly equivalent to standard MoE with full capacity (no dropped tokens). It's faster since it formulates MoE operations in terms of block-sparse operations to accomodate imbalanced assignments of tokens to experts, whereas standard MoE either (1) drop tokens at the cost of reduced performance or (2) set capacity factor to number of experts and thus waste computation and memory on padding. """ def __init__(self, config): super().__init__() self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) # Jitter parameters self.jitter_noise = config.router_jitter_noise def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape if self.training and self.jitter_noise > 0: hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits class MixtralDecoderLayer(nn.Module): def __init__(self, config: MixtralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) self.block_sparse_moe = MixtralSparseMoeBlock(config) self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_router_logits (`bool`, *optional*): Whether or not to return the logits of all the routers. They are useful for computing the router loss, and should not be returned during inference. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, router_logits = self.block_sparse_moe(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) if output_router_logits: outputs += (router_logits,) return outputs MIXTRAL_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`MixtralConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", MIXTRAL_START_DOCSTRING, ) # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral class MixtralPreTrainedModel(PreTrainedModel): config_class = MixtralConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MixtralDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() MIXTRAL_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. output_router_logits (`bool`, *optional*): Whether or not to return the logits of all the routers. They are useful for computing the router loss, and should not be returned during inference. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ @add_start_docstrings( "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", MIXTRAL_START_DOCSTRING, ) # copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral # TODO @longjie no longer copied from Mistral after static cache class MixtralModel(MixtralPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] Args: config: MixtralConfig """ def __init__(self, config: MixtralConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self._attn_implementation = config._attn_implementation self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value # Ignore copy @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, MoeModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False use_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache) and not self.training: use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, output_router_logits, use_cache, cache_position, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) if output_router_logits: all_router_logits += (layer_outputs[-1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] if v is not None ) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing if attention_mask.max() != 0: raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask class MixtralForCausalLM(MixtralPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = MixtralModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) # Ignore copy def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, MixtralForCausalLM >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok, attention_mask, ) if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device if not return_dict: output = (logits,) + outputs[1:] if output_router_logits: output = (aux_loss,) + output return (loss,) + output if loss is not None else output return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, output_router_logits=False, position_ids=None, use_cache=True, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, "output_router_logits": output_router_logits, } ) return model_inputs @add_start_docstrings( """ The Mixtral Model transformer with a sequence classification head on top (linear layer). [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, MIXTRAL_START_DOCSTRING, ) # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL class MixtralForSequenceClassification(MixtralPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = MixtralModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) @add_start_docstrings( """ The Mixtral Model transformer with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, MIXTRAL_START_DOCSTRING, ) # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL class MixtralForTokenClassification(MixtralPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = MixtralModel(config) if getattr(config, "classifier_dropout", None) is not None: classifier_dropout = config.classifier_dropout elif getattr(config, "hidden_dropout", None) is not None: classifier_dropout = config.hidden_dropout else: classifier_dropout = 0.1 self.dropout = nn.Dropout(classifier_dropout) self.score = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) ================================================ FILE: archive/ktransformers/models/modeling_qwen2_moe.py ================================================ # coding=utf-8 ''' Description : Author : Boxin Zhang Version : 0.1.0 ''' # Adapted from # https://github.com/huggingface/transformers/blob/v4.42.3/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # Copyright (c) 2024 by KVCache.AI, All Rights Reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Qwen2MoE model.""" import inspect import math from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_attn_mask_utils import ( AttentionMaskConverter, ) from transformers.modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "Qwen/Qwen1.5-MoE-A2.7B" _CONFIG_FOR_DOC = "Qwen2MoeConfig" # Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func def load_balancing_loss_func( gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None ) -> float: r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between experts is too unbalanced. Args: gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of shape [batch_size X sequence_length, num_experts]. attention_mask (`torch.Tensor`, None): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. num_experts (`int`, *optional*): Number of experts Returns: The auxiliary loss. """ if gate_logits is None or not isinstance(gate_logits, tuple): return 0 if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) if attention_mask is None: # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.mean(expert_mask.float(), dim=0) # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) else: batch_size, sequence_length = attention_mask.shape num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( attention_mask[None, :, :, None, None] .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) .reshape(-1, top_k, num_experts) .to(compute_device) ) # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( expert_attention_mask, dim=0 ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) .reshape(-1, num_experts) .to(compute_device) ) # Compute the average probability of routing to these experts router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( router_per_expert_attention_mask, dim=0 ) overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, max_seqlen_in_batch, ) # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2Moe class Qwen2MoeRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Qwen2MoeRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2Moe class Qwen2MoeRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed # Modified from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2Moe class Qwen2MoeMLP(nn.Module): def __init__(self, config, intermediate_size=None): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe class Qwen2MoeAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer and "Generating Long Sequences with Sparse Transformers". """ def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True self.attention_dropout = config.attention_dropout if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.rotary_emb = Qwen2MoeRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe class Qwen2MoeFlashAttention2(Qwen2MoeAttention): """ Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom config.max_window_layers layers. """ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ): bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) use_sliding_windows = ( _flash_supports_window_size and getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window and self.config.use_sliding_window ) if not _flash_supports_window_size: logger.warning_once( "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" " make sure to upgrade flash-attn library." ) if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 if ( getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window and cache_has_contents ): slicing_tokens = 1 - self.config.sliding_window past_key = past_key_value[self.layer_idx][0] past_value = past_key_value[self.layer_idx][1] past_key = past_key[:, :, slicing_tokens:, :].contiguous() past_value = past_value[:, :, slicing_tokens:, :].contiguous() if past_key.shape[-2] != self.config.sliding_window - 1: raise ValueError( f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" f" {past_key.shape}" ) if attention_mask is not None: attention_mask = attention_mask[:, slicing_tokens:] attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # we slice the states for static kv cache to be supported in FA2. Not sure it's a must as compile fails # for bsz == 1, avoid using slice to capture cuda graph if cache_position is not None and q_len > 1: key_states = key_states[:, :, : cache_position[-1] + 1, :] value_states = value_states[:, :, : cache_position[-1] + 1, :] # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) dropout_rate = 0.0 if not self.training else self.attention_dropout # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) # Reashape to the expected shape for Flash Attention query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, position_ids=position_ids, dropout=dropout_rate, use_sliding_windows=use_sliding_windows, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def _flash_attention_forward( self, query_states, key_states, value_states, attention_mask, query_length, position_ids, dropout=0.0, softmax_scale=None, use_sliding_windows=False, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. Args: query_states (`torch.Tensor`): Input query states to be passed to Flash Attention API key_states (`torch.Tensor`): Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`float`): Attention dropout softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) use_sliding_windows (`bool`, *optional*): Whether to activate sliding window attention. """ if not self._flash_attn_uses_top_left_mask: causal = self.is_causal else: # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. causal = self.is_causal and query_length != 1 # Decide whether to use SWA or not by layer index. if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: use_sliding_windows = False # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens if not use_sliding_windows: attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, ) else: attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, window_size=(self.config.sliding_window, self.config.sliding_window), ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: if not use_sliding_windows: if query_length == 1: position_ids = position_ids.to(dtype=torch.int32).squeeze(1) attn_output = flash_attn_with_kvcache( query_states, key_states, value_states, cache_seqlens=position_ids, softmax_scale=softmax_scale, causal=causal, ) else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, ) else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, window_size=(self.config.sliding_window, self.config.sliding_window), ) return attn_output # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape # On the first iteration we need to properly re-create the padding mask # by slicing it on the proper place if kv_seq_len != attention_mask.shape[-1]: attention_mask_num_tokens = attention_mask.shape[-1] attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) # Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2Moe class Qwen2MoeSdpaAttention(Qwen2MoeAttention): """ Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `Qwen2MoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ # Adapted from Qwen2MoeAttention.forward def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( "Qwen2MoeModel is using Qwen2MoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value QWEN2MOE_ATTENTION_CLASSES = { "eager": Qwen2MoeAttention, "flash_attention_2": Qwen2MoeFlashAttention2, "sdpa": Qwen2MoeSdpaAttention, } class Qwen2MoeSparseMoeBlock(nn.Module): def __init__(self, config): super().__init__() self.num_experts = config.num_experts self.top_k = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob # gating self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) self.experts = nn.ModuleList( [Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] ) self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) if self.norm_topk_prob: routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) shared_expert_output = self.shared_expert(hidden_states) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output final_hidden_states = final_hidden_states + shared_expert_output final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits class Qwen2MoeDecoderLayer(nn.Module): def __init__(self, config: Qwen2MoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) if (layer_idx not in config.mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen2MoeSparseMoeBlock(config) else: self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size) self.input_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_router_logits (`bool`, *optional*): Whether or not to return the logits of all the routers. They are useful for computing the router loss, and should not be returned during inference. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) if isinstance(hidden_states, tuple): hidden_states, router_logits = hidden_states else: router_logits = None hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) if output_router_logits: outputs += (router_logits,) return outputs QWEN2MOE_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`Qwen2MoeConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.", QWEN2MOE_START_DOCSTRING, ) class Qwen2MoePreTrainedModel(PreTrainedModel): config_class = Qwen2MoeConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2MoeDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() QWEN2MOE_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. output_router_logits (`bool`, *optional*): Whether or not to return the logits of all the routers. They are useful for computing the router loss, and should not be returned during inference. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ @add_start_docstrings( "The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.", QWEN2MOE_START_DOCSTRING, ) class Qwen2MoeModel(Qwen2MoePreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`] Args: config: Qwen2MoeConfig """ def __init__(self, config: Qwen2MoeConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [Qwen2MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self._attn_implementation = config._attn_implementation self.norm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, MoeModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False use_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, output_router_logits, use_cache, cache_position, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) if output_router_logits and layer_outputs[-1] is not None: all_router_logits += (layer_outputs[-1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] if v is not None ) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing if attention_mask.max() != 0: raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = Qwen2MoeModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_experts self.num_experts_per_tok = config.num_experts_per_tok # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM >>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok, attention_mask, ) if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device if not return_dict: output = (logits,) + outputs[1:] if output_router_logits: output = (aux_loss,) + output return (loss,) + output if loss is not None else output return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, use_cache=True, **kwargs, ): past_length = 0 # Omit tokens covered by past_key_values if past_key_values is not None: if isinstance(past_key_values, Cache): past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) if past_key_values.get_max_length() is not None else None ) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) elif use_cache: cache_position = cache_position[-input_length:] model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, "cache_position": cache_position, } ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past @add_start_docstrings( """ The Qwen2MoE Model transformer with a sequence classification head on top (linear layer). [`Qwen2MoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, QWEN2MOE_START_DOCSTRING, ) # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = Qwen2MoeModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) @add_start_docstrings( """ The Qwen2MoE Model transformer with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, QWEN2MOE_START_DOCSTRING, ) # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = Qwen2MoeModel(config) if getattr(config, "classifier_dropout", None) is not None: classifier_dropout = config.classifier_dropout elif getattr(config, "hidden_dropout", None) is not None: classifier_dropout = config.hidden_dropout else: classifier_dropout = 0.1 self.dropout = nn.Dropout(classifier_dropout) self.score = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) ================================================ FILE: archive/ktransformers/models/modeling_qwen3_moe.py ================================================ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/qwen3_moe/modular_qwen3_moe.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_qwen3_moe.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from transformers.generation import GenerationMixin from transformers.modeling_attn_mask_utils import AttentionMaskConverter # from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS # from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.modeling_utils import PreTrainedModel # from transformers.processing_utils import Unpack from transformers.utils import ( # LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) from transformers.utils.deprecation import deprecate_kwarg from .configuration_qwen3_moe import Qwen3MoeConfig from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRotaryEmbedding logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "Qwen/Qwen3-MoE-15B-A2B" _CONFIG_FOR_DOC = "Qwen3MoeConfig" def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Qwen3MoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Qwen3MoeConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.rotary_emb = Qwen2MoeRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) self.sliding_window = config.sliding_window if not ( self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers ): self.sliding_window = None def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, # **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward # if self.config._attn_implementation != "eager": # if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): # logger.warning_once( # "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " # 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' # ) # else: # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, # diff with Llama # **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Qwen3MoeMLP(nn.Module): def __init__(self, config, intermediate_size=None): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class Qwen3MoeSparseMoeBlock(nn.Module): def __init__(self, config): super().__init__() self.num_experts = config.num_experts self.top_k = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob # gating self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) self.experts = nn.ModuleList( [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) if self.norm_topk_prob: # only diff with mixtral sparse moe block! routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits class Qwen3MoeRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Qwen3MoeRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps self.hidden_size = hidden_size def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Qwen3MoeDecoderLayer(nn.Module): def __init__(self, config: Qwen3MoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Qwen3MoeAttention(config, layer_idx) self.mlp = Qwen3MoeMLP(config) self.self_attn = Qwen3MoeAttention(config, layer_idx) if (layer_idx not in config.mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen3MoeSparseMoeBlock(config) else: self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size) self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC # **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_router_logits (`bool`, *optional*): Whether or not to return the logits of all the routers. They are useful for computing the router loss, and should not be returned during inference. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) if isinstance(hidden_states, tuple): hidden_states, router_logits = hidden_states else: router_logits = None hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if output_router_logits: outputs += (router_logits,) return outputs def _compute_default_rope_parameters( config: Optional[Qwen3MoeConfig] = None, device: Optional["torch.device"] = None, seq_len: Optional[int] = None, **rope_kwargs, ) -> Tuple["torch.Tensor", float]: """ Computes the inverse frequencies according to the original RoPE implementation Args: config ([`~transformers.PretrainedConfig`]): The model configuration. device (`torch.device`): The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. rope_kwargs (`Dict`, *optional*): BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ if config is not None and len(rope_kwargs) > 0: raise ValueError( "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" ) if len(rope_kwargs) > 0: base = rope_kwargs["base"] dim = rope_kwargs["dim"] elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 dim = int(config.head_dim * partial_rotary_factor) attention_factor = 1.0 # Unused in this type of RoPE # Compute the inverse frequencies inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) return inv_freq, attention_factor class Qwen3MoeRotaryEmbedding(nn.Module): def __init__(self, config: Qwen3MoeConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.scaling_factor = 1.0 self.dim = config.head_dim self.max_position_embeddings = config.max_position_embeddings self.base = config.rope_theta inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) inv_freq, self.attention_scaling = _compute_default_rope_parameters(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset # This .to() is needed if the model has been moved to a device after being initialized (because # the buffer is automatically moved, but not the original copy) self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention cos = cos * self.attention_scaling sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) QWEN3_MOE_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`Qwen3MoeConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare Qwen3Moe Model outputting raw hidden-states without any specific head on top.", QWEN3_MOE_START_DOCSTRING, ) class Qwen3MoePreTrainedModel(PreTrainedModel): config_class = Qwen3MoeConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen3MoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() QWEN3_MOE_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ @add_start_docstrings( "The bare Qwen3Moe Model outputting raw hidden-states without any specific head on top.", QWEN3_MOE_START_DOCSTRING, ) class Qwen3MoeModel(Qwen3MoePreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3MoeDecoderLayer`] Args: config: Qwen3MoeConfig """ def __init__(self, config: Qwen3MoeConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [Qwen3MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, # **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False if use_cache and past_key_values is None: past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, output_router_logits, use_cache, cache_position, position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, # **flash_attn_kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) if output_router_logits: all_router_logits += (layer_outputs[-1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) output = MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) return output if return_dict else output.to_tuple() def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" " this may lead to unexpected behaviour for Flash Attention version of Qwen3Moe. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if ( self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache) and not output_attentions ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, past_key_values=past_key_values, ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, batch_size: int, config: Qwen3MoeConfig, past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. config (`Qwen3MoeConfig`): The model's configuration class past_key_values (`Cache`): The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( causal_mask.device ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask # class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... class KwargsForCausalLM(): ... def load_balancing_loss_func( gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], num_experts: Optional[int] = None, top_k=2, attention_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, int]: r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between experts is too unbalanced. Args: gate_logits: Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of shape [batch_size X sequence_length, num_experts]. num_experts: Number of experts top_k: The number of experts to route per-token, can be also interpreted as the `top-k` routing parameter. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. Returns: The auxiliary loss. """ if gate_logits is None or not isinstance(gate_logits, tuple): return 0 if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) if attention_mask is None: # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.mean(expert_mask.float(), dim=0) # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) else: batch_size, sequence_length = attention_mask.shape num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( attention_mask[None, :, :, None, None] .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) .reshape(-1, top_k, num_experts) .to(compute_device) ) # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( expert_attention_mask, dim=0 ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) .reshape(-1, num_experts) .to(compute_device) ) # Compute the average probability of routing to these experts router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( router_per_expert_attention_mask, dim=0 ) overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) self.model = Qwen3MoeModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_experts self.num_experts_per_tok = config.num_experts_per_tok # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, # **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. logits_to_keep (`int` or `torch.Tensor`, *optional*): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: Example: ```python >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, # **kwargs, ) hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size) aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok, attention_mask, ) if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device if not return_dict: output = (logits,) + outputs[1:] if output_router_logits: output = (aux_loss,) + output return (loss,) + output if loss is not None else output return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, ) @add_start_docstrings( """ The Qwen3Moe Model transformer with a sequence classification head on top (linear layer). [`Qwen3MoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, QWEN3_MOE_START_DOCSTRING, ) class Qwen3MoeForSequenceClassification(Qwen3MoePreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = Qwen3MoeModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 logger.warning_once( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) @add_start_docstrings( """ The Qwen3Moe Model transformer with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, QWEN3_MOE_START_DOCSTRING, ) class Qwen3MoeForTokenClassification(Qwen3MoePreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = Qwen3MoeModel(config) if getattr(config, "classifier_dropout", None) is not None: classifier_dropout = config.classifier_dropout elif getattr(config, "hidden_dropout", None) is not None: classifier_dropout = config.hidden_dropout else: classifier_dropout = 0.1 self.dropout = nn.Dropout(classifier_dropout) self.score = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) loss = None if labels is not None: loss = self.loss_function(logits, labels, self.config) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @add_start_docstrings( """ The Qwen3Moe Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). """, QWEN3_MOE_START_DOCSTRING, ) class Qwen3MoeForQuestionAnswering(Qwen3MoePreTrainedModel): base_model_prefix = "transformer" def __init__(self, config): super().__init__(config) self.transformer = Qwen3MoeModel(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.transformer.embed_tokens def set_input_embeddings(self, value): self.transformer.embed_tokens = value @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.transformer( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() loss = None if start_positions is not None and end_positions is not None: loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((loss,) + output) if loss is not None else output return QuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) __all__ = [ "Qwen3MoeForCausalLM", "Qwen3MoeForQuestionAnswering", "Qwen3MoeModel", "Qwen3MoePreTrainedModel", "Qwen3MoeForSequenceClassification", "Qwen3MoeForTokenClassification", ] ================================================ FILE: archive/ktransformers/models/modeling_qwen3_next.py ================================================ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/qwen3_next/modular_qwen3_next.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_qwen3_next.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import Cache from transformers.generation import GenerationMixin from transformers.masking_utils import create_causal_mask from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_layers import ( GenericForQuestionAnswering, GenericForSequenceClassification, GenericForTokenClassification, GradientCheckpointingLayer, ) from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import OutputRecorder, check_model_inputs try: from transformers.utils.import_utils import ( is_causal_conv1d_available, is_flash_linear_attention_available, ) except ImportError: is_causal_conv1d_available = lambda: False try: from transformers.utils.import_utils import ( is_flash_linear_attention_available, ) except ImportError: is_flash_linear_attention_available = lambda: False from .configuration_qwen3_next import Qwen3NextConfig if is_causal_conv1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update else: causal_conv1d_update, causal_conv1d_fn = None, None if is_flash_linear_attention_available(): from fla.modules import FusedRMSNormGated from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule else: chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None FusedRMSNormGated = None logger = logging.get_logger(__name__) class Qwen3NextRMSNormGated(nn.Module): def __init__(self, hidden_size, eps=1e-6, **kwargs): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states, gate=None): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) # Norm before gate hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) hidden_states = self.weight * hidden_states.to(input_dtype) hidden_states = hidden_states * F.silu(gate.to(torch.float32)) return hidden_states.to(input_dtype) class Qwen3NextDynamicCache: """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention cache (which has a constant shape regardless of seq_len). This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`. """ is_compileable = False def __init__(self, config: Qwen3NextConfig): super().__init__() self.layer_types = config.layer_types self.transformer_layers = [ i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" ] self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference self.conv_states = [None for _ in range(config.num_hidden_layers)] self.recurrent_states = [None for _ in range(config.num_hidden_layers)] self.key_cache = [None for _ in range(config.num_hidden_layers)] self.value_cache = [None for _ in range(config.num_hidden_layers)] def __len__(self): return len(self.layer_types) def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: return self.key_cache[layer_idx], self.value_cache[layer_idx] def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: if self.key_cache[layer_idx] is None: self.key_cache[layer_idx] = key_states self.value_cache[layer_idx] = value_states else: self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) return self.key_cache[layer_idx], self.value_cache[layer_idx] def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" for layer_idx in range(len(self.key_cache)): if self.key_cache[layer_idx] is not None: device = self.key_cache[layer_idx].device beam_idx = beam_idx.to(device) self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) if self.conv_states[layer_idx] is not None: device = self.conv_states[layer_idx].device beam_idx = beam_idx.to(device) self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx) def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # take any layer that contains cache and not empty tensor layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: return 0 return self.key_cache[layer_idx].shape[-2] def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: """ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for the given layer at `layer_idx`. The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. """ kv_offset = 0 query_length = cache_position.shape[0] past_seen_tokens = self.get_seq_length(layer_idx) kv_length = query_length + past_seen_tokens return kv_length, kv_offset @property def has_previous_state(self): """We have a previous state if the last linear (conv) layer was already updated.""" return self.conv_states[self.last_linear_layer] is not None class Qwen3NextRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: Qwen3NextConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class Qwen3NextRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.hidden_size = dim self.variance_epsilon = eps self.eps = eps self.weight = nn.Parameter(torch.zeros(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()) # Llama does x.to(float16) * w whilst Qwen3Next is (x * w).to(float16) # See https://github.com/huggingface/transformers/pull/29402 output = output * (1.0 + self.weight.float()) return output.type_as(x) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Removes the interleaving of cos and sin from GLM Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) # Keep half or full tensor for later concatenation rotary_dim = cos.shape[-1] q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] # Apply rotary embeddings on the first half or full tensor q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) # Concatenate back to full shape q_embed = torch.cat([q_embed, q_pass], dim=-1) k_embed = torch.cat([k_embed, k_pass], dim=-1) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Qwen3NextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Qwen3NextConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! self.k_norm = Qwen3NextRMSNorm( self.head_dim, eps=config.rms_norm_eps ) # thus post q_norm does not need reshape @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states, gate = torch.chunk( self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 ) gate = gate.reshape(*input_shape, -1) query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = attn_output * torch.sigmoid(gate) attn_output = self.o_proj(attn_output) return attn_output, attn_weights def apply_mask_to_padding_states(hidden_states, attention_mask): """ Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 """ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: dtype = hidden_states.dtype hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) return hidden_states is_fast_path_available = all( (causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule) ) def torch_causal_conv1d_update( hidden_states, conv_state, weight, bias=None, activation=None, ): _, hidden_size, seq_len = hidden_states.shape state_len = conv_state.shape[-1] hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) conv_state.copy_(hidden_states_new[:, :, -state_len:]) out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size) out = F.silu(out[:, :, -seq_len:]) out = out.to(hidden_states.dtype) return out def torch_chunk_gated_delta_rule( query, key, value, g, beta, chunk_size=64, initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=False, ): initial_dtype = query.dtype if use_qk_l2norm_in_kernel: query = F.normalize(query, p=2, dim=-1) key = F.normalize(key, p=2, dim=-1) query, key, value, beta, g = [ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) ] batch_size, sequence_length, num_heads, k_head_dim = key.shape v_head_dim = value.shape[-1] pad_size = (chunk_size - num_heads % chunk_size) % chunk_size query = F.pad(query, (0, 0, 0, pad_size)) key = F.pad(key, (0, 0, 0, pad_size)) value = F.pad(value, (0, 0, 0, pad_size)) beta = F.pad(beta, (0, pad_size)) g = F.pad(g, (0, pad_size)) tot_heads = num_heads + pad_size scale = 1 / (query.shape[-1] ** 0.5) query = query * scale v_beta = value * beta.unsqueeze(-1) k_beta = key * beta.unsqueeze(-1) # reshape to chunks query, key, value, k_beta, v_beta = [ x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) ] g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) # chunk decay g = g.cumsum(dim=-1) decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) for i in range(1, chunk_size): row = attn[..., i, :i].clone() sub = attn[..., :i, :i].clone() attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) value = attn @ v_beta k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) last_recurrent_state = ( torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value) if initial_state is None else initial_state.to(value) ) core_attn_out = torch.zeros_like(value) mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) # for each chunk for i in range(0, tot_heads // chunk_size): q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state v_new = v_i - v_prime attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state core_attn_out[:, :, i] = attn_inter + attn @ v_new last_recurrent_state = ( last_recurrent_state * g[:, :, i, -1, None, None].exp() + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new ) if not output_final_state: last_recurrent_state = None core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) core_attn_out = core_attn_out[:, :, :num_heads] core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) return core_attn_out, last_recurrent_state def torch_recurrent_gated_delta_rule( query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False ): initial_dtype = query.dtype if use_qk_l2norm_in_kernel: query = F.normalize(query, p=2, dim=-1) key = F.normalize(key, p=2, dim=-1) query, key, value, beta, g = [ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) ] batch_size, sequence_length, num_heads, k_head_dim = key.shape v_head_dim = value.shape[-1] scale = 1 / (query.shape[-1] ** 0.5) query = query * scale core_attn_out = torch.zeros(batch_size, sequence_length, num_heads, v_head_dim).to(value) last_recurrent_state = ( torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value) if initial_state is None else initial_state.to(value) ) for i in range(num_heads): q_t = query[:, :, i] k_t = key[:, :, i] v_t = value[:, :, i] g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) beta_t = beta[:, :, i].unsqueeze(-1) last_recurrent_state = last_recurrent_state * g_t kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) delta = (v_t - kv_mem) * beta_t last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) if not output_final_state: last_recurrent_state = None core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) return core_attn_out, last_recurrent_state class Qwen3NextGatedDeltaNet(nn.Module): def __init__(self, config: Qwen3NextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.num_v_heads = config.linear_num_value_heads self.num_k_heads = config.linear_num_key_heads self.head_k_dim = config.linear_key_head_dim self.head_v_dim = config.linear_value_head_dim self.key_dim = self.head_k_dim * self.num_k_heads self.value_dim = self.head_v_dim * self.num_v_heads self.conv_kernel_size = config.linear_conv_kernel_dim self.layer_idx = layer_idx self.activation = config.hidden_act self.act = ACT2FN[config.hidden_act] self.layer_norm_epsilon = config.rms_norm_eps self.config = config # QKV self.conv_dim = self.key_dim * 2 + self.value_dim self.conv1d = nn.Conv1d( in_channels=self.conv_dim, out_channels=self.conv_dim, bias=False, kernel_size=self.conv_kernel_size, groups=self.conv_dim, padding=self.conv_kernel_size - 1, ) # projection of the input hidden states projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 projection_size_ba = self.num_v_heads * 2 self.in_proj_qkvz = nn.Linear(self.hidden_size, projection_size_qkvz, bias=False) self.in_proj_ba = nn.Linear(self.hidden_size, projection_size_ba, bias=False) # time step projection (discretization) # instantiate once and copy inv_dt in init_weights of PretrainedModel self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) A = torch.empty(self.num_v_heads).uniform_(0, 16) self.A_log = nn.Parameter(torch.log(A)) self.norm = ( Qwen3NextRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) if FusedRMSNormGated is None else FusedRMSNormGated( self.head_v_dim, eps=self.layer_norm_epsilon, activation=self.activation, device=torch.cuda.current_device(), dtype=config.dtype if config.dtype is not None else torch.get_current_dtype(), ) ) self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) self.causal_conv1d_fn = causal_conv1d_fn self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule if not is_fast_path_available: logger.warning_once( "The fast path is not available because one of the required library is not installed. Falling back to " "torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and" " https://github.com/Dao-AILab/causal-conv1d" ) def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): """ Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. """ new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( self.num_k_heads, 2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads, ) new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads) mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) mixed_ba = mixed_ba.view(*new_tensor_shape_ba) split_arg_list_qkvz = [ self.head_k_dim, self.head_k_dim, (self.num_v_heads // self.num_k_heads * self.head_v_dim), (self.num_v_heads // self.num_k_heads * self.head_v_dim), ] split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads] query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim) z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim) b = b.reshape(b.size(0), b.size(1), self.num_v_heads) a = a.reshape(a.size(0), a.size(1), self.num_v_heads) return query, key, value, z, b, a def forward( self, hidden_states: torch.Tensor, cache_params: Optional[Qwen3NextDynamicCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = ( cache_params is not None and cache_params.has_previous_state and seq_len == 1 and cache_position is not None ) # getting projected states from cache if it exists if cache_params is not None: conv_state = cache_params.conv_states[self.layer_idx] recurrent_state = cache_params.recurrent_states[self.layer_idx] projected_states_qkvz = self.in_proj_qkvz(hidden_states) projected_states_ba = self.in_proj_ba(hidden_states) query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) mixed_qkv = torch.cat((query, key, value), dim=-1) mixed_qkv = mixed_qkv.transpose(1, 2) if use_precomputed_states: # 2. Convolution sequence transformation # NOTE: the conv state is updated in `causal_conv1d_update` mixed_qkv = self.causal_conv1d_update( mixed_qkv, conv_state, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, ) else: if cache_params is not None: conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) cache_params.conv_states[self.layer_idx] = conv_state if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, weight=self.conv1d.weight.squeeze(1), bias=self.conv1d.bias, activation=self.activation, seq_idx=None, ) else: mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split( mixed_qkv, [ self.key_dim, self.key_dim, self.value_dim, ], dim=-1, ) query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim) key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim) value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) beta = b.sigmoid() # If the model is loaded in fp16, without the .float() here, A might be -inf g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) if self.num_v_heads // self.num_k_heads > 1: query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) if not use_precomputed_states: core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( query, key, value, g=g, beta=beta, initial_state=None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) else: core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, key, value, g=g, beta=beta, initial_state=recurrent_state, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) # Update cache if cache_params is not None: cache_params.recurrent_states[self.layer_idx] = last_recurrent_state z_shape_og = z.shape # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) output = self.out_proj(core_attn_out) return output class Qwen3NextMLP(nn.Module): def __init__(self, config, intermediate_size=None): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class Qwen3NextSparseMoeBlock(nn.Module): def __init__(self, config): super().__init__() self.num_experts = config.num_experts self.top_k = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob # gating self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) self.experts = nn.ModuleList( [Qwen3NextMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] ) self.shared_expert = Qwen3NextMLP(config, intermediate_size=config.shared_expert_intermediate_size) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) if self.norm_topk_prob: routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) shared_expert_output = self.shared_expert(hidden_states) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output final_hidden_states = final_hidden_states + shared_expert_output final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits class Qwen3NextDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen3NextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size # token mixer self.layer_type = config.layer_types[layer_idx] if self.layer_type == "linear_attention": self.linear_attn = Qwen3NextGatedDeltaNet(config, layer_idx) elif self.layer_type == "full_attention": self.self_attn = Qwen3NextAttention(config, layer_idx) if (layer_idx not in config.mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen3NextSparseMoeBlock(config) else: self.mlp = Qwen3NextMLP(config, intermediate_size=config.intermediate_size) self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[tuple[torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> torch.FloatTensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_router_logits (`bool`, *optional*): Whether or not to return the logits of all the routers. They are useful for computing the router loss, and should not be returned during inference. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Token Mixer if self.layer_type == "linear_attention": hidden_states = self.linear_attn( hidden_states=hidden_states, cache_params=past_key_values, cache_position=cache_position, attention_mask=attention_mask, ) elif self.layer_type == "full_attention": # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) # For the MoE layers, we need to unpack if isinstance(hidden_states, tuple): hidden_states, _ = hidden_states hidden_states = residual + hidden_states return hidden_states class Qwen3NextPreTrainedModel(PreTrainedModel): config: Qwen3NextConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen3NextDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True _keys_to_ignore_on_load_unexpected = [r"^mtp.*"] _can_record_outputs = { "router_logits": OutputRecorder(Qwen3NextSparseMoeBlock, index=1), "hidden_states": Qwen3NextDecoderLayer, "attentions": Qwen3NextAttention, } _is_stateful = True def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Qwen3NextGatedDeltaNet): module.dt_bias.data.fill_(1.0) module.A_log.data.uniform_(0, 16).log_() class Qwen3NextModel(Qwen3NextPreTrainedModel): def __init__(self, config: Qwen3NextConfig): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) self.layers = nn.ModuleList( [Qwen3NextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen3NextRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @check_model_inputs @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> MoeModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = Qwen3NextDynamicCache(config=self.config) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = create_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers[: self.config.num_hidden_layers]: layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask hidden_states = decoder_layer( hidden_states, position_embeddings=position_embeddings, attention_mask=layer_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = self.norm(hidden_states) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) def _update_linear_attn_mask(self, attention_mask, cache_position): """ NOTE: Left-padding is used for linear attention mask. No need for zeroing states when 1. Cached forward 2. Attending to all inputs """ linear_attn_mask = attention_mask if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): linear_attn_mask = None return linear_attn_mask def load_balancing_loss_func( gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], num_experts: Optional[int] = None, top_k=2, attention_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, int]: r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between experts is too unbalanced. Args: gate_logits: Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of shape [batch_size X sequence_length, num_experts]. num_experts: Number of experts top_k: The number of experts to route per-token, can be also interpreted as the `top-k` routing parameter. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. Returns: The auxiliary loss. """ if gate_logits is None or not isinstance(gate_logits, tuple): return 0 if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) if attention_mask is None: # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.mean(expert_mask.float(), dim=0) # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) else: batch_size, sequence_length = attention_mask.shape num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( attention_mask[None, :, :, None, None] .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) .reshape(-1, top_k, num_experts) .to(compute_device) ) # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( expert_attention_mask, dim=0 ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) .reshape(-1, num_experts) .to(compute_device) ) # Compute the average probability of routing to these experts router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( router_per_expert_attention_mask, dim=0 ) overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts @auto_docstring class Qwen3NextForCausalLM(Qwen3NextPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) self.model = Qwen3NextModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_experts self.num_experts_per_tok = config.num_experts_per_tok # Initialize weights and apply final processing self.post_init() @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Qwen3NextDynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_router_logits: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> MoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from transformers import AutoTokenizer, Qwen3NextForCausalLM >>> model = Qwen3NextForCausalLM.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct") >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_router_logits=output_router_logits, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( outputs.router_logits, self.num_experts, self.num_experts_per_tok, attention_mask, ) if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, ) class Qwen3NextForSequenceClassification(GenericForSequenceClassification, Qwen3NextPreTrainedModel): pass class Qwen3NextForTokenClassification(GenericForTokenClassification, Qwen3NextPreTrainedModel): pass class Qwen3NextForQuestionAnswering(GenericForQuestionAnswering, Qwen3NextPreTrainedModel): base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model` __all__ = [ "Qwen3NextForCausalLM", "Qwen3NextForQuestionAnswering", "Qwen3NextModel", "Qwen3NextPreTrainedModel", "Qwen3NextForSequenceClassification", "Qwen3NextForTokenClassification", ] ================================================ FILE: archive/ktransformers/models/modeling_smallthinker.py ================================================ # coding=utf-8 from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from transformers.generation import GenerationMixin from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import can_return_tuple, is_torch_flex_attn_available, logging from .configuration_smallthinker import SmallthinkerConfig if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask from transformers.integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) class SmallthinkerHierarchicalMLP(nn.Module): def __init__(self, config: SmallthinkerConfig): super().__init__() self.config = config self.hidden_dim = config.hidden_size self.ffn_dim = config.moe_ffn_hidden_size self.moe_enable_secondary_experts = config.moe_enable_secondary_experts if self.moe_enable_secondary_experts: self.num_secondary_experts = config.moe_num_secondary_experts self.secondary_expert_size = config.moe_secondary_expert_size self.secondary_gate = nn.Linear(self.hidden_dim, self.num_secondary_experts, bias=False) self.up = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.gate = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.down = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) def forward(self, secondary_gate_input: torch.Tensor, hidden_states: torch.Tensor): if self.moe_enable_secondary_experts: secondary_gate_logits = F.sigmoid(self.secondary_gate(secondary_gate_input)) > 0.5 secondary_gate_mask = secondary_gate_logits.unsqueeze(-1) current_hidden_states = self.up(hidden_states) * F.relu(self.gate(hidden_states)) activated_output = current_hidden_states batch_size, intermediate_size = activated_output.shape if self.moe_enable_secondary_experts: num_groups = intermediate_size // self.secondary_expert_size activated_output = activated_output.view(batch_size, num_groups, self.secondary_expert_size) output = activated_output * secondary_gate_mask else: output = activated_output current_hidden_states = output.view(batch_size, -1) current_hidden_states = self.down(current_hidden_states) return current_hidden_states class SmallthinkerMoeBlock(nn.Module): def __init__(self, config: SmallthinkerConfig): super().__init__() self.hidden_dim = config.hidden_size self.num_primary_experts = config.moe_num_primary_experts self.enable_early_router = config.moe_enable_early_router self.moe_primary_router_apply_softmax = config.moe_primary_router_apply_softmax self.num_active_primary_experts = config.moe_num_active_primary_experts self.primary_router = nn.Linear(self.hidden_dim, self.num_primary_experts, bias=False) self.experts = nn.ModuleList([SmallthinkerHierarchicalMLP(config) for _ in range(self.num_primary_experts)]) def forward(self, router_input: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape # Flatten the tokens into (bs * sl, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim) router_input = router_input.view(-1, hidden_dim) # Primary router logits: (bs * sl, n_experts) if self.enable_early_router: router_logits = self.primary_router(router_input) else: router_logits = self.primary_router(hidden_states) router_logits, selected_experts = torch.topk(router_logits, self.num_active_primary_experts, dim=-1) if self.moe_primary_router_apply_softmax: routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) else: routing_weights = F.sigmoid(router_logits) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) routing_weights = routing_weights.to(hidden_states.dtype) # Prepare the final tensor final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_primary_experts).permute(2, 1, 0) expert_hitted = (expert_mask.sum(dim=(-1, -2)) > 0).nonzero(as_tuple=True)[0].tolist() for expert_idx in expert_hitted: expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) # current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) # current_router_input = router_input[None, top_x].reshape(-1, hidden_dim) current_state = hidden_states[top_x].reshape(-1, hidden_dim) current_router_input = router_input[top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_router_input, current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits class SmallthinkerDenseMlpBlock(nn.Module): def __init__(self, config: SmallthinkerConfig): super().__init__() hidden_dim = config.hidden_size ffn_dim = config.dense_ffn_hidden_size self.up = nn.Linear(hidden_dim, ffn_dim, bias=False) self.gate = nn.Linear(hidden_dim, ffn_dim, bias=False) self.down = nn.Linear(ffn_dim, hidden_dim, bias=False) # Offer unified interface for SmallthinkerMoeBlock and SmallthinkerDenseMlpBlock, though router_input is not used here def forward(self, router_input: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor: current_hidden_states = self.up(hidden_states) * F.relu(self.gate(hidden_states)) current_hidden_states = self.down(current_hidden_states) return current_hidden_states, None class SmallthinkerRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ SmallthinkerRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.hidden_size = hidden_size self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class SmallthinkerAttention(nn.Module): def __init__(self, config: SmallthinkerConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx # For KVCache management self.head_dim = config.head_dim self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.is_causal = True self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.sliding_window = config.sliding_window_size if config.sliding_window_layout[layer_idx] else None self.use_qk_norm = config.use_qk_norm def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if self.use_qk_norm: raise NotImplementedError("use_qk_norm is not implemented yet") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if position_embeddings: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) else: cos, sin = None, None if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation == "sdpa": raise NotImplementedError("SDPA impl is buggy for now. NEVER TRY TO USE IT.") if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0, scaling=self.scaling, sliding_window=self.sliding_window, # main diff with Llama **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class SmallthinkerDecoderLayer(nn.Module): def __init__(self, config: SmallthinkerConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = SmallthinkerAttention(config, layer_idx) self.block_sparse_moe = SmallthinkerMoeBlock(config) if config.moe_layer_layout[layer_idx] else SmallthinkerDenseMlpBlock(config) self.input_layernorm = SmallthinkerRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = SmallthinkerRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_router_logits (`bool`, *optional*): Whether or not to return the logits of all the routers. They are useful for computing the router loss, and should not be returned during inference. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model """ # print(f"hidden states, shape {hidden_states.shape}: {hidden_states}") # debug print residual = hidden_states router_input = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, router_logits = self.block_sparse_moe(router_input, hidden_states) hidden_states = residual + hidden_states # SYNC after_moe_residual_value=hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if output_router_logits: outputs += (router_logits,) return outputs class SmallthinkerRotaryEmbedding(nn.Module): def __init__(self, config: SmallthinkerConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class SmallthinkerPreTrainedModel(PreTrainedModel): config_class = SmallthinkerConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["SmallthinkerDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, SmallthinkerRMSNorm): module.weight.data.fill_(1.0) # @auto_docstring class SmallthinkerModel(SmallthinkerPreTrainedModel): def __init__(self, config: SmallthinkerConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [SmallthinkerDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = SmallthinkerRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = SmallthinkerRotaryEmbedding(config=config) self.gradient_checkpointing = False self.rope_layout = config.rope_layout # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @can_return_tuple # @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> MoeModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False if use_cache and past_key_values is None: past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) # print("atten mask:", attention_mask) # debug print causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) # print("causal mask:", causal_mask) # debug print hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, past_key_values, output_attentions, output_router_logits, use_cache, cache_position, position_embeddings if self.rope_layout[layer_idx] else None, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings if self.rope_layout[layer_idx] else None, **flash_attn_kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) if output_router_logits: all_router_logits += (layer_outputs[-1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" " this may lead to unexpected behaviour for Flash Attention version of Smallthinker. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if ( self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache) and not output_attentions ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype = input_tensor.dtype min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, past_key_values=past_key_values, ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, cache_position: torch.Tensor, batch_size: int, config: SmallthinkerConfig, past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. config (`SmallthinkerConfig`): The model's configuration class past_key_values (`Cache`): The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -1, 1 ) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( causal_mask.device ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask class KwargsForCausalLM(FlashAttentionKwargs): ... def load_balancing_loss_func( gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], num_experts: Optional[int] = None, top_k=2, attention_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, int]: r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between experts is too unbalanced. Args: gate_logits: Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of shape [batch_size X sequence_length, num_experts]. num_experts: Number of experts top_k: The number of experts to route per-token, can be also interpreted as the `top-k` routing parameter. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. Returns: The auxiliary loss. """ if gate_logits is None or not isinstance(gate_logits, tuple): return 0 if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) if attention_mask is None: # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.mean(expert_mask.float(), dim=0) # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) else: batch_size, sequence_length = attention_mask.shape num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( attention_mask[None, :, :, None, None] .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) .reshape(-1, top_k, num_experts) .to(compute_device) ) # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( expert_attention_mask, dim=0 ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) .reshape(-1, num_experts) .to(compute_device) ) # Compute the average probability of routing to these experts router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( router_per_expert_attention_mask, dim=0 ) overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts # @auto_docstring class SmallThinkerForCausalLM(SmallthinkerPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = SmallthinkerModel(config) self.vocab_size = config.vocab_size # Handle tie / untie word embeddings self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # self.num_experts = config.num_local_experts # self.num_experts_per_tok = config.num_experts_per_tok # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @can_return_tuple # @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> MoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from transformers import AutoTokenizer, SmallThinkerForCausalLM >>> model = SmallThinkerForCausalLM.from_pretrained("mistralai/Smallthinker-8x7B-v0.1") >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Smallthinker-8x7B-v0.1") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: MoeModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( outputs.router_logits, self.num_experts, self.num_experts_per_tok, attention_mask, ) if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, ) # No such functions for now # #@auto_docstring( # custom_intro=""" # The Smallthinker Model transformer with a sequence classification head on top (linear layer). # [`SmallthinkerForSequenceClassification`] uses the last token in order to do the classification, as other causal models # (e.g. GPT-2) do. # Since it does classification on the last token, it requires to know the position of the last token. If a # `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If # no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the # padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in # each row of the batch). # """ # ) # class SmallthinkerForSequenceClassification(SmallthinkerPreTrainedModel): # def __init__(self, config): # super().__init__(config) # self.num_labels = config.num_labels # self.model = SmallthinkerModel(config) # self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # # Initialize weights and apply final processing # self.post_init() # def get_input_embeddings(self): # return self.model.embed_tokens # def set_input_embeddings(self, value): # self.model.embed_tokens = value # @can_return_tuple # #@auto_docstring # def forward( # self, # input_ids: Optional[torch.LongTensor] = None, # attention_mask: Optional[torch.Tensor] = None, # position_ids: Optional[torch.LongTensor] = None, # past_key_values: Optional[Cache] = None, # inputs_embeds: Optional[torch.FloatTensor] = None, # labels: Optional[torch.LongTensor] = None, # use_cache: Optional[bool] = None, # output_attentions: Optional[bool] = None, # output_hidden_states: Optional[bool] = None, # ) -> SequenceClassifierOutputWithPast: # r""" # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): # Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., # config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If # `config.num_labels > 1` a classification loss is computed (Cross-Entropy). # """ # transformer_outputs: BaseModelOutputWithPast = self.model( # input_ids, # attention_mask=attention_mask, # position_ids=position_ids, # past_key_values=past_key_values, # inputs_embeds=inputs_embeds, # use_cache=use_cache, # output_attentions=output_attentions, # output_hidden_states=output_hidden_states, # ) # hidden_states = transformer_outputs.last_hidden_state # logits = self.score(hidden_states) # if input_ids is not None: # batch_size = input_ids.shape[0] # else: # batch_size = inputs_embeds.shape[0] # if self.config.pad_token_id is None and batch_size != 1: # raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") # if self.config.pad_token_id is None: # last_non_pad_token = -1 # elif input_ids is not None: # # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id # non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) # token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) # last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) # else: # last_non_pad_token = -1 # logger.warning_once( # f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " # "unexpected if using padding tokens in conjunction with `inputs_embeds.`" # ) # pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] # loss = None # if labels is not None: # loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) # return SequenceClassifierOutputWithPast( # loss=loss, # logits=pooled_logits, # past_key_values=transformer_outputs.past_key_values, # hidden_states=transformer_outputs.hidden_states, # attentions=transformer_outputs.attentions, # ) # #@auto_docstring # class SmallthinkerForTokenClassification(SmallthinkerPreTrainedModel): # def __init__(self, config): # super().__init__(config) # self.num_labels = config.num_labels # self.model = SmallthinkerModel(config) # if getattr(config, "classifier_dropout", None) is not None: # classifier_dropout = config.classifier_dropout # elif getattr(config, "hidden_dropout", None) is not None: # classifier_dropout = config.hidden_dropout # else: # classifier_dropout = 0.1 # self.dropout = nn.Dropout(classifier_dropout) # self.score = nn.Linear(config.hidden_size, config.num_labels) # # Initialize weights and apply final processing # self.post_init() # def get_input_embeddings(self): # return self.model.embed_tokens # def set_input_embeddings(self, value): # self.model.embed_tokens = value # @can_return_tuple # #@auto_docstring # def forward( # self, # input_ids: Optional[torch.LongTensor] = None, # attention_mask: Optional[torch.Tensor] = None, # position_ids: Optional[torch.LongTensor] = None, # past_key_values: Optional[Cache] = None, # inputs_embeds: Optional[torch.FloatTensor] = None, # labels: Optional[torch.LongTensor] = None, # use_cache: Optional[bool] = None, # output_attentions: Optional[bool] = None, # output_hidden_states: Optional[bool] = None, # ) -> TokenClassifierOutput: # r""" # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): # Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., # config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If # `config.num_labels > 1` a classification loss is computed (Cross-Entropy). # """ # outputs: BaseModelOutputWithPast = self.model( # input_ids, # attention_mask=attention_mask, # position_ids=position_ids, # past_key_values=past_key_values, # inputs_embeds=inputs_embeds, # use_cache=use_cache, # output_attentions=output_attentions, # output_hidden_states=output_hidden_states, # ) # sequence_output = outputs.last_hidden_state # sequence_output = self.dropout(sequence_output) # logits = self.score(sequence_output) # loss = None # if labels is not None: # loss = self.loss_function(logits, labels, self.config) # return TokenClassifierOutput( # loss=loss, # logits=logits, # hidden_states=outputs.hidden_states, # attentions=outputs.attentions, # ) # #@auto_docstring # class SmallthinkerForQuestionAnswering(SmallthinkerPreTrainedModel): # base_model_prefix = "model" # def __init__(self, config): # super().__init__(config) # self.qa_outputs = nn.Linear(config.hidden_size, 2) # self.model = SmallthinkerModel(config) # diff with Llama: transformer->model # # Initialize weights and apply final processing # self.post_init() # def get_input_embeddings(self): # return self.model.embed_tokens # def set_input_embeddings(self, value): # self.model.embed_tokens = value # @can_return_tuple # #@auto_docstring # def forward( # self, # input_ids: Optional[torch.LongTensor] = None, # attention_mask: Optional[torch.Tensor] = None, # position_ids: Optional[torch.LongTensor] = None, # past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, # inputs_embeds: Optional[torch.FloatTensor] = None, # start_positions: Optional[torch.LongTensor] = None, # end_positions: Optional[torch.LongTensor] = None, # output_attentions: Optional[bool] = None, # output_hidden_states: Optional[bool] = None, # **kwargs, # ) -> QuestionAnsweringModelOutput: # r""" # start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): # Labels for position (index) of the start of the labelled span for computing the token classification loss. # Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence # are not taken into account for computing the loss. # end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): # Labels for position (index) of the end of the labelled span for computing the token classification loss. # Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence # are not taken into account for computing the loss. # """ # outputs: BaseModelOutputWithPast = self.model( # input_ids, # attention_mask=attention_mask, # position_ids=position_ids, # past_key_values=past_key_values, # inputs_embeds=inputs_embeds, # output_attentions=output_attentions, # output_hidden_states=output_hidden_states, # ) # sequence_output = outputs.last_hidden_state # logits = self.qa_outputs(sequence_output) # start_logits, end_logits = logits.split(1, dim=-1) # start_logits = start_logits.squeeze(-1).contiguous() # end_logits = end_logits.squeeze(-1).contiguous() # loss = None # if start_positions is not None and end_positions is not None: # loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) # return QuestionAnsweringModelOutput( # loss=loss, # start_logits=start_logits, # end_logits=end_logits, # hidden_states=outputs.hidden_states, # attentions=outputs.attentions, # ) __all__ = [ "SmallThinkerForCausalLM", "SmallthinkerForQuestionAnswering", "SmallthinkerModel", "SmallthinkerPreTrainedModel", "SmallthinkerForSequenceClassification", "SmallthinkerForTokenClassification", ] if __name__ == "__main__": from transformers import AutoTokenizer, AutoModelForCausalLM test_config = SmallthinkerConfig() tokenizer = AutoTokenizer.from_pretrained("./qwen-tokenizer") text = "Once upon a day" tokens = tokenizer.encode_plus( text,add_special_tokens=True,return_tensors='pt') # print(tokens) test_model = AutoModelForCausalLM.from_pretrained(".").cuda() output = test_model.generate(tokens) otokens = tokenizer.decode(output[0]) # print(otokens) ================================================ FILE: archive/ktransformers/operators/RoPE.py ================================================ """ Description : Author : Boxin Zhang Version : 0.1.0 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ from torch import nn from transformers import ROPE_INIT_FUNCTIONS from ktransformers.models.modeling_llama import ( LlamaRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding, ) from ktransformers.models.modeling_deepseek_v3 import ( DeepseekV3RotaryEmbedding ) from ktransformers.models.modeling_deepseek import ( DeepseekV2YarnRotaryEmbedding, DeepseekV2RotaryEmbedding, yarn_get_mscale, yarn_linear_ramp_mask, yarn_find_correction_range ) from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_loader import GGUFLoader from ktransformers.util.utils import InferenceState from transformers.configuration_utils import PretrainedConfig from ktransformers.models.modeling_smallthinker import SmallthinkerRotaryEmbedding from ktransformers.models.modeling_glm4_moe import Glm4MoeRotaryEmbedding import torch # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", generate_device: str = "cuda", prefill_device: str = "cuda", **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs ) self.orig_module.__init__( orig_module.dim, orig_module.max_position_embeddings, orig_module.base ) self.generate_device = generate_device self.prefill_device = prefill_device def load(self): self.orig_module.__init__( self.orig_module.dim, self.orig_module.max_position_embeddings, self.orig_module.base, self.device, ) class RotaryEmbeddingV3(BaseInjectedModule): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", generate_device: str = "cuda", prefill_device: str = "cuda", **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs ) self.generate_device = generate_device self.prefill_device = prefill_device @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def load(self): self._init( dim=self.config.qk_rope_head_dim, max_position_embeddings=self.config.max_position_embeddings, base=self.config.rope_theta, device=self.device, ) def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0): self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) # self.register_buffer("inv_freq", inv_freq, persistent=False) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, generate_device: str = "cuda", prefill_device: str = "cuda", **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs ) self.orig_module.__init__( orig_module.dim, orig_module.max_position_embeddings, orig_module.base, None, orig_module.scaling_factor, orig_module.rope_type, orig_module.config, ) self.generate_device = generate_device self.prefill_device = prefill_device def load(self): self.orig_module.__init__( self.orig_module.dim, self.orig_module.max_position_embeddings, self.orig_module.base, self.device, self.orig_module.scaling_factor, self.orig_module.rope_type, self.orig_module.config, ) class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", generate_device: str = "cuda", prefill_device: str = "cuda", **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs ) self.orig_module.__init__( orig_module.dim, orig_module.max_position_embeddings, orig_module.base, None, # device orig_module.scaling_factor, orig_module.original_max_position_embeddings, orig_module.beta_fast, orig_module.beta_slow, orig_module.mscale, orig_module.mscale_all_dim, ) self.generate_device = generate_device self.prefill_device = prefill_device def load(self): self.orig_module.__init__( self.orig_module.dim, self.orig_module.max_position_embeddings, self.orig_module.base, self.generate_device, self.orig_module.scaling_factor, self.orig_module.original_max_position_embeddings, self.orig_module.beta_fast, self.orig_module.beta_slow, self.orig_module.mscale, self.orig_module.mscale_all_dim, ) # class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding): # def __init__( # self, # key: str, # gguf_loader: GGUFLoader, # config: PretrainedConfig, # orig_module: nn.Module, # # device: str = "cuda", # generate_device: str = "cuda", # prefill_device: str = "cuda", # **kwargs, # ): # BaseInjectedModule.__init__( # self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs # ) # self.generate_device = generate_device # self.prefill_device = prefill_device # def load(self): # # TODO support perlayer prefill # self.orig_module.__init__( # self.config, # device=self.generate_device # ) # return class YarnRotaryEmbeddingV3(BaseInjectedModule): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", generate_device: str = "cuda", prefill_device: str = "cuda", **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs ) self.generate_device = generate_device self.prefill_device = prefill_device def load(self): kwargs = { key: self.config.rope_scaling[key] for key in [ "original_max_position_embeddings", "beta_fast", "beta_slow", "mscale", "mscale_all_dim", ] if key in self.config.rope_scaling } self._init( dim=self.config.qk_rope_head_dim, max_position_embeddings=self.config.max_position_embeddings, base=self.config.rope_theta, device=self.device, scaling_factor=self.config.rope_scaling["factor"], **kwargs, ) @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos()* self._mscale sin = emb.sin()* self._mscale return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def _init( self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, original_max_position_embeddings=4096, beta_fast=32, beta_slow=1, mscale=1, mscale_all_dim=0, ): self.original_max_position_embeddings = original_max_position_embeddings self.beta_fast = beta_fast self.beta_slow = beta_slow self.mscale = mscale self.mscale_all_dim = mscale_all_dim self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base freq_extra = 1.0 / ( self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) ) freq_inter = 1.0 / ( self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) ) low, high = yarn_find_correction_range( self.beta_fast, self.beta_slow, dim, self.base, self.original_max_position_embeddings, ) inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( device=device, dtype=torch.float32 ) self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask self._mscale = float( yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) ) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings class DynamicNTKScalingRotaryEmbedding( BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding ): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs ) self.orig_module.__init__( orig_module.dim, orig_module.max_position_embeddings, orig_module.base, None, # device orig_module.scaling_factor, orig_module.rope_type, orig_module.config, ) def load(self): self.orig_module.__init__( self.orig_module.dim, self.orig_module.max_position_embeddings, self.orig_module.base, self.orig_module.device, self.orig_module.scaling_factor, self.orig_module.rope_type, self.orig_module.config, ) class RotaryEmbeddingV4(BaseInjectedModule): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", generate_device: str = "cuda", prefill_device: str = "cuda", **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, generate_device, **kwargs ) self.generate_device = generate_device self.prefill_device = prefill_device @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def load(self): self._init( dim=self.config.qk_rope_head_dim, max_position_embeddings=self.config.max_position_embeddings, base=self.config.rope_theta, device=self.device, ) def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0): self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) # self.register_buffer("inv_freq", inv_freq, persistent=False) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings class KQwen3MoeRotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", generate_device: str = "cuda", prefill_device: str = "cuda", **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs ) self.orig_module.__init__( config, ) self.generate_device = generate_device self.prefill_device = prefill_device def load(self): self.orig_module.__init__( self.orig_module.config ) class KSmallthinkerRotaryEmbedding(BaseInjectedModule, SmallthinkerRotaryEmbedding): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", generate_device: str = "cuda", prefill_device: str = "cuda", **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs ) self.orig_module.__init__( config ) self.generate_device = generate_device self.prefill_device = prefill_device def load(self): self.orig_module.__init__( self.orig_module.config, device = self.generate_device, ) @torch.no_grad() def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class KGlm4MoeRotaryEmbedding(BaseInjectedModule, Glm4MoeRotaryEmbedding): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", generate_device: str = "cuda", prefill_device: str = "cuda", **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs ) self.orig_module.__init__( config ) self.generate_device = generate_device self.prefill_device = prefill_device def load(self): self.orig_module.__init__( self.orig_module.config, device = self.generate_device, ) @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # print(inv_freq_expanded.device) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) ================================================ FILE: archive/ktransformers/operators/__init__.py ================================================ ================================================ FILE: archive/ktransformers/operators/ascend/ascend_attention.py ================================================ # coding=utf-8 # Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved. # Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import warnings from typing import Optional, Tuple import torch import torch_npu from torch import nn import torch.nn.functional as F from transformers.configuration_utils import PretrainedConfig from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention, apply_rotary_pos_emb from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_loader import GGUFLoader from ktransformers.util.utils import get_compute_capability, get_use_npu_graph, get_current_device from ktransformers.models.custom_cache import StaticCache from ktransformers.server.balance_serve.inference.forward_batch import ForwardMiniBatchSplit from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size, allredeuce_warpper, get_tensor_parallel_group from ktransformers.util.vendors import device_manager, GPUVendor from ktransformers.util import utils def apply_rotary_pos_emb_fusion(q, k, cos, sin, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) b, h, s, d = q.shape q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) b, h, s, d = k.shape k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) q_embed = torch_npu.npu_rotary_mul(q, cos, sin) k_embed = torch_npu.npu_rotary_mul(k, cos, sin) return q_embed, k_embed class MatMulOps(object): def execute(self, x_input): """ :param x, weight, quant_bia, deq_scale :return: """ quant_out = x_input[0] weight = x_input[1] quant_bia = x_input[2] deq_scale = x_input[3] return [torch_npu.npu_quant_matmul(quant_out, weight.T, deq_scale, bias=quant_bia, output_dtype=torch.float16)] class DynamicQuantOps(object): """ :param x :return """ def execute(self, x_input): out = torch.empty_like(x_input[0], dtype=torch.int8) torch_npu._npu_quantize_per_tensor(x_input[0], x_input[1], x_input[2], out) return [out] class KDeepseekV2AttentionW8A8A2(BaseInjectedModule, DeepseekV2Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" attn_mask: Optional[torch.Tensor] = None class PageKVWrapper(object): """ wrap the difference of KV Cache and Block info between offline model & direct serving & sched serving succession should keep the function api """ def __init__(self, past_key_value: StaticCache): self.kv_cache = past_key_value self.page_size = self.kv_cache.page_size self.position = self.kv_cache.position self.page_idx = None # staticKV can get from itself self.page_offset = None def update(self, compressed_kv, k_pe, layer_idx, cache_kwargs): return self.kv_cache.update(compressed_kv, k_pe, layer_idx, cache_kwargs) def get_usable_length(self, kv_seq_len, layer_idx): return self.kv_cache.get_usable_length(kv_seq_len, layer_idx) def get_seq_length(self, layer_idx): return self.kv_cache.get_seq_length(layer_idx) def get_block_table(self, layer_idx): return self.kv_cache.page_table_list[layer_idx] def init_page_kv_wrapper(self, past_key_value: StaticCache): self.page_kv_wrapper = self.PageKVWrapper(past_key_value) def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", chunck_size: int = 1000, absorb_for_prefill: bool = False, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) self.chunck_size = config.chunk_size self.mla_wrapper = None self.page_kv_wrapper = None self.absorb_for_prefill = absorb_for_prefill self.use_merge = os.getenv("USE_MERGE", "0") tp = get_tensor_parallel_size() if tp > 1: self.num_heads //= tp if self.use_merge == "0": self.elewise_quant = DynamicQuantOps() self.matmulDequant_operation = MatMulOps() self.matmulDequant_operation_aclnn = MatMulOps() elif self.use_merge == "1": print("--Use torch npu FA OP !--") else: print("--Use default op !--") self.sparse_mode = 0 @allredeuce_warpper def forward_chunck( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[StaticCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, is_prefill: bool = True, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: hidden_states_quant = self.elewise_quant.execute([hidden_states, self.q_a_proj.input_scale, self.q_a_proj.input_offset])[0] q_a_proj_out = self.matmulDequant_operation.execute([hidden_states_quant, self.q_a_proj.weight, self.q_a_proj.quant_bias, self.q_a_proj.deq_scale])[0] q_a_proj_out = self.q_a_layernorm(q_a_proj_out) q_a_proj_out = self.elewise_quant.execute([q_a_proj_out, self.q_b_proj.input_scale, self.q_b_proj.input_offset])[0] q = self.matmulDequant_operation.execute([q_a_proj_out, self.q_b_proj.weight, self.q_b_proj.quant_bias, self.q_b_proj.deq_scale])[0] q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) hidden_states_quant = self.elewise_quant.execute([hidden_states, self.kv_a_proj_with_mqa.input_scale, self.kv_a_proj_with_mqa.input_offset])[0] compressed_kv = self.matmulDequant_operation.execute([hidden_states_quant, self.kv_a_proj_with_mqa.weight, self.kv_a_proj_with_mqa.quant_bias, self.kv_a_proj_with_mqa.deq_scale])[0] compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_layernorm(compressed_kv) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv_seq_len = k_pe.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += self.page_kv_wrapper.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb_fusion(q_pe, k_pe, cos, sin) # update KV if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models cache_kwargs["page_idx"] = self.page_kv_wrapper.page_idx cache_kwargs["page_offset"] = self.page_kv_wrapper.page_offset k_pe = k_pe.transpose(1, 2) # k_pe [bsz, 1, q_len, self.qk_rope_head_dim] compressed_kv = compressed_kv.unsqueeze(2) # compressed_kv [bsz, q_len, self.kv_lora_rank] compressed_kv_with_k_pe, _ = self.page_kv_wrapper.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) if is_prefill: compressed_kv_prefill = compressed_kv.clone() # clone for prefill infer k_pe_prefill = k_pe.clone() compressed_kv, k_pe = torch.split( compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) weight_uk = self.q_absorb weight_uv = self.out_absorb # ATB-MLA-FA+PA if self.use_merge == "0" and is_prefill: # if self.layer_idx == 0: # print(self.page_kv_wrapper.get_seq_length(self.layer_idx) # self.page_kv_wrapper.get_block_table(self.layer_idx), self.page_kv_wrapper.position) current_sqenLen = self.page_kv_wrapper.get_seq_length(self.layer_idx) attention_mask = attention_mask[0, :, :, :current_sqenLen].squeeze(0).squeeze(0) # FIXME this is wrong in random choose pages for sched, currently just use kv without history # compressed_kv = compressed_kv.view(bsz, 1, -1, self.kv_lora_rank)[:,:,:current_sqenLen,:] # k_pe = k_pe.view(bsz, 1, -1, self.qk_rope_head_dim)[:,:,:current_sqenLen,:] compressed_kv = compressed_kv_prefill.transpose(1,2).contiguous() k_pe = k_pe_prefill.transpose(1,2).contiguous() k_pe_repeated = k_pe.repeat(1, self.num_heads, 1, 1) k_up = torch.matmul(compressed_kv, weight_uk.mT) v_up = torch.matmul(compressed_kv, weight_uv) qTensor = torch.cat((q_nope, q_pe), dim=-1).transpose(1, 2).contiguous().view( bsz, q_len, self.num_heads, (self.qk_nope_head_dim + self.qk_rope_head_dim)) kTensor = torch.cat((k_up, k_pe_repeated), dim=-1).transpose(1, 2).contiguous().view( bsz, current_sqenLen, self.num_heads, (self.qk_nope_head_dim + self.qk_rope_head_dim)) vTensor = torch.cat((v_up, k_pe_repeated), dim=-1).transpose(1, 2).contiguous().view( bsz, current_sqenLen, self.num_heads, (self.v_head_dim + self.qk_rope_head_dim)) seq_len_data = [q_len] * bsz infer_attention_output, _ = torch_npu.npu_fused_infer_attention_score( qTensor, kTensor, vTensor, atten_mask = attention_mask.type(torch.int8), actual_seq_lengths = seq_len_data, scale = self.softmax_scale, num_heads = self.num_heads, num_key_value_heads = self.num_heads, input_layout = "BSND") attn_output = infer_attention_output[..., :self.v_head_dim] if tuple(attn_output.size()) != (bsz, q_len, self.num_heads, self.v_head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.v_head_dim)}, but is" f" {tuple(attn_output.size())}" ) attn_output = attn_output.contiguous().view(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.elewise_quant.execute([attn_output, self.o_proj.input_scale, self.o_proj.input_offset])[0] attn_output = self.matmulDequant_operation_aclnn.execute([attn_output, self.o_proj.weight, self.o_proj.quant_bias, self.o_proj.deq_scale])[0] return attn_output, None, past_key_value elif self.use_merge == "0" and not is_prefill: return self.forward_paged(q_pe=q_pe, q_nope=q_nope, compressed_kv_with_k_pe=compressed_kv_with_k_pe, past_key_value=past_key_value, cache_position=cache_position) if self.use_merge == "1": k_pe_repeated = k_pe.repeat(1, self.num_heads, 1, 1) k_up = torch.matmul(compressed_kv, weight_uk.mT) v_up = torch.matmul(compressed_kv, weight_uv) qTensor = torch.cat((q_nope, q_pe), dim=-1) kTensor = torch.cat((k_up, k_pe_repeated), dim=-1) vTensor = torch.cat((v_up, k_pe_repeated), dim=-1) if q_len != 1: attn_output = torch_npu.npu_prompt_flash_attention( qTensor, kTensor, vTensor, num_heads=self.num_heads, scale_value=self.softmax_scale, input_layout="BNSD") else: attn_output = torch_npu.npu_incre_flash_attention( qTensor, kTensor, vTensor, num_heads=self.num_heads, scale_value=self.softmax_scale, input_layout="BNSD") attn_output = attn_output[:, :, :, :self.v_head_dim] else: q_nope = torch.matmul(q_nope, self.q_absorb) attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale compressed_kv = compressed_kv.squeeze(1) """ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) assert attention_mask is not None """ if attention_mask is not None: """ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) """ attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(q_pe.dtype) attn_weights = nn.functional.dropout( attn_weights, p=self.attention_dropout, training=self.training ) attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv) attn_output = torch.matmul(attn_output, self.out_absorb) if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value def forward_paged( self, q_pe: torch.Tensor, q_nope: torch.Tensor, compressed_kv_with_k_pe: torch.Tensor, past_key_value: Optional[StaticCache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # if self.layer_idx == 1: # print(self.page_kv_wrapper.get_block_table(self.layer_idx), self.page_kv_wrapper.position) bsz, _, q_len, _ = q_nope.size() q_nope = torch.einsum('b h q d, h d k -> b h q k', q_nope, self.q_absorb) # torch.Size([1, 128, 1, 512]) compressed_kv = compressed_kv_with_k_pe.permute(0, 2, 1, 3) kvCache = compressed_kv[:, :, :, :self.kv_lora_rank].contiguous() kRopeCache = compressed_kv[:, :, :, self.kv_lora_rank:].contiguous() if get_use_npu_graph(): from ktransformers.util.npu_graph_runner import get_or_create_runner npu_graph_runner = get_or_create_runner(get_current_device()) stream = npu_graph_runner.main_stream if npu_graph_runner.past_key_value is None: npu_graph_runner.past_key_value = past_key_value if npu_graph_runner.workspace is None: workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( q_nope, kvCache, kvCache, query_rope=q_pe, key_rope=kRopeCache, num_heads=self.num_heads, num_key_value_heads=1, input_layout="BNSD", scale=self.softmax_scale, antiquant_mode=0, antiquant_scale=None, block_table=self.page_kv_wrapper.get_block_table(self.layer_idx), block_size=self.page_kv_wrapper.page_size, actual_seq_lengths_kv=self.page_kv_wrapper.position, sparse_mode = self.sparse_mode) npu_graph_runner.workspace = workspace attn_output = torch.zeros_like(q_nope, dtype=torch.float16, device=get_current_device()) softmax_lse = torch.empty(1, dtype=torch.float16, device=get_current_device()) torch_npu.npu_fused_infer_attention_score.out( q_nope, kvCache, kvCache, workspace=npu_graph_runner.workspace, query_rope=q_pe, key_rope=kRopeCache, num_heads=self.num_heads, num_key_value_heads=1, input_layout="BNSD", scale=self.softmax_scale, antiquant_mode=0, antiquant_scale=None, block_table=self.page_kv_wrapper.get_block_table(self.layer_idx), block_size=self.page_kv_wrapper.page_size, actual_seq_lengths_kv=self.page_kv_wrapper.position, sparse_mode = self.sparse_mode, out=[attn_output, softmax_lse]) else: attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( q_nope, kvCache, kvCache, query_rope=q_pe, key_rope=kRopeCache, num_heads=self.num_heads, num_key_value_heads=1, input_layout="BNSD", scale=self.softmax_scale, antiquant_mode=0, antiquant_scale=None, block_table=self.page_kv_wrapper.get_block_table(self.layer_idx), block_size=self.page_kv_wrapper.page_size, actual_seq_lengths_kv=self.page_kv_wrapper.position, sparse_mode = self.sparse_mode ) attn_output = torch.einsum('b h q k, h k v -> b q h v', attn_output, self.out_absorb) attn_output = attn_output.contiguous().view(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.elewise_quant.execute([attn_output, self.o_proj.input_scale, self.o_proj.input_offset])[0] attn_output = self.matmulDequant_operation_aclnn.execute([attn_output, self.o_proj.weight, self.o_proj.quant_bias, self.o_proj.deq_scale])[0] return attn_output, None, past_key_value def forward_windows( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[StaticCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, is_prefill: bool = True, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) self.init_page_kv_wrapper(past_key_value) bsz, q_len, _ = hidden_states.size() if q_len <= self.chunck_size: return self.forward_chunck( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, is_prefill, **kwargs ) assert output_attentions == False, "output_attentions is not supported when using chunked attention" attn_output = None cur_idx = 0 while cur_idx < q_len: if attention_mask is not None: chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...] else: # generate chunk_mask automatically. self.attn_mask = \ torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \ if self.attn_mask is None \ else self.attn_mask self.attn_mask[:, :, :, cur_idx:min(cur_idx + self.chunck_size, past_key_value.max_cache_len)] = \ -65504.0 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1) \ [:, :min(self.chunck_size, min(past_key_value.max_cache_len - cur_idx, self.chunck_size))] self.attn_mask[:, :, :, cur_idx + self.chunck_size:] = -65504.0 self.attn_mask[:, :, :, :cur_idx] = 0 chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len - cur_idx)) cur_output, _, _ = self.forward_chunck( hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...], chunk_mask, position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)], past_key_value, output_attentions, use_cache, cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)], **kwargs ) cur_idx += self.chunck_size if attn_output is None: attn_output = cur_output else: attn_output = torch.cat((attn_output, cur_output), dim=-2) return attn_output, None, past_key_value def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[StaticCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, is_prefill: bool = True, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # TODO: remove cache_position since it do not support multi-batch return self.forward_windows( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, is_prefill, **kwargs, ) class KDeepseekV2AttentionW8A8A2Serve(BaseInjectedModule, DeepseekV2Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" attn_mask: Optional[torch.Tensor] = None def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", chunck_size: int = 1024, absorb_for_prefill: bool = False, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) # self.chunck_size = chunck_size self.absorb_for_prefill = absorb_for_prefill self.elewise_quant = DynamicQuantOps() self.matmulDequant_operation = MatMulOps() self.matmulDequant_operation_aclnn = MatMulOps() # tp切分 tp = get_tensor_parallel_size() if tp > 1: self.num_heads //= tp self.sparse_mode = 0 def print_callback(self, param): with torch.npu.stream(torch.npu.Stream(device="npu:0")): hidden_states, position_ids, cache_position, page_idx, page_offset, block_table = param print("########################################") print("hidden_states is ", hidden_states) print("position_ids is ", position_ids) print("cache_position is ", cache_position) print("page_idx is ", page_idx) print("page_offset is ", page_offset) print("block_table is ", block_table) print("########################################") @allredeuce_warpper def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[StaticCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, is_prefill: Optional[bool] = None, page_idx: Optional[torch.Tensor] = None, page_offset: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, q_len_raw: Optional[torch.Tensor] = None, kv_len_raw: Optional[torch.Tensor] = None, stream: Optional[torch.npu.Stream] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: def create_causal_mask(q_lens, kv_lens): q_lens = torch.tensor(q_lens) kv_lens = torch.tensor(kv_lens) bsz = q_lens.size(0) max_q_len = q_lens.max().item() max_kv_len = kv_lens.max().item() # causal mask [max_q_len, max_kv_len] base_causal = torch.tril(torch.ones((max_q_len, max_kv_len), dtype=torch.bool)) # mask initialize: [bsz, max_q_len, max_kv_len] to False mask = torch.zeros((bsz, max_q_len, max_kv_len), dtype=torch.bool) for i in range(bsz): ql, kl = q_lens[i].item(), kv_lens[i].item() # copy base_causal to mask mask[i, :ql, :kl] = base_causal[:ql, :kl] return mask bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: hidden_states_quant = self.elewise_quant.execute([hidden_states, self.q_a_proj.input_scale, self.q_a_proj.input_offset])[0] q_a_proj_out = self.matmulDequant_operation.execute([hidden_states_quant, self.q_a_proj.weight, self.q_a_proj.quant_bias, self.q_a_proj.deq_scale])[0] q_a_proj_out = self.q_a_layernorm(q_a_proj_out) q_a_proj_out = self.elewise_quant.execute([q_a_proj_out, self.q_b_proj.input_scale, self.q_b_proj.input_offset])[0] q = self.matmulDequant_operation.execute([q_a_proj_out, self.q_b_proj.weight, self.q_b_proj.quant_bias, self.q_b_proj.deq_scale])[0] q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) hidden_states_quant = self.elewise_quant.execute([hidden_states, self.kv_a_proj_with_mqa.input_scale, self.kv_a_proj_with_mqa.input_offset])[0] compressed_kv = self.matmulDequant_operation.execute([hidden_states_quant, self.kv_a_proj_with_mqa.weight, self.kv_a_proj_with_mqa.quant_bias, self.kv_a_proj_with_mqa.deq_scale])[0] compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_layernorm(compressed_kv) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv_seq_len = k_pe.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb_fusion(q_pe, k_pe, cos, sin) # update KV compressed_kv_prefill, k_pe_prefill = None, None if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models cache_kwargs["page_idx"], cache_kwargs["page_offset"] = page_idx, page_offset k_pe = k_pe.transpose(1, 2) # k_pe [bsz, 1, q_len, self.qk_rope_head_dim] compressed_kv = compressed_kv.unsqueeze(2) # compressed_kv [bsz, q_len, self.kv_lora_rank] combined = torch.cat([compressed_kv, k_pe], dim=-1) # shape: [batch_size, num_heads, 2*self.kv_lora_rank] # combined = combined.contiguous() compressed_kv_with_k_pe, _ = past_key_value.update(combined, self.layer_idx, cache_kwargs) if is_prefill: compressed_kv_prefill = compressed_kv.clone() k_pe_prefill = k_pe.clone() weight_uk = self.q_absorb weight_uv = self.out_absorb if is_prefill: kTensor_list = [] vTensor_list = [] qTensor_list = [] attention_mask_list = [] seq_len_data = [] kv_len_list = [] for sample_idx in range(bsz): current_q_len = q_len_raw[sample_idx].item() if (q_len_raw is not None and sample_idx < len(q_len_raw)) else hidden_states.shape[1] current_kv_len = kv_len_raw[sample_idx].item() if (kv_len_raw is not None and sample_idx < len(kv_len_raw)) else current_q_len current_q_len = max(1, current_q_len) current_kv_len = max(1, current_kv_len) seq_len_data.append(current_q_len) kv_len_list.append(current_kv_len) if attention_mask is not None: mask_sample = attention_mask[ sample_idx:sample_idx+1, :, :, :current_kv_len ].squeeze(0).squeeze(0) if mask_sample.shape[0] < current_q_len: mask_sample = torch.nn.functional.pad(mask_sample, (0, 0, 0, current_q_len - mask_sample.shape[0]), value=1) elif mask_sample.shape[0] > current_q_len: mask_sample = mask_sample[:current_q_len, :] if mask_sample.shape[1] < current_kv_len: mask_sample = torch.nn.functional.pad(mask_sample, (0, current_kv_len - mask_sample.shape[1]), value=1) elif mask_sample.shape[1] > current_kv_len: mask_sample = mask_sample[:, :current_kv_len] mask_sample = torch.where( (mask_sample > -1e-6) & (mask_sample < 1e-6), torch.tensor(0, device=mask_sample.device, dtype=torch.int8), torch.tensor(1, device=mask_sample.device, dtype=torch.int8) ) else: mask_sample = torch.ones((current_q_len, current_kv_len), device=hidden_states.device, dtype=torch.int8) valid_len = min(current_q_len, current_kv_len) mask_sample[:, :valid_len] = 0 attention_mask_list.append(mask_sample) compressed_kv_sample = compressed_kv_prefill[sample_idx:sample_idx+1, :current_q_len, ...].transpose(1, 2).contiguous() k_pe_sample = k_pe_prefill[sample_idx:sample_idx+1, :current_q_len, ...].transpose(1, 2).contiguous() k_pe_repeated_sample = k_pe_sample.repeat(1, self.num_heads, 1, 1) q_nope_sample = q_nope[sample_idx:sample_idx+1, :, :current_q_len, :].contiguous() q_pe_sample = q_pe[sample_idx:sample_idx+1, :, :current_q_len, :].contiguous() q_concat_sample = torch.cat((q_nope_sample, q_pe_sample), dim=-1) q_transposed_sample = q_concat_sample.transpose(1, 2).contiguous() qTensor_sample = q_transposed_sample.view(current_q_len, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim) qTensor_list.append(qTensor_sample) k_up_sample = torch.matmul(compressed_kv_sample, weight_uk.mT) k_concat_sample = torch.cat((k_up_sample, k_pe_repeated_sample), dim=-1) k_transposed_sample = k_concat_sample.transpose(1, 2).contiguous() kTensor_sample = k_transposed_sample.view(current_kv_len, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim) kTensor_list.append(kTensor_sample) v_up_sample = torch.matmul(compressed_kv_sample, weight_uv) v_concat_sample = torch.cat((v_up_sample, k_pe_repeated_sample), dim=-1) v_transposed_sample = v_concat_sample.transpose(1, 2).contiguous() vTensor_sample = v_transposed_sample.view(current_kv_len, self.num_heads, self.v_head_dim + self.qk_rope_head_dim) vTensor_list.append(vTensor_sample) max_kv_len = max(kv_len_list) max_q_len = max(seq_len_data) qTensor = torch.nn.utils.rnn.pad_sequence(qTensor_list, batch_first=True, padding_value=0.0).contiguous() kTensor = torch.nn.utils.rnn.pad_sequence(kTensor_list, batch_first=True, padding_value=0.0).contiguous() vTensor = torch.nn.utils.rnn.pad_sequence(vTensor_list, batch_first=True, padding_value=0.0).contiguous() attention_mask = ~create_causal_mask(seq_len_data, kv_len_list).to(qTensor.device) infer_attention_output, _ = torch_npu.npu_fused_infer_attention_score( qTensor, kTensor, vTensor, atten_mask = attention_mask.type(torch.int8), actual_seq_lengths = seq_len_data, scale = self.softmax_scale, num_heads = self.num_heads, num_key_value_heads = self.num_heads, input_layout = "BSND") attn_output = infer_attention_output[..., :self.v_head_dim] if tuple(attn_output.size()) != (bsz, max_q_len, self.num_heads, self.v_head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, max_q_len, self.num_heads, self.v_head_dim)}, but is {tuple(attn_output.size())}" ) attn_output = attn_output.contiguous().view(bsz, max_q_len, self.num_heads * self.v_head_dim) attn_output = self.elewise_quant.execute([attn_output, self.o_proj.input_scale, self.o_proj.input_offset])[0] attn_output = self.matmulDequant_operation_aclnn.execute([attn_output, self.o_proj.weight, self.o_proj.quant_bias, self.o_proj.deq_scale])[0] return attn_output, None, past_key_value else: return self.forward_paged(q_pe = q_pe, q_nope = q_nope, compressed_kv_with_k_pe = compressed_kv_with_k_pe, past_key_value = past_key_value, cache_position = cache_position, block_table = block_table, page_size = past_key_value.page_size, q_len_raw = q_len_raw, kv_len_raw = kv_len_raw, stream = stream) @allredeuce_warpper def forward_paged( self, q_pe: torch.Tensor, q_nope: torch.Tensor, compressed_kv_with_k_pe: torch.Tensor, past_key_value: Optional[StaticCache] = None, cache_position: Optional[torch.LongTensor] = None, block_table: Optional[torch.Tensor] = None, page_size: Optional[int] = None, q_len_raw: Optional[torch.Tensor] = None, kv_len_raw: Optional[torch.Tensor] = None, stream: Optional[torch.npu.Stream] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # if self.layer_idx == 0: # print(self.page_kv_wrapper.get_block_table(self.layer_idx), self.page_kv_wrapper.position) bsz, _, q_len, _ = q_nope.size() # print(f"{q_nope.size()=}") q_nope = torch.einsum('b h q d, h d k -> b h q k', q_nope, self.q_absorb) # torch.size([1, 128, 1, 512]) compressed_kv = compressed_kv_with_k_pe.permute(0,2,1,3) kvCache = compressed_kv[:,:,:,:self.kv_lora_rank].contiguous() kRopeCache = compressed_kv[:,:,:,self.kv_lora_rank:].contiguous() if get_use_npu_graph(): from ktransformers.server.balance_serve.inference.model_runner import ModelRunner, get_or_create_model_runner npu_graph_runner = get_or_create_model_runner(device=get_current_device()) npu_graph_idx = bsz - 1 if npu_graph_runner.workspace[npu_graph_idx] is None: workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( q_nope, kvCache, kvCache, query_rope=q_pe, key_rope=kRopeCache, num_heads=self.num_heads, num_key_value_heads=1, input_layout="BNSD", scale=self.softmax_scale, antiquant_mode=0, antiquant_scale=None, block_table=block_table, block_size=page_size, actual_seq_lengths_kv=kv_len_raw, sparse_mode = self.sparse_mode) npu_graph_runner.workspace[npu_graph_idx] = workspace attn_output = torch.zeros_like(q_nope, dtype=torch.float16, device=get_current_device()) softmax_lse = torch.empty(1, dtype=torch.float16, device=get_current_device()) torch_npu.npu_fused_infer_attention_score.out( q_nope, kvCache, kvCache, workspace=npu_graph_runner.workspace[npu_graph_idx], query_rope = q_pe, key_rope = kRopeCache, num_heads = self.num_heads, num_key_value_heads = 1, input_layout = "BNSD", scale = self.softmax_scale, antiquant_mode = 0, antiquant_scale = None, block_table = block_table, block_size = page_size, actual_seq_lengths_kv = kv_len_raw, sparse_mode = self.sparse_mode, out=[attn_output, softmax_lse]) else: tp_group = get_tensor_parallel_group() torch.distributed.barrier(tp_group) attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( q_nope, kvCache, kvCache, query_rope = q_pe, key_rope = kRopeCache, num_heads = self.num_heads, num_key_value_heads = 1, input_layout = "BNSD", scale = self.softmax_scale, antiquant_mode = 0, antiquant_scale = None, block_table = block_table, block_size = page_size, actual_seq_lengths_kv = kv_len_raw, sparse_mode = self.sparse_mode ) attn_output = torch.einsum('b h q k, h k v -> b q h v', attn_output, self.out_absorb) attn_output = attn_output.contiguous().view(bsz, q_len, self.num_heads*self.v_head_dim) attn_output = self.elewise_quant.execute([attn_output, self.o_proj.input_scale, self.o_proj.input_offset])[0] attn_output = self.matmulDequant_operation_aclnn.execute([attn_output, self.o_proj.weight, self.o_proj.quant_bias, self.o_proj.deq_scale])[0] return attn_output, None, past_key_value def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class KQwen3MoeAttentionW8A8A2Serve(BaseInjectedModule, Qwen3MoeAttention): attn_mask: Optional[torch.Tensor] = None def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "npu", generate_device: str = "npu", chunck_size: int = 1024, absorb_for_prefill: bool = False, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) self.absorb_for_prefill = absorb_for_prefill self.elewise_quant = DynamicQuantOps() self.matmulDequant_operation = MatMulOps() self.matmulDequant_operation_aclnn = MatMulOps() self.softmax_scale = self.scaling self.sparse_mode = 0 self._prefill_step = 0 self._cur_prefill_dir: Optional[str] = None if hasattr(self, "rotary_emb"): if hasattr(self.rotary_emb, "cos_cached"): self.rotary_emb.cos_cached = self.rotary_emb.cos_cached.to(torch.float16) self.rotary_emb.sin_cached = self.rotary_emb.sin_cached.to(torch.float16) if hasattr(self.rotary_emb, "inv_freq"): self.rotary_emb.inv_freq = self.rotary_emb.inv_freq.to(torch.float16) def _linear_w8a8a2(self, x: torch.Tensor, proj: nn.Module, name: str) -> torch.Tensor: if x.dtype == torch.bfloat16: x = x.to(torch.float16) B, Q, H_in = x.shape x_2d = x.view(-1, H_in) # [T, H_in], T = B * Q x_q = self.elewise_quant.execute([ x_2d, proj.input_scale, proj.input_offset ])[0] y_2d = self.matmulDequant_operation.execute([ x_q, proj.weight, proj.quant_bias, proj.deq_scale ])[0] return y_2d.view(B, Q, -1) # ------------------------------------------------------- # forward # ------------------------------------------------------- def forward(self, hidden_states: torch.Tensor, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, cache_position=None, is_prefill=None, page_idx=None, page_offset=None, block_table=None, q_len_raw=None, kv_len_raw=None, stream=None, **kwargs): if hidden_states.dim() == 2: hidden_states = hidden_states.unsqueeze(0) bsz, q_len, hidden = hidden_states.shape # -------- QKV -------- q_proj_out = self._linear_w8a8a2(hidden_states, self.q_proj, "Q") B, S, _ = q_proj_out.shape q = q_proj_out.view(B, S, self.num_heads, self.head_dim) # [B, S, H, Dh] q = self.q_norm(q) q_in = q.view(B, S, -1) k_proj_out = self._linear_w8a8a2(hidden_states, self.k_proj, "K") k = k_proj_out.view(B, S, self.num_key_value_heads, self.head_dim) k = self.k_norm(k) k_in = k.view(B, S, -1) v_in = self._linear_w8a8a2(hidden_states, self.v_proj, "V") q = q_in.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) k = k_in.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) v = v_in.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # -------- RoPE -------- cos, sin = self.rotary_emb(v, position_ids) q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1) # -------- prefill / decode -------- if is_prefill: out = self._forward_prefill( q, k, v, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, q_len_raw=q_len_raw, kv_len_raw=kv_len_raw, page_idx=page_idx, page_offset=page_offset, block_table=block_table, ) return out else: return self.forward_paged( q=q, k=k, v=v, past_key_value=past_key_value, cache_position=cache_position, block_table=block_table, page_size=getattr(past_key_value, "page_size", None), q_len_raw=q_len_raw, kv_len_raw=kv_len_raw, page_idx=page_idx, page_offset=page_offset, stream=stream ) # ------------------------------------------------------- # Prefill # ------------------------------------------------------- def _forward_prefill( self, q: torch.Tensor, # [B, H, Q, Dh] k: torch.Tensor, # [B, KvH, Q, Dh] v: torch.Tensor, # [B, KvH, Q, Dh] attention_mask=None, position_ids=None, past_key_value=None, q_len_raw=None, kv_len_raw=None, page_idx=None, page_offset=None, block_table=None, **kwargs, ) -> torch.Tensor: B, H, Q, Dh = q.shape KvH = k.shape[1] # ---------- 1) 写 KV cache ---------- if ( past_key_value is not None and page_idx is not None and page_offset is not None ): try: past_key_value.update( key_states=k, value_states=v, layer_idx=self.layer_idx, cache_kwargs={ "page_idx": page_idx, "page_offset": page_offset, }, ) except Exception as e: print(f"[PREFILL-QWEN3][WARN] KV cache update failed: {e}", flush=True) # ---------- 2) GQA:4 KV → 32 Q heads ---------- if KvH != self.num_key_value_heads: print( f"[PREFILL-QWEN3][WARN] KvH ({KvH}) != config.num_key_value_heads " f"({self.num_key_value_heads}), 使用 k.shape[1] 作为 KvH", flush=True, ) KvH = k.shape[1] if H % KvH != 0: raise ValueError( f"[PREFILL-QWEN3] num_heads={H} 不是 num_kv_heads={KvH} 的整数倍" ) group_size = H // KvH k_full = k.repeat_interleave(group_size, dim=1) v_full = v.repeat_interleave(group_size, dim=1) print("[PREFILL-QWEN3] k_full/v_full:", k_full.shape, v_full.shape, flush=True) # ---------- 3) BSND + causal mask ---------- q_bsnd = q.permute(0, 2, 1, 3).contiguous() # [B, Q, H, Dh] k_bsnd = k_full.permute(0, 2, 1, 3).contiguous() v_bsnd = v_full.permute(0, 2, 1, 3).contiguous() if q_len_raw is None: seq_len_data = [Q for _ in range(B)] kv_len_list = [Q for _ in range(B)] else: seq_len_data = [] kv_len_list = [] for b_idx in range(B): cur_q = int(q_len_raw[b_idx].item()) if kv_len_raw is not None: cur_kv = int(kv_len_raw[b_idx].item()) else: cur_kv = cur_q cur_q = max(1, cur_q) cur_kv = max(1, cur_kv) seq_len_data.append(cur_q) kv_len_list.append(cur_kv) def create_causal_mask(q_lens, kv_lens): q_lens_t = torch.tensor(q_lens, device=q_bsnd.device) kv_lens_t = torch.tensor(kv_lens, device=q_bsnd.device) bsz = q_lens_t.size(0) max_q = int(q_lens_t.max().item()) max_kv = int(kv_lens_t.max().item()) base_causal = torch.tril( torch.ones((max_q, max_kv), dtype=torch.bool, device=q_bsnd.device) ) mask = torch.zeros( (bsz, max_q, max_kv), dtype=torch.bool, device=q_bsnd.device ) for i in range(bsz): ql = int(q_lens_t[i].item()) kl = int(kv_lens_t[i].item()) mask[i, :ql, :kl] = base_causal[:ql, :kl] return mask max_q_len = max(seq_len_data) if len(seq_len_data) > 0 else Q max_kv_len = max(kv_len_list) if len(kv_len_list) > 0 else Q q_list, k_list, v_list = [], [], [] for b_idx in range(B): cur_q = seq_len_data[b_idx] cur_kv = kv_len_list[b_idx] q_sample = q_bsnd[b_idx, :cur_q, :, :].contiguous() k_sample = k_bsnd[b_idx, :cur_kv, :, :].contiguous() v_sample = v_bsnd[b_idx, :cur_kv, :, :].contiguous() q_list.append(q_sample) k_list.append(k_sample) v_list.append(v_sample) qTensor = torch.nn.utils.rnn.pad_sequence( q_list, batch_first=True, padding_value=0.0 ).contiguous() kTensor = torch.nn.utils.rnn.pad_sequence( k_list, batch_first=True, padding_value=0.0 ).contiguous() vTensor = torch.nn.utils.rnn.pad_sequence( v_list, batch_first=True, padding_value=0.0 ).contiguous() causal_mask = create_causal_mask(seq_len_data, kv_len_list) atten_mask = (~causal_mask).to(torch.int8) print("[PREFILL-QWEN3] qTensor/kTensor/vTensor:", qTensor.shape, kTensor.shape, vTensor.shape, flush=True) # ---------- 4) NPU fused attention ---------- infer_attention_output, _ = torch_npu.npu_fused_infer_attention_score( qTensor, kTensor, vTensor, atten_mask=atten_mask, actual_seq_lengths=seq_len_data, scale=self.softmax_scale, num_heads=H, num_key_value_heads=H, input_layout="BSND", ) attn_output = infer_attention_output # ---------- 5) reshape + W8A8 o_proj ---------- attn_output = attn_output.contiguous().view(B, max_q_len, H * Dh) attn_output_q = self.elewise_quant.execute( [attn_output, self.o_proj.input_scale, self.o_proj.input_offset] )[0] attn_output = self.matmulDequant_operation_aclnn.execute( [attn_output_q, self.o_proj.weight, self.o_proj.quant_bias, self.o_proj.deq_scale] )[0] print("[PREFILL-QWEN3] attn_output(after o_proj):", attn_output.shape, attn_output.dtype, flush=True) return attn_output def forward_paged( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, past_key_value, cache_position, block_table, page_size, q_len_raw, kv_len_raw, page_idx, page_offset, stream, **kwargs, ): B, H, Q, Dh = q.shape KvH = k.shape[1] # ========= 1) 更新 KV cache ========= past_key_value.update( key_states=k, value_states=v, layer_idx=self.layer_idx, cache_kwargs={ "page_idx": page_idx, "page_offset": page_offset, }, ) Kcache = past_key_value.get_k_cache(self.layer_idx) Vcache = past_key_value.get_v_cache(self.layer_idx) q_bnsd = q.contiguous() k_bnsd = Kcache.contiguous().to(torch.float16).transpose(1, 2) v_bnsd = Vcache.contiguous().to(torch.float16).transpose(1, 2) use_graph = get_use_npu_graph() device = get_current_device() if use_graph: from ktransformers.server.balance_serve.inference.model_runner import get_or_create_model_runner npu_graph_runner = get_or_create_model_runner(device=device) npu_graph_idx = B - 1 if npu_graph_runner.workspace[npu_graph_idx] is None: workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( q_bnsd, k_bnsd, v_bnsd, num_heads=H, num_key_value_heads=KvH, input_layout="BNSD", scale=self.softmax_scale, antiquant_mode=0, antiquant_scale=None, block_table=block_table, block_size=page_size, actual_seq_lengths_kv=kv_len_raw, sparse_mode=self.sparse_mode, ) npu_graph_runner.workspace[npu_graph_idx] = workspace attn_output = torch.zeros_like(q_bnsd, dtype=torch.float16, device=device) softmax_lse = torch.empty(1, dtype=torch.float16, device=device) torch_npu.npu_fused_infer_attention_score.out( q_bnsd, k_bnsd, v_bnsd, workspace=npu_graph_runner.workspace[npu_graph_idx], num_heads=H, num_key_value_heads=KvH, input_layout="BNSD", scale=self.softmax_scale, antiquant_mode=0, antiquant_scale=None, block_table=block_table, block_size=page_size, actual_seq_lengths_kv=kv_len_raw, sparse_mode=self.sparse_mode, out=[attn_output, softmax_lse] ) else: attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( q_bnsd, k_bnsd, v_bnsd, num_heads=H, num_key_value_heads=KvH, input_layout="BNSD", scale=self.softmax_scale, antiquant_mode=0, antiquant_scale=None, block_table=block_table, block_size=page_size, actual_seq_lengths_kv=kv_len_raw, sparse_mode=self.sparse_mode, ) attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(B, -1, H * Dh) attn_output_q = self.elewise_quant.execute( [attn_output, self.o_proj.input_scale, self.o_proj.input_offset] )[0] attn_output = self.matmulDequant_operation_aclnn.execute( [attn_output_q, self.o_proj.weight, self.o_proj.quant_bias, self.o_proj.deq_scale] )[0] return attn_output ================================================ FILE: archive/ktransformers/operators/ascend/ascend_experts.py ================================================ # coding=utf-8 # Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved. # Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re import os from typing import Optional import bisect import torch import numpy as np from torch import nn import torch_npu from transformers import PretrainedConfig import torch.nn.functional as F from ktransformers.util.custom_loader import GGUFLoader from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size, get_tensor_parallel_group from ktransformers.operators.experts import cuda_graphs, KExpertsBase, KExpertsCPU, KTransformersExperts, EXPERTS_MAP, KDeepseekV3MoE from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.utils import CUR_DEVICE, get_use_npu_graph, InferenceState from ktransformers.operators.experts import cuda_graphs as npu_graphs from ktransformers.util import utils class KExpertsCPUW8A8(KExpertsCPU): def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=None, use_npu_graph=False): if use_npu_graph: seq_len = input_tensor.size(0) cuda_graph_idx = seq_len - 1 if cuda_graph_idx is None else cuda_graph_idx # input_tensor is seq & batch merged self.cpu_infer.submit(self.moe.forward(KExpertsCPU.expert_ids_cpu[cuda_graph_idx][0].size(0), KExpertsCPU.expert_ids_cpu[cuda_graph_idx][0].size(1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx][0].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx][0].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx][0].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx][0].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx][0].data_ptr() )) self.cpu_infer.sync() else: if bsz_tensor is None: bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32) # if torch.cuda.is_current_stream_capturing(): org_type = input_tensor.dtype input_tensor = input_tensor.contiguous().cpu() input_tensor = input_tensor.to(torch.bfloat16) expert_ids = expert_ids.contiguous().cpu() weights = weights.contiguous().to(torch.float32).cpu() bsz_tensor = bsz_tensor.contiguous().cpu() output = torch.empty_like(input_tensor).contiguous() self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr())) self.cpu_infer.sync() return output.to(org_type).to(device=utils.get_current_device()) EXPERTS_MAP["KExpertsCPUW8A8"] = KExpertsCPUW8A8 class KTransformersExpertsW8A8(KTransformersExperts): def forward(self, input_tensor, expert_ids, weights, cuda_graph_idx=None, use_npu_graph=False): if self.mode == InferenceState.GENERATE: assert self.generate_experts is not None, "generate_experts is None" return self.generate_experts.forward(input_tensor, expert_ids, weights, cuda_graph_idx=cuda_graph_idx, use_npu_graph=use_npu_graph) elif self.mode == InferenceState.PREFILL: assert self.prefill_experts is not None, "prefill_experts is None" return self.prefill_experts.forward(input_tensor, expert_ids, weights, cuda_graph_idx=cuda_graph_idx, use_npu_graph=use_npu_graph) else: raise ValueError("load or set_inference_mode before forward") class KDeepseekV3MoEW8A8(KDeepseekV3MoE): def forward(self, hidden_states, stream=None, para_stream=None): tp_size = get_tensor_parallel_size() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() identity = hidden_states orig_shape = hidden_states.shape def share_experts_forward(): if self.config.n_shared_experts is not None: return self.shared_experts(identity).squeeze(0) if rank == 0: topk_idx, topk_weight = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) if get_use_npu_graph(): org_type = hidden_states.dtype if hasattr(self.config, "backend_type"): if self.config.backend_type == "ktransformers": from ktransformers.util.npu_graph_runner import get_or_create_runner npu_graph_runner = get_or_create_runner(utils.get_current_device()) stream = npu_graph_runner.main_stream para_stream = npu_graph_runner.share_experts_stream event = torch.npu.Event() event.record(stream) with torch.npu.stream(para_stream): event.wait(para_stream) y_ = share_experts_forward() if share_experts_forward is not None else None event.record(para_stream) input_tensor = hidden_states.to(torch.bfloat16) topk_weight = topk_weight.contiguous().to(torch.float32) cuda_graph_idx = orig_shape[0] - 1 self.moe_kexperts_param = (hidden_states, topk_idx, topk_weight, cuda_graph_idx, True) if cuda_graph_idx < len(npu_graphs): expert_ids = topk_idx KExpertsCPU.input_tensor_cpu[cuda_graph_idx][0].copy_(input_tensor, non_blocking = True) KExpertsCPU.expert_ids_cpu[cuda_graph_idx][0].copy_(expert_ids, non_blocking = True) KExpertsCPU.weights_cpu[cuda_graph_idx][0].copy_(topk_weight, non_blocking = True) torch_npu.npu._launch_host_func(stream, self.cpu_moe_kexperts, self.moe_kexperts_param) y = self.experts.generate_experts.output_cpu[cuda_graph_idx][0].to(utils.get_current_device(), non_blocking = True) y = y.view(*orig_shape).to(device=hidden_states.device) y = y.to(org_type) event.wait(stream) else: from ktransformers.util.npu_graph_runner import get_or_create_runner npu_graph_runner = get_or_create_runner(utils.get_current_device()) event = torch.npu.Event() event.record(npu_graph_runner.main_stream) with torch.npu.stream(npu_graph_runner.share_experts_stream): event.wait(npu_graph_runner.share_experts_stream) y_ = share_experts_forward() if share_experts_forward is not None else None event.record(npu_graph_runner.share_experts_stream) topk_weight = topk_weight.contiguous().to(torch.float32) self.moe_kexperts_param = (hidden_states, topk_idx, topk_weight, None, True) org_type = hidden_states.dtype input_tensor = hidden_states.to(torch.bfloat16) cuda_graph_idx = bisect.bisect_left(npu_graphs, 1) if cuda_graph_idx < len(npu_graphs): immediate_expert_ids = topk_idx KExpertsCPU.input_tensor_cpu[cuda_graph_idx][0].copy_(input_tensor, non_blocking = True) KExpertsCPU.expert_ids_cpu[cuda_graph_idx][0].copy_(immediate_expert_ids, non_blocking = True) KExpertsCPU.weights_cpu[cuda_graph_idx][0].copy_(topk_weight, non_blocking = True) npu_graph_runner.launch_callback( self.cpu_moe_kexperts, self.moe_kexperts_param, 1, npu_graph_runner.main_stream) y = self.experts.generate_experts.output_cpu[cuda_graph_idx][0].to(utils.get_current_device(), non_blocking = True) y = y.to(org_type) y = y.view(*orig_shape).to(device=hidden_states.device) event.wait(npu_graph_runner.main_stream) else: y = self.moe_kexperts(hidden_states, topk_idx, topk_weight) y_ = share_experts_forward() if share_experts_forward is not None else None y = y.view(*orig_shape).to(device=hidden_states.device) y_ = y_.view(*orig_shape) else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) y = torch.zeros(orig_shape, dtype=torch.float16, device=CUR_DEVICE) y_ = share_experts_forward() if share_experts_forward is not None else None if tp_size > 1 and world_size == tp_size: torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM, group=get_tensor_parallel_group()) if self.config.n_shared_experts is not None: y += y_ return y @torch.no_grad() def cpu_moe_kexperts(self, moe_kexperts_param) -> torch.Tensor: x, topk_ids, topk_weight, cuda_graph_idx, use_npu_graph = moe_kexperts_param _ = self.experts(x, topk_ids, topk_weight, cuda_graph_idx=cuda_graph_idx, use_npu_graph=use_npu_graph) @torch.no_grad() def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: outs = self.experts(x, topk_ids, topk_weight) return outs class KQwen3MoeSparseMoeBlockW8A8(BaseInjectedModule): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "npu", generate_device: str = "npu", **kwargs, ): super().__init__( key, gguf_loader, config, orig_module, prefill_device=prefill_device, generate_device=generate_device, **kwargs, ) self.gate = orig_module.gate self.top_k = orig_module.top_k self.norm_topk_prob = orig_module.norm_topk_prob self.output_router_logits = getattr(orig_module, "output_router_logits", False) experts_key = f"{key}.experts" print(f"[NPU-MOE][INIT] build experts at key={experts_key}", flush=True) self.experts = KTransformersExpertsW8A8( key=experts_key, gguf_loader=gguf_loader, config=config, orig_module=orig_module.experts, prefill_device=prefill_device, prefill_op="KExpertsTorch", generate_device="cpu", generate_op="KExpertsCPUW8A8", out_device=prefill_device, ) def set_inference_mode(self, mode: InferenceState): if isinstance(self.experts, KExpertsBase): self.experts.set_inference_mode(mode) @torch.no_grad() def cpu_moe_kexperts(self, moe_kexperts_param): x, topk_ids, topk_weight, cuda_graph_idx, use_npu_graph = moe_kexperts_param _ = self.experts( x, topk_ids, topk_weight, cuda_graph_idx=cuda_graph_idx, use_npu_graph=use_npu_graph, ) @torch.no_grad() def moe_kexperts( self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor: torch.Tensor = None, cuda_graph_idx: int = 0, use_npu_graph: bool = False, ) -> torch.Tensor: outs = self.experts( x, topk_ids, topk_weight, cuda_graph_idx=cuda_graph_idx, use_npu_graph=use_npu_graph, ) return outs def forward( self, hidden_states: torch.Tensor, bsz_tensor: torch.Tensor = None, cuda_graph_idx: int = 0, *args, **kwargs, ): if hidden_states.dim() == 3: B, S, H = hidden_states.shape else: orig_shape = hidden_states.shape hidden_states = hidden_states.view(1, -1, orig_shape[-1]) B, S, H = hidden_states.shape orig_device = hidden_states.device orig_shape = (B, S, H) output_router_logits_flag = kwargs.pop("output_router_logits", False) need_router_logits = output_router_logits_flag or self.output_router_logits # ===== 1) flatten ===== hidden_states_flat = hidden_states.view(-1, H) T = hidden_states_flat.shape[0] # ===== 2) gate ===== router_logits = self.gate(hidden_states_flat) try: router_logits_bs = router_logits.view(B, S, -1) except Exception: router_logits_bs = router_logits routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk( routing_weights, self.top_k, dim=-1 ) if self.norm_topk_prob: rw_sum = routing_weights.sum(dim=-1, keepdim=True) routing_weights = routing_weights / rw_sum routing_weights = routing_weights.to(hidden_states_flat.dtype) # ===== 3) MoE experts ===== use_npu_graph = get_use_npu_graph() if torch.distributed.is_available() and torch.distributed.is_initialized(): tp_size = get_tensor_parallel_size() world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() else: tp_size = 1 world_size = 1 rank = 0 y = None if isinstance(self.experts, KExpertsBase): if getattr(self.experts, "mode", None) == InferenceState.UNLOAD: self.experts.set_inference_mode(InferenceState.GENERATE) if rank == 0: if use_npu_graph: org_type = hidden_states_flat.dtype input_tensor = hidden_states_flat.to(torch.bfloat16) topk_weight_f32 = routing_weights.contiguous().to(torch.float32) self.moe_kexperts_param = ( hidden_states_flat, selected_experts, topk_weight_f32, cuda_graph_idx, True, ) if cuda_graph_idx < len(npu_graphs): KExpertsCPU.input_tensor_cpu[cuda_graph_idx][0].copy_(input_tensor, non_blocking=True) KExpertsCPU.expert_ids_cpu[cuda_graph_idx][0].copy_(selected_experts, non_blocking=True) KExpertsCPU.weights_cpu[cuda_graph_idx][0].copy_(topk_weight_f32, non_blocking=True) stream = torch.npu.current_stream() torch_npu.npu._launch_host_func( stream, self.cpu_moe_kexperts, self.moe_kexperts_param, ) y_flat = self.experts.generate_experts.output_cpu[cuda_graph_idx][0].to( utils.get_current_device(), non_blocking=True, ) y_flat = y_flat.to(org_type) y = y_flat.view(*orig_shape).to(device=orig_device) else: tmp_bsz_tensor = torch.tensor([B], dtype=torch.int32, device=orig_device) y_flat = self.moe_kexperts( hidden_states_flat, selected_experts, routing_weights, bsz_tensor=tmp_bsz_tensor, cuda_graph_idx=cuda_graph_idx, use_npu_graph=False, ) y = y_flat.view(*orig_shape).to(device=orig_device) else: if bsz_tensor is None: bsz_tensor = torch.tensor( [B], dtype=torch.int32, device=orig_device, ) y_flat = self.moe_kexperts( hidden_states_flat, selected_experts, routing_weights, bsz_tensor=bsz_tensor, cuda_graph_idx=cuda_graph_idx, use_npu_graph=False, ) y = y_flat.view(*orig_shape).to(device=orig_device) else: y = torch.zeros(orig_shape, dtype=hidden_states.dtype, device=orig_device) else: y = hidden_states if tp_size > 1 and world_size == tp_size: torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM, group=get_tensor_parallel_group()) # print("================ [NPU-MOE] EXIT MLP =======================\n") if need_router_logits: num_experts = router_logits.shape[-1] router_logits_bs = router_logits.view(B, S, num_experts) return y, router_logits_bs return y ================================================ FILE: archive/ktransformers/operators/ascend/ascend_gate.py ================================================ import torch import torch_npu import torch.nn as nn import torch.nn.functional as F from ktransformers.operators.gate import KMoEGate class KDeepseekV3GateA2(KMoEGate): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None): if device is None: device = self.device if w is None: w = self.load_weights(device=device) if isinstance(w, dict): self.weight_type = w["weight_type"] self.e_score_correction_bias_type = w["e_score_correction_bias_type"] self.orig_module.weight = nn.Parameter(w["weight"]) self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"]) else: raise ValueError("Invalid weight type") self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device).to(torch.float32)) self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device).to(torch.float32)) def forward(self, hidden_states) -> torch.Tensor: h = hidden_states.shape[-1] # compute gating score hidden_states = hidden_states.view(-1, h) logits = F.linear(hidden_states.type(torch.float32), self.weight, None) topk_weight, topk_idx, _ = torch_npu.npu_moe_gating_top_k( logits, k=self.top_k, bias=self.e_score_correction_bias, k_group=self.topk_group, group_count=self.n_group, group_select_mode=1, renorm=0, norm_type=1, routed_scaling_factor=self.routed_scaling_factor, eps=float(1e-20)) return topk_idx.type(torch.int64), topk_weight ================================================ FILE: archive/ktransformers/operators/ascend/ascend_layernorm.py ================================================ # coding=utf-8 # Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved. # Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re from typing import Optional, Union, Tuple import torch import torch_npu from torch import nn from transformers import PretrainedConfig from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm from ktransformers.util import utils from ktransformers.util.custom_loader import GGUFLoader class KDeepseekV3RMSNormW8A8(BaseInjectedModule): def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "npu", generate_device: str = "npu", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.weight = nn.Parameter(torch.ones(self.orig_module.hidden_size)) self.bias = nn.Parameter(torch.ones(self.orig_module.hidden_size)) self.variance_epsilon = self.orig_module.variance_epsilon def forward(self, hidden_states): input_dtype = hidden_states.dtype out = torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] + self.bias return out.to(input_dtype) def load(self): self.weight = self.gguf_loader.safetensor_loader.load_tensor(self.key + ".weight").to(utils.get_current_device()) self.bias = self.gguf_loader.safetensor_loader.load_tensor(self.key + ".bias").to(utils.get_current_device()) def unload(self): if self.weight is not None: self.weight = None if self.bias is not None: self.bias = None class KQwen3MoeRMSNormW8A8(BaseInjectedModule): def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "npu", generate_device: str = "npu", **kwargs): super().__init__(key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) self.hidden_size = orig_module.hidden_size self.variance_epsilon = orig_module.variance_epsilon self.weight = nn.Parameter(orig_module.weight.data.clone()) def forward(self, x: torch.Tensor): x = x.to(torch.float16) gamma = self.weight.to(torch.float16) input_dtype = x.dtype out = torch_npu.npu_rms_norm( x, gamma, self.variance_epsilon )[0] return out.to(input_dtype) def load(self): device = utils.get_current_device() self.weight = self.gguf_loader.safetensor_loader.load_tensor(self.key + ".weight").to(device) try: self.bias = ( self.gguf_loader.safetensor_loader .load_tensor(self.key + ".bias") .to(device) ) except KeyError: self.bias = None def unload(self): self.weight = None self.bias = None class KQwen3FinalRMSNormNPU(nn.Module): def __init__(self, orig_module: nn.Module): super().__init__() assert hasattr(orig_module, "weight"), "orig_module must have weight" self.variance_epsilon = getattr(orig_module, "variance_epsilon", 1e-6) w = orig_module.weight.detach() if w.dtype not in (torch.float16, torch.bfloat16, torch.float32): w = w.to(torch.float16) else: if w.dtype == torch.float32: w = w.to(torch.float16) self.weight = nn.Parameter(w) def forward(self, x: torch.Tensor): input_dtype = x.dtype x = x.contiguous() gamma = self.weight x_rms = x.to(dtype=gamma.dtype) out = torch_npu.npu_rms_norm( x_rms, gamma, self.variance_epsilon )[0] return out.to(input_dtype) ================================================ FILE: archive/ktransformers/operators/ascend/ascend_linear.py ================================================ # coding=utf-8 # Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved. # Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import abstractmethod import torch import torch_npu import torch.distributed as dist from torch import nn from transformers import PretrainedConfig from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.linear import KLinearBase, LINEAR_MAP from ktransformers.util import utils from ktransformers.util.custom_loader import GGUFLoader from ktransformers.util.utils import InferenceState from ktransformers.util.ascend.ascend_utils import get_safetensors_cut_weight, get_tensor_parallel_size, get_tensor_parallel_group from ktransformers.util.custom_gguf import translate_name_to_gguf class KLinearW8A8(KLinearBase): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, device: str = "cuda", **kwargs, ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) def load_weight(self, override_key: str | None = None, device: str | None = None): if override_key is not None: keys = override_key else: keys = [self.key] fake_tensor = torch.tensor([1]) for key in keys: if device is None: device = utils.get_current_device() key = translate_name_to_gguf(key) if key == "lm_head": key = "output" if key + ".weight" in self.gguf_loader.safetensor_loader.tensor_file_map: if key + ".deq_scale" in self.gguf_loader.safetensor_loader.tensor_file_map: qweight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight") deq_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.deq_scale") quant_bias = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.quant_bias") input_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.input_scale") input_offset = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.input_offset") tensors = (qweight, deq_scale, quant_bias, input_scale, input_offset) return tensors elif key + ".weight_scale" in self.gguf_loader.safetensor_loader.tensor_file_map: if key.endswith("ffn_gate_shexp"): parts = key.split(".") layer = parts[1] gate_weight = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight") gate_weight = get_safetensors_cut_weight(self.key, gate_weight).t() up_weight = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight") up_weight = get_safetensors_cut_weight(self.key, up_weight).t() gate_scale = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight_scale") gate_scale = get_safetensors_cut_weight(self.key, gate_scale) up_scale = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight_scale") up_scale = get_safetensors_cut_weight(self.key, up_scale) gate_up_weight = torch.cat((gate_weight, up_weight), 1) gate_up_scale = torch.cat((gate_scale, up_scale), 0) gate_offset = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight_offset") up_offset = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight_offset") gate_up_offset = torch.cat((gate_offset, up_offset), 0) tensors = (gate_up_weight, gate_up_scale, gate_up_offset) elif key.endswith("ffn_up_shexp"): return fake_tensor else: qweight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight") weight_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight_scale") weight_offset = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight_offset") tensors = (qweight, weight_scale, weight_offset) return tensors else: weight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight") return weight else: raise FileNotFoundError(f"Weight file not found for key {key}") @abstractmethod def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = "cuda"): pass @abstractmethod def unload(self): pass class KLinearTorchW8A8A2(KLinearW8A8): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, device: str = "cuda", **kwargs, ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) self.has_bias = False self.dtype = torch.get_default_dtype() self.weight = None self.input_scale = None self.input_offset = None self.quant_bias = None self.deq_scale = None self.weight_scale = None self.weight_offset = None def forward(self, x: torch.Tensor, bsz_tensor) -> torch.Tensor: if x.dtype != self.weight.dtype: x = x.to(self.weight.dtype) return torch.matmul(x, self.weight) def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None): if device is None: device = utils.get_current_device() device = utils.CUR_DEVICE if w is None: w = self.load_weight() if isinstance(w, nn.Parameter): try: self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T.contiguous() except: self.weight = w.to(dtype=self.dtype).T.contiguous() self.weight = self.weight.to(device) if self.has_bias: self.bias = self.bias.to(device) elif isinstance(w, tuple): w_list = list(w) if len(w_list) == 3: self.weight = w_list[0] self.weight_scale = w_list[1].view(-1) self.weight_offset = w_list[2] self.weight = self.weight.to(utils.CUR_DEVICE) self.weight_scale = self.weight_scale.to(utils.CUR_DEVICE) if self.key.endswith("ffn_gate_shexp") is not True: self.weight = get_safetensors_cut_weight(self.key, self.weight).t() weight_scale = get_safetensors_cut_weight(self.key, self.weight_scale) self.weight_scale = weight_scale.clone() del weight_scale else: for i in range(len(w_list)): w_list[i] = get_safetensors_cut_weight(self.key, w_list[i]) w_list[i] = w_list[i].to(utils.CUR_DEVICE) self.weight = w_list[0] self.deq_scale = w_list[1] self.quant_bias = w_list[2] if "attn_output" in self.key or "ffn_down" in self.key: if torch.distributed.get_rank(get_tensor_parallel_group()) != 0: self.quant_bias = torch.zeros_like(self.quant_bias, dtype=self.quant_bias.dtype, device=self.quant_bias.device) self.input_scale = w_list[3] self.input_offset = w_list[4] elif isinstance(w, torch.Tensor): self.weight = w.T.contiguous() self.weight = self.weight.to(device) if "kv_b" not in self.key and ("output" in self.key or "eh_proj" in self.key): self.weight = torch_npu.npu_format_cast(self.weight, 29) else: raise ValueError(f"Invalid weight type {self.key=} {type(w)=}") def unload(self): if self.weight is not None: self.weight = None if self.has_bias: self.bias = None self.input_scale = None self.input_offset = None self.quant_bias = None self.deq_scale = None self.weight_scale = None self.weight_offset = None LINEAR_MAP["KLinearTorchW8A8A2"] = KLinearTorchW8A8A2 class KTransformersLinearW8A8A2(BaseInjectedModule, KLinearW8A8): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, generate_device: str = "cuda", generate_op: str | None = "KLinearMarlin", prefill_device: str = "cuda", prefill_op: str | None = "KLinearTorch", **kwargs, ): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) KLinearW8A8.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) # build all the linear operators if prefill_op is not None: assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported" self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs) else: self.prefill_linear = None if generate_op is not None: assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported" self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs) else: self.generate_linear = None self.mode = InferenceState.UNLOAD def forward(self, x, bsz_tensor=None): if self.mode == InferenceState.PREFILL: assert self.prefill_linear is not None, "cpu linear is not initialized" y = self.prefill_linear.forward(x, bsz_tensor) else: assert self.generate_linear is not None, "gpu linear is not initialized" y = self.generate_linear.forward(x, bsz_tensor) return y def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE): if not mode: mode = InferenceState.GENERATE # load to device if mode == InferenceState.PREFILL: self.generate_linear.unload() self.prefill_linear.load(w=w) self.device = self.prefill_linear.device self.weight = self.prefill_linear.weight # modeling_xxx.py may use linear.weight self.input_scale = self.prefill_linear.input_scale self.input_offset = self.prefill_linear.input_offset self.quant_bias = self.prefill_linear.quant_bias self.deq_scale = self.prefill_linear.deq_scale self.weight_scale = self.prefill_linear.weight_scale self.weight_offset = self.prefill_linear.weight_offset elif mode == InferenceState.GENERATE: self.prefill_linear.unload() self.generate_linear.load(w=w) self.device = self.generate_linear.device self.weight = self.generate_linear.weight # modeling_xxx.py may use linear.weight self.input_scale = self.generate_linear.input_scale self.input_offset = self.generate_linear.input_offset self.quant_bias = self.generate_linear.quant_bias self.deq_scale = self.generate_linear.deq_scale self.weight_scale = self.generate_linear.weight_scale self.weight_offset = self.generate_linear.weight_offset elif mode == InferenceState.UNLOAD: self.prefill_linear.unload() self.generate_linear.unload() self.device = "cpu" else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") self.mode = mode def unload(self): if self.prefill_linear is not None: self.prefill_linear.unload() if self.generate_linear is not None: self.generate_linear.unload() self.device = self.generate_linear.device def set_inference_mode(self, mode: InferenceState): if not mode: mode = InferenceState.GENERATE if mode == InferenceState.GENERATE: self.load(mode=InferenceState.GENERATE) elif mode == InferenceState.PREFILL: self.load(mode=InferenceState.PREFILL) elif mode == InferenceState.UNLOAD: self.unload() else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") ================================================ FILE: archive/ktransformers/operators/ascend/ascend_mlp.py ================================================ # coding=utf-8 # Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved. # Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch_npu from ktransformers.util.ascend.ascend_utils import allredeuce_warpper from ktransformers.util.utils import CUR_DEVICE from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP from ktransformers.models.modeling_qwen3_moe import Qwen3MoeMLP class KDeepseekV3MLPW8A8A2V1(BaseInjectedModule, DeepseekV3MLP): @allredeuce_warpper def forward(self, x, is_prefill=None, use_cuda_graph=False): original_dtype = x.dtype quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x) dynamic_scale = dynamic_scale.view(-1) quant_out = quant_out.view(-1, quant_out.shape[-1]) gate_x = torch_npu.npu_quant_matmul( quant_out, self.orig_module.gate_proj.weight, self.orig_module.gate_proj.weight_scale, pertoken_scale=dynamic_scale, bias=None, output_dtype=original_dtype, ) up_x = torch_npu.npu_quant_matmul( quant_out, self.orig_module.up_proj.weight, self.orig_module.up_proj.weight_scale, pertoken_scale=dynamic_scale, bias=None, output_dtype=original_dtype, ) down_x = self.act_fn(gate_x) * up_x down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x) down_dynamic_scale = down_dynamic_scale.view(-1) down_proj = torch_npu.npu_quant_matmul( down_quant_out, self.orig_module.down_proj.weight, self.orig_module.down_proj.weight_scale, pertoken_scale=down_dynamic_scale, bias=None, output_dtype=original_dtype, ) down_proj = down_proj.reshape(x.shape) return down_proj class KDeepseekV3MLPW8A8A2V2(BaseInjectedModule, DeepseekV3MLP): @allredeuce_warpper def forward(self, x, is_prefill=None, use_cuda_graph=False): original_dtype = x.dtype quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x) dynamic_scale = dynamic_scale.view(-1) quant_out = quant_out.view(-1, quant_out.shape[-1]) gate_up_x = torch_npu.npu_quant_matmul( quant_out, self.orig_module.gate_proj.weight, self.orig_module.gate_proj.weight_scale, pertoken_scale=dynamic_scale, bias=None, output_dtype=original_dtype, ) down_x = torch_npu.npu_swiglu(gate_up_x, -1) down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x) down_dynamic_scale = down_dynamic_scale.view(-1) down_proj = torch_npu.npu_quant_matmul( down_quant_out, self.orig_module.down_proj.weight, self.orig_module.down_proj.weight_scale, pertoken_scale=down_dynamic_scale, bias=None, output_dtype=original_dtype, ) down_proj = down_proj.reshape(x.shape) return down_proj class KQwen3MoeMLPW8A8A2(BaseInjectedModule, Qwen3MoeMLP): @allredeuce_warpper def forward(self, x, is_prefill=None, use_cuda_graph=False): original_dtype = x.dtype quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x) dynamic_scale = dynamic_scale.view(-1) quant_out = quant_out.view(-1, quant_out.shape[-1]) gate_x = torch_npu.npu_quant_matmul( quant_out, self.orig_module.gate_proj.weight, self.orig_module.gate_proj.weight_scale, pertoken_scale=dynamic_scale, bias=None, output_dtype=original_dtype, ) up_x = torch_npu.npu_quant_matmul( quant_out, self.orig_module.up_proj.weight, self.orig_module.up_proj.weight_scale, pertoken_scale=dynamic_scale, bias=None, output_dtype=original_dtype, ) down_x = torch.nn.functional.silu(gate_x) * up_x down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x) down_dynamic_scale = down_dynamic_scale.view(-1) down_proj = torch_npu.npu_quant_matmul( down_quant_out, self.orig_module.down_proj.weight, self.orig_module.down_proj.weight_scale, pertoken_scale=down_dynamic_scale, bias=None, output_dtype=original_dtype, ) down_proj = down_proj.reshape(x.shape) return down_proj ================================================ FILE: archive/ktransformers/operators/attention.py ================================================ ''' Description : Author : Boxin Zhang Version : 0.1.0 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import torch from torch import nn import warnings import torch.nn.functional as F from ktransformers.operators.models import KLlamaModel from ktransformers.models.configuration_deepseek import DeepseekV2Config from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.modeling_llama import LlamaRotaryEmbedding from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_loader import GGUFLoader from ktransformers.util.utils import get_compute_capability import logging from transformers.configuration_utils import PretrainedConfig from transformers.cache_utils import Cache from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor try: from flash_attn import flash_attn_func except: pass from ktransformers.operators.triton_attention import decode_attention_fwd_grouped from ktransformers.operators.triton_attention_prefill import context_attention_fwd import os from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled if flashinfer_enabled: from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton from flashinfer.mla import BatchMLAPagedAttentionWrapper from ktransformers.models.custom_cache import KDeepSeekV3Cache logger = logging.getLogger("attention") # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # V3 MLA is same to V2 class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" attn_mask: Optional[torch.Tensor] = None def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", chunck_size: int = 1000, absorb_for_prefill: bool = False, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) self.chunck_size = chunck_size # TODO, generate chunck_size automatically. self.mla_wrapper = None self.absorb_for_prefill = absorb_for_prefill def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank) return self.q_absorb, self.out_absorb def forward_chunck( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim] # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim] compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_layernorm(compressed_kv) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv_seq_len = k_pe.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models # compressed_kv [bsz, q_len, self.kv_lora_rank] # k_pe [bsz, 1, q_len, self.qk_rope_head_dim] k_pe = k_pe.transpose(1,2) compressed_kv = compressed_kv.unsqueeze(2) compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) compressed_kv, k_pe = torch.split( compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) # k_pe [pages, page_size, 1, self.qk_rope_head_dim] # compressed_kv [pages, page_size, 1, self.kv_lora_rank] q_absorb, out_absorb = self.get_absorbed() # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim] # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim] k_pe = k_pe.view(bsz, 1, -1, self.qk_rope_head_dim)[:,:,:attention_mask.size(-1),:] compressed_kv = compressed_kv.view(bsz, 1, -1, self.kv_lora_rank)[:,:,:attention_mask.size(-1),:] # k_pe [bsz, 1, cache_len, self.qk_rope_head_dim] # compressed_kv [bsz, 1, cache_len,self.kv_lora_rank] q_nope = torch.matmul(q_nope, q_absorb) #print(q_pe.shape) #print(k_pe.shape) #print(q_nope.shape) #print(compressed_kv.shape) attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale #attn_weights [bsz, self.num_heads, q_len, kv_seq_len] compressed_kv = compressed_kv.squeeze(1) """ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) assert attention_mask is not None """ if attention_mask is not None: """ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) """ #causal_mask = attention_mask[:, :, :, : kv_seq_len] attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(q_pe.dtype) attn_weights = nn.functional.dropout( attn_weights, p=self.attention_dropout, training=self.training ) attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv) attn_output = torch.matmul(attn_output, out_absorb.mT) if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value def forward_linux_triton( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_layernorm(compressed_kv) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim) compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank) kv_seq_len = q_len if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2) # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim] # decode if q_len == 1: if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank] # for speed # compressed_kv_with_k_pe [bsz, q_len, 1, self.kv_lora_rank + self.qk_rope_head_dim] # compressed_kv [bsz, q_len, 1, self.kv_lora_rank] # q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim] # q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank] q_absorb, out_absorb = self.get_absorbed() q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below q_nope = torch.matmul(q_nope, q_absorb) # batched MM q_nope = q_nope.transpose(1, 2) #assert q_nope.is_contiguous() # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank] # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] query_states = torch.cat([q_nope, q_pe], dim=-1) query_states = query_states.squeeze(1) attn_output = torch.zeros_like(q_nope) # [bsz, q_len, self.num_heads, self.kv_lora_rank] attn_logits = torch.empty( ( bsz, self.num_heads, 4, #num_kv_splits # follow vLLM, fix it TODO self.kv_lora_rank + 1, ), dtype=torch.float32, device = attn_output.device ) """ print("query_states", torch.isnan(query_states).any()) print("compressed_kv_with_k_pe", torch.isnan(compressed_kv_with_k_pe[:,:,0,:]).any()) print("compressed_kv", torch.isnan(compressed_kv[:,:,0,:]).any()) print("position_ids", torch.isnan(position_ids).any()) """ # flash attn doesn't support head_dim bigger than 256 # use triton attention kernel adapted from vLLM and SGLang for MQA decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output, page_table, position_ids.squeeze(0).to(torch.int32)+1, attn_logits, 4, #num_kv_splits # follow vLLM, fix it TODO self.softmax_scale, past_key_value.page_size) # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank] # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] attn_output = attn_output.transpose(1, 2) attn_output = torch.matmul(attn_output, out_absorb.mT) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) #print("attn_output", torch.isnan(attn_output).any()) return attn_output, None, past_key_value else: if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models k_pe.squeeze(0) compressed_kv.squeeze(0) compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) compressed_kv, k_pe = torch.split( compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim) k_pe = k_pe[:, :kv_seq_len] compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank) compressed_kv = compressed_kv[:, :kv_seq_len] kv = ( self.kv_b_proj(compressed_kv) .view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) ) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim) key_states[:, :, :, :self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1) value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim) value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) attn_output = flash_attn_func( query_states, key_states, value_states_padded, softmax_scale=self.softmax_scale, causal=True, ) if self.q_head_dim != self.v_head_dim: attn_output = attn_output[:, :, :, : self.v_head_dim] attn_output = attn_output.reshape( bsz, q_len, self.num_heads * self.v_head_dim ).contiguous() attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value def forward_linux_flashinfer( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_layernorm(compressed_kv) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim) compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank) kv_seq_len = q_len if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version transformer verision v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2) # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim] # decode if q_len == 1 or self.absorb_for_prefill: if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, past_key_value.page_size, self.kv_lora_rank) k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, past_key_value.page_size, self.qk_rope_head_dim) # k_pe [max_pages, page_size, self.qk_rope_head_dim] # compressed_kv [max_pages, page_size, self.kv_lora_rank] # q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim] # q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank] q_absorb, out_absorb = self.get_absorbed() q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below q_nope = torch.matmul(q_nope, q_absorb) # batched MM q_nope = q_nope.transpose(1, 2) q_nope = q_nope.contiguous() #assert q_nope.is_contiguous() # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank] # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] q_nope.squeeze_(0) q_pe.squeeze_(0) # flash attn doesn't support head_dim bigger than 256, use flashinfer if self.mla_wrapper is None: self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True) if self.mla_wrapper.need_plan: self.mla_wrapper.need_plan = False if q_len == 1: self.mla_wrapper.plan(None,None,None, position_ids.squeeze(1)+1, None, self.num_heads, self.kv_lora_rank, self.qk_rope_head_dim, past_key_value.page_size, self.softmax_scale, q_nope.dtype, compressed_kv.dtype) else: qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=self.device) kv_len_arr = torch.tensor([position_ids[0, -1].item()+1], dtype=torch.int32, device=self.device) self.mla_wrapper.plan(qo_indptr,None,None, kv_len_arr, None, self.num_heads, self.kv_lora_rank, self.qk_rope_head_dim, past_key_value.page_size, self.softmax_scale, q_nope.dtype, compressed_kv.dtype) attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank) """ k = ( torch.cat([compressed_kv, k_pe], dim=-1) .view(-1, 1, 512 + 64) .repeat_interleave(self.num_heads, dim=1) ) v = compressed_kv.view(-1, 1, 512).repeat_interleave(self.num_heads, dim=1) lens = position_ids.item() + 1 #print("lens", lens) attn_ref, lse_ref = attention_ref( 1, torch.cat([q_nope, q_pe], dim=-1), k[:lens], v[:lens], False, self.softmax_scale ) attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank) """ # mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank] # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank] # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank] attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim] attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank] attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim] attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value else: if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models k_pe.squeeze(0) compressed_kv.squeeze(0) compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) compressed_kv, k_pe = torch.split( compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim) k_pe = k_pe[:, :kv_seq_len] compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank) compressed_kv = compressed_kv[:, :kv_seq_len] kv = ( self.kv_b_proj(compressed_kv) .view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) ) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim) key_states[:, :, :, :self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1) value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim) value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) attn_output = flash_attn_func( query_states, key_states, value_states_padded, softmax_scale=self.softmax_scale, causal=True, ) if self.q_head_dim != self.v_head_dim: attn_output = attn_output[:, :, :, : self.v_head_dim] attn_output = attn_output.reshape( bsz, q_len, self.num_heads * self.v_head_dim ).contiguous() attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value def forward_windows( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() if q_len <= self.chunck_size: return self.forward_chunck( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs ) assert output_attentions == False, "output_attentions is not supported when using chunked attention" attn_output = None cur_idx = 0 while cur_idx < q_len: if attention_mask is not None: chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...] else: # generate chunk_mask automatically. self.attn_mask = \ torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \ if self.attn_mask is None \ else self.attn_mask self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \ -1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\ [:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))] self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38 self.attn_mask[:, :, :, :cur_idx] = 0 chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx)) cur_output, _, _ = self.forward_chunck( hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...], chunk_mask, position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)], past_key_value, output_attentions, use_cache, cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)], **kwargs ) cur_idx += self.chunck_size if attn_output is None: attn_output = cur_output else: attn_output = torch.cat((attn_output, cur_output), dim=-2) return attn_output, None, past_key_value def forward_xpu( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) query_states = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) k_nope, value_states = torch.split( kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) kv_seq_len = value_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) position_embeddings = kwargs.get("position_embeddings", None) if position_embeddings is not None: cos, sin = position_embeddings key_states = torch.cat( [k_nope, k_pe.expand([-1, self.num_heads, -1, -1])], dim=-1 ) from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim :], key_states[:, :, :, self.qk_nope_head_dim:], cos, sin, True) else: q_nope, q_pe = torch.split( query_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update( key_states.half(), value_states.half(), self.layer_idx, cache_kwargs ) attn_weights = None from ipex_llm.transformers.models.common import scaled_dot_product_attention attn_output = scaled_dot_product_attention( query_states.half(), key_states, value_states, attention_mask.half(), q_len == kv_seq_len, self.softmax_scale ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output).to(hidden_states.dtype) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if torch.xpu.is_available(): return self.forward_xpu( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs, ) elif (os.name == 'nt' or get_compute_capability() < 8 or hidden_states.device.type == 'cpu' or device_manager.gpu_vendor != GPUVendor.NVIDIA): return self.forward_windows( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs, ) else: if flashinfer_enabled: return self.forward_linux_flashinfer( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs, ) else: return self.forward_linux_triton( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs, ) class KLlamaAttention(BaseInjectedModule): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices = self.q_proj.weight.split( (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 ) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] query_states = torch.cat(query_states, dim=-1) key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] key_states = torch.cat(key_states, dim=-1) value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin) if q_len == 1: position_ids = position_ids[0][-1].unsqueeze(0).unsqueeze(0) query_states = query_states[:, :, -1:] key_states = key_states[:, :, -1:] attn_output = KLlamaModel.dynamic_sdpa.apply( self.layer_idx, bsz, position_ids[0][0], query_states.transpose(1, 2).to(torch.float16), key_states.transpose(1, 2).to(torch.float16), value_states.transpose(1, 2).to(torch.float16), mode="prefill" if q_len > 1 else "generate", ) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1) if self.config.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class KQwen3MoeAttentionIPEXLLM(BaseInjectedModule, Qwen3MoeAttention): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "xpu", generate_device: str = "xpu", chunck_size: int = 1000, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) self.chunck_size = chunck_size # TODO, generate chunck_size automatically. assert prefill_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device" assert generate_device.lower()[:3] == "xpu", "KQwen3MoeAttentionIPEXLLM only supports XPU device" def forward( self, hidden_states: torch.Tensor, position_ids: Optional[torch.Tensor], position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] bsz, q_len, _ = hidden_states.size() input_dtype = hidden_states.dtype hidden_shape = (*input_shape, -1, self.head_dim) if not hasattr(self, 'qkv_proj'): from ipex_llm.transformers.models.common import merge_quantized_qkv merge_quantized_qkv(self.q_proj.generate_linear, self.k_proj.generate_linear, self.v_proj.generate_linear, self.orig_module) qkv = self.qkv_proj(hidden_states) qkv = qkv.view(bsz, q_len, -1, self.head_dim) qkv = qkv.transpose(1, 2) query_states, key_states, value_states = qkv.split([self.config.num_attention_heads, self.config.num_key_value_heads, self.config.num_key_value_heads], dim=1) query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) if position_embeddings is None: position_embeddings = self.rotary_emb(hidden_states, position_ids) cos, sin = position_embeddings from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced rotary_half_with_cache_inplaced(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states.half(), value_states.half(), self.layer_idx, cache_kwargs) attn_weights = None from ipex_llm.transformers.models.common import scaled_dot_product_attention attn_output = scaled_dot_product_attention( query_states.half(), key_states, value_states, attention_mask.half(), q_len == key_states.size(2), self.scaling ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output).to(input_dtype) return attn_output, attn_weights ================================================ FILE: archive/ktransformers/operators/balance_serve_attention.py ================================================ ''' Description : Author : Boxin Zhang Version : 0.2.5 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import torch from torch import nn from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb from ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention from ktransformers.models.modeling_smallthinker import SmallthinkerAttention from ktransformers.models.modeling_glm4_moe import Glm4MoeAttention from ktransformers.models.modeling_qwen3_next import Qwen3NextGatedDeltaNet from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_loader import GGUFLoader import logging from transformers.configuration_utils import PretrainedConfig from flashinfer import BatchMLAPagedAttentionWrapper from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn from ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache logger = logging.getLogger("attention") # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", chunck_size: int = 1000, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) self.chunck_size = chunck_size # TODO, generate chunck_size automatically. def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank) out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank) self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, bias=False, dtype=q_absorb.dtype, device=q_absorb.device) self.q_absorb.weight.data = q_absorb self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, bias=False, dtype=out_absorb.dtype, device=out_absorb.device) self.out_absorb.weight.data = out_absorb #del self.orig_module.kv_b_proj q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank) return q_absorb, out_absorb def forward(self, hidden_states: torch.Tensor, kv_cache: KDeepSeekV3Cache, position_ids: torch.Tensor, wrapper: BatchMLAPagedAttentionWrapper, num_tokens_tensors: torch.Tensor, page_idx: torch.Tensor, page_offset: torch.Tensor, ): q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states, num_tokens_tensors) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors) q = q.view(q_len, self.num_heads, self.q_head_dim) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors) compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) compressed_kv = compressed_kv.contiguous() compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors) k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim) compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank) cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0)) q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2) q_pe = q_pe.squeeze(0) if kv_cache is not None: # page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices) cache_kwargs = {"sin": sin, "cos": cos, "page_idx": page_idx, "page_offset": page_offset} # Specific to RoPE models compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs) compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank) k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim) q_absorb, out_absorb = self.get_absorbed() q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below q_nope = torch.matmul(q_nope, q_absorb) # batched MM q_nope = q_nope.transpose(0, 1) # q_nope.squeeze_(1) # q_pe.squeeze_(1) attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank) attn_output = attn_output.transpose(0, 1) attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim] attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output, num_tokens_tensors) return attn_output class KQwen2MoeAttention(BaseInjectedModule, Qwen2MoeAttention): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", chunck_size: int = 1000, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) self.chunck_size = chunck_size # TODO, generate chunck_size automatically. # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def forward(self, hidden_states: torch.Tensor, kv_cache: KGQACache, position_ids: torch.Tensor, wrapper: flashInferAttn, bsz_tensors: torch.Tensor, page_idx: torch.Tensor, page_offset: torch.Tensor, ): q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states, bsz_tensors) key_states = self.k_proj(hidden_states, bsz_tensors) value_states = self.v_proj(hidden_states, bsz_tensors) query_states = query_states.view(q_len, self.num_heads, self.head_dim) key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) cos, sin = self.rotary_emb(value_states.unsqueeze(0), position_ids.unsqueeze(0)) query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2) query_states = query_states.view(q_len, self.num_heads, self.head_dim) key_states = key_states.view( q_len, self.num_key_value_heads, self.head_dim ) value_states = value_states.view( q_len, self.num_key_value_heads, self.head_dim ) k_cache = kv_cache.get_k_cache(self.layer_idx) v_cache = kv_cache.get_v_cache(self.layer_idx) attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states) attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors) return attn_output class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", chunck_size: int = 1000, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) self.chunck_size = chunck_size # TODO, generate chunck_size automatically. # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def forward(self, hidden_states: torch.Tensor, kv_cache: KGQACache, position_ids: torch.Tensor, wrapper: flashInferAttn, bsz_tensors: torch.Tensor, page_idx: torch.Tensor, page_offset: torch.Tensor, ): q_len, _ = hidden_states.size() bsz_tensors_q = bsz_tensors * self.num_heads bsz_tensors_kv = bsz_tensors * self.num_key_value_heads query_states = self.q_norm(self.q_proj(hidden_states, bsz_tensors), bsz_tensors_q) key_states = self.k_norm(self.k_proj(hidden_states, bsz_tensors), bsz_tensors_kv) value_states = self.v_proj(hidden_states, bsz_tensors) query_states = query_states.view(q_len, self.num_heads, self.head_dim) key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) cos, sin = self.rotary_emb(value_states.unsqueeze(0), position_ids.unsqueeze(0)) query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2) query_states = query_states.view(q_len, self.num_heads, self.head_dim) key_states = key_states.view( q_len, self.num_key_value_heads, self.head_dim ) value_states = value_states.view( q_len, self.num_key_value_heads, self.head_dim ) k_cache = kv_cache.get_k_cache(self.layer_idx) v_cache = kv_cache.get_v_cache(self.layer_idx) attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states) attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors) return attn_output class deepseek_torch_attn(BaseInjectedModule, DeepseekV2Attention): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", chunck_size: int = 1000, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) self.chunck_size = chunck_size # TODO, generate chunck_size automatically. def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank) out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank) self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, bias=False, dtype=q_absorb.dtype, device=q_absorb.device) self.q_absorb.weight.data = q_absorb self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, bias=False, dtype=out_absorb.dtype, device=out_absorb.device) self.out_absorb.weight.data = out_absorb #del self.orig_module.kv_b_proj q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank) return q_absorb, out_absorb def forward(self, hidden_states: torch.Tensor, kv_cache: KDeepSeekV3Cache, position_ids: torch.Tensor, wrapper: None, num_tokens_tensors: torch.Tensor, page_idx: torch.Tensor, page_offset: torch.Tensor, attention_masks: Optional[list[torch.Tensor]] = None, q_indptr: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None, kv_indptr: Optional[torch.Tensor] = None, bsz_tensors: Optional[torch.Tensor] = None, last_page_len: Optional[torch.Tensor] = None, ): # range bsz_tensors final_attention_output = torch.tensor([], device=hidden_states.device) for i in range(bsz_tensors[0]): batch_num_tokens_tensors = q_indptr[i+1] - q_indptr[i] batch_last_page_len = last_page_len[i] # kv_total_len is kv_len, batch_compressed_kv is compressed_kv, batch_k_pe is k_pe batch_page_idx = page_idx[q_indptr[i]:q_indptr[i+1]] batch_page_offset = page_offset[q_indptr[i]:q_indptr[i+1]] # kv_page_nums is the number of pages for the current batch kv_page_nums = kv_indptr[i+1] - kv_indptr[i] # kv_total_len is the total length of the kv cache for the current batch (kv_len for algorithm) kv_total_len = kv_page_nums * kv_cache.page_size if batch_last_page_len is not None: kv_total_len = kv_total_len - (kv_cache.page_size - batch_last_page_len) # print(f"kv_total_len's shape {kv_total_len.shape}") # kv_index is the index of the kv cache pages for the current batch kv_index = kv_indices[kv_indptr[i]:kv_indptr[i+1]] # we can index [kv_index, page_offset_indices] to get the kv cache for the current batch # from q_indptr[i] to q_indptr[i+1] is the range of the current batch batch_hidden_states = hidden_states[q_indptr[i]:q_indptr[i+1]] batch_position_ids = position_ids[q_indptr[i]:q_indptr[i+1]] q_len, _ = batch_hidden_states.size() # print("q_len -> ", q_len) if self.q_lora_rank is None: q = self.q_proj(batch_hidden_states, batch_num_tokens_tensors) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(batch_hidden_states, batch_num_tokens_tensors), batch_num_tokens_tensors), batch_num_tokens_tensors) # for v3, bsz, q_len, num_heads(128), qk_head_dim(192=128(nope)+64(rope)) q = q.view(q_len, self.num_heads, self.q_head_dim) # q_nope is [q_len, num_heads(128), qk_nope_head_dim(128)] # q_pe is [q_len, num_heads(128), qk_rope_head_dim(64)] q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) # compressed_kv is [q_len, kv_lora_rank(512) + rope(64)] compressed_kv = self.kv_a_proj_with_mqa(batch_hidden_states, batch_num_tokens_tensors) # compressed_kv is [q_len, kv_lora_rank(512)], k_pe is [q_len, rope(64)] compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) compressed_kv = compressed_kv.contiguous() compressed_kv = self.kv_a_layernorm(compressed_kv, batch_num_tokens_tensors) # k_pe is [q_len, 1, qk_rope_head_dim(64)] k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim) # compressed_kv is [q_len, 1, kv_lora_rank(512)] compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank) cos, sin = self.rotary_emb(q_pe, batch_position_ids.unsqueeze(0)) # print(f"q_pe shape{q_pe.shape}, k_pe shape {k_pe.shape}") q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2) q_pe = q_pe.squeeze(0) # q_pe is [num_heads(128), q_len, qk_rope_head_dim(64)] q_pe.transpose_(0, 1) if kv_cache is not None: cache_kwargs = {"sin": sin, "cos": cos, "page_idx": batch_page_idx, "page_offset": batch_page_offset} # Specific to RoPE models compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, batch_page_idx, batch_page_offset, cache_kwargs) compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank) k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim) # q_absorb is [num_heads(128), qk_nope_head_dim(128), kv_lora_rank(512)] # out_absorb is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] v_head_dim is also the nope dim q_absorb, out_absorb = self.get_absorbed() # q_nope is [num_heads(128), q_len, qk_nope_head_dim(128)] q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below # q_nope is [num_heads(128), q_len, kv_lora_rank(512)] q_nope = torch.matmul(q_nope, q_absorb) # batched MM # # q_nope is [q_len, num_heads(128), kv_lora_rank(512)] # q_nope = q_nope.transpose(0, 1) # we need to index out the compressed_kv and k_pe for the current batch batch_compressed_kv = None batch_k_pe = None for page_index in kv_index: if kv_total_len > kv_cache.page_size: tmp_compressed_kv = compressed_kv[page_index, 0:kv_cache.page_size, :] tmp_k_pe = k_pe[page_index, 0:kv_cache.page_size, :] if batch_compressed_kv is None or batch_k_pe is None: batch_compressed_kv = tmp_compressed_kv batch_k_pe = tmp_k_pe else: batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0) batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0) kv_total_len -= kv_cache.page_size else: tmp_compressed_kv = compressed_kv[page_index, 0:kv_total_len, :] tmp_k_pe = k_pe[page_index, 0:kv_total_len, :] if batch_compressed_kv is None or batch_k_pe is None: batch_compressed_kv = tmp_compressed_kv batch_k_pe = tmp_k_pe else: batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0) batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0) break # batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)] # batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)] attention_weights = (torch.matmul(q_pe,batch_k_pe.mT) + torch.matmul(q_nope, batch_compressed_kv.mT)) * self.softmax_scale # attention_weights is [num_heads(128), q_len, k_len] # attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(q_len,-1,-1).transpose(0,1) # attention_masks[i] is [q_len, k_len] attention_weights = (attention_weights + attention_masks[i]) # attention_weights shape is [num_heads(128), q_len, k_len] attention_weights = nn.functional.softmax(attention_weights,dim=-1,dtype=torch.float32).to(q_pe.dtype) attn_output = torch.matmul(attention_weights, batch_compressed_kv) # [num_heads(128),q_len, lora_rank(512)] # out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] out_absorb = out_absorb.transpose(1,2) # q for q_len, n for num_heads, h for v_head_dim, v for kv_lora_rank attn_output = torch.matmul(attn_output, out_absorb) # [num_heads(128), q_len, v_head_dim(128)] attn_output = attn_output.transpose(0, 1) # [q_len, num_heads(128), v_head_dim(128)] attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output, batch_num_tokens_tensors) final_attention_output = torch.cat((final_attention_output, attn_output), dim=0) return final_attention_output class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", chunck_size: int = 1000, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) self.chunck_size = chunck_size # TODO, generate chunck_size automatically. def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def forward(self, hidden_states: torch.Tensor, kv_cache: KGQACache, freqs_cis: torch.Tensor, wrapper: flashInferAttn, bsz_tensors: torch.Tensor, position_ids: torch.Tensor = None, ): if self.use_qk_norm: raise NotImplementedError("use_qk_norm is not implemented yet") q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states, bsz_tensors) key_states = self.k_proj(hidden_states, bsz_tensors) value_states = self.v_proj(hidden_states, bsz_tensors) query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim) key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) # cos, sin = freqs_cis """ print(query_states.shape) print(key_states.shape) print(cos.shape) print(sin.shape) """ if freqs_cis: cos, sin = freqs_cis query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2) query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim) key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) k_cache = kv_cache.get_k_cache(self.layer_idx) v_cache = kv_cache.get_v_cache(self.layer_idx) attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states) attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors) return attn_output class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", chunck_size: int = 1000, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) self.chunck_size = chunck_size # TODO, generate chunck_size automatically. def apply_rotary_pos_emb( self, q: torch.Tensor, k: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor], unsqueeze_dim=2 ) -> Tuple[torch.Tensor, torch.Tensor]: # Keep half or full tensor for later concatenation cos = freqs_cis[0] sin = freqs_cis[1] rotary_dim = cos.shape[-1] cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] # Apply rotary embeddings on the first half or full tensor q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) # Concatenate back to full shape q_embed = torch.cat([q_embed, q_pass], dim=-1) k_embed = torch.cat([k_embed, k_pass], dim=-1) return q_embed, k_embed def forward(self, hidden_states: torch.Tensor, kv_cache: KGQACache, freqs_cis: torch.Tensor, wrapper: flashInferAttn, bsz_tensors: torch.Tensor, position_ids: torch.Tensor = None, ): q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states, bsz_tensors) key_states = self.k_proj(hidden_states, bsz_tensors) value_states = self.v_proj(hidden_states, bsz_tensors) if self.use_qk_norm: query_states = self.q_norm(query_states, bsz_tensors) key_states = self.k_norm(key_states, bsz_tensors) # cos, sin = freqs_cis """ print(query_states.shape) print(key_states.shape) print(cos.shape) print(sin.shape) """ query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim) key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim) value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim) if freqs_cis is not None: query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis) query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim) key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim) value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim) k_cache = kv_cache.get_k_cache(self.layer_idx) v_cache = kv_cache.get_v_cache(self.layer_idx) print(f"{k_cache.shape=}, {v_cache.shape=}, {query_states.shape=}, {key_states.shape=}, {value_states.shape=}") attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states) attn_output = self.o_proj(attn_output.view(q_len, self.config.num_attention_heads * self.head_dim), bsz_tensors) return attn_output from ktransformers.models.modeling_qwen3_next import apply_mask_to_padding_states import torch.nn.functional as F from ktransformers.models.modeling_qwen3_next import Qwen3NextAttention class KQwen3NextAttention(BaseInjectedModule, Qwen3NextAttention): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", chunck_size: int = 1000, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) self.chunck_size = chunck_size # TODO, generate chunck_size automatically. # Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Removes the interleaving of cos and sin from GLM Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) # Keep half or full tensor for later concatenation rotary_dim = cos.shape[-1] q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] # Apply rotary embeddings on the first half or full tensor q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) # Concatenate back to full shape q_embed = torch.cat([q_embed, q_pass], dim=-1) k_embed = torch.cat([k_embed, k_pass], dim=-1) return q_embed, k_embed def forward(self, hidden_states: torch.Tensor, kv_cache: KGQACache, freqs_cis: torch.Tensor, wrapper: flashInferAttn, bsz_tensors: torch.Tensor, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ): q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states, bsz_tensors) query_states, gate = torch.chunk( self.q_proj(hidden_states).view(q_len, -1, self.head_dim * 2), 2, dim=-1 ) gate = gate.reshape(q_len, -1) key_states = self.k_proj(hidden_states, bsz_tensors) query_states = query_states.reshape(q_len, -1) query_states = self.q_norm(query_states, bsz_tensors) key_states = self.k_norm(key_states, bsz_tensors) value_states = self.v_proj(hidden_states, bsz_tensors) query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim) key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) if freqs_cis: cos, sin = freqs_cis query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2) query_states, key_states = query_states.squeeze(0), key_states.squeeze(0) k_cache = kv_cache.get_k_cache(self.layer_idx) v_cache = kv_cache.get_v_cache(self.layer_idx) attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states) attn_output = attn_output.reshape(q_len, -1).contiguous() attn_output = attn_output * torch.sigmoid(gate) attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors) return attn_output class KQwen3NextGatedDeltaNet(BaseInjectedModule, Qwen3NextGatedDeltaNet): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", chunck_size: int = 1000, **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.layer_idx) self.chunck_size = chunck_size # TODO, generate chunck_size automatically. def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): """ Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. """ new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( self.num_k_heads, 2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads, ) new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads) mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) mixed_ba = mixed_ba.view(*new_tensor_shape_ba) split_arg_list_qkvz = [ self.head_k_dim, self.head_k_dim, (self.num_v_heads // self.num_k_heads * self.head_v_dim), (self.num_v_heads // self.num_k_heads * self.head_v_dim), ] split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads] query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim) z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim) b = b.reshape(b.size(0), b.size(1), self.num_v_heads) a = a.reshape(a.size(0), a.size(1), self.num_v_heads) return query, key, value, z, b, a def forward( self, hidden_states: torch.Tensor, conv_states: Optional[list[torch.Tensor]] = None, recurrent_states: Optional[list[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, bsz_tensors: Optional[torch.Tensor] = None, ): hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape conv_state = conv_states[self.layer_idx] if conv_states is not None else None recurrent_state = ( recurrent_states[self.layer_idx] if recurrent_states is not None else None ) use_precomputed_states = ( conv_state is not None and recurrent_state is not None and seq_len == 1 ) projected_states_qkvz = self.in_proj_qkvz(hidden_states, bsz_tensors) projected_states_ba = self.in_proj_ba(hidden_states, bsz_tensors) query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba) query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) mixed_qkv = torch.cat((query, key, value), dim=-1) mixed_qkv = mixed_qkv.transpose(1, 2) if use_precomputed_states: # 2. Convolution sequence transformation # NOTE: the conv state is updated in `causal_conv1d_update` mixed_qkv = self.causal_conv1d_update( mixed_qkv, conv_state, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation, ) else: conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, weight=self.conv1d.weight.squeeze(1), bias=self.conv1d.bias, activation=self.activation, seq_idx=None, ) else: mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split( mixed_qkv, [ self.key_dim, self.key_dim, self.value_dim, ], dim=-1, ) query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim) key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim) value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) beta = b.sigmoid() # If the model is loaded in fp16, without the .float() here, A might be -inf g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) if self.num_v_heads // self.num_k_heads > 1: query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) if not use_precomputed_states: core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( query, key, value, g=g, beta=beta, initial_state=None, output_final_state=conv_state is not None, use_qk_l2norm_in_kernel=True, ) else: core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, key, value, g=g, beta=beta, initial_state=recurrent_state, output_final_state=conv_state is not None, use_qk_l2norm_in_kernel=True, ) # Update cache recurrent_state = last_recurrent_state z_shape_og = z.shape # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) output = self.out_proj(core_attn_out, bsz_tensors) if conv_state is not None: conv_states[self.layer_idx] = conv_state if recurrent_state is not None: recurrent_states[self.layer_idx] = recurrent_state return output ================================================ FILE: archive/ktransformers/operators/base_operator.py ================================================ ''' Description : Author : Boxin Zhang Version : 0.1.0 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' from typing import Any from torch import nn, Tensor from ktransformers.util.custom_loader import GGUFLoader from transformers.configuration_utils import PretrainedConfig import ktransformers.util.utils as utils class BaseInjectedModule(nn.Module): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", **kwargs): nn.Module.__init__(self) nn.Module.__setattr__(self, "orig_module", orig_module) object.__setattr__(self, "key", key) object.__setattr__(self, "gguf_loader", gguf_loader) object.__setattr__(self, "config", config) object.__setattr__(self, "prefill_device", prefill_device) object.__setattr__(self, "generate_device", generate_device) object.__setattr__(self, "device", generate_device) def __getattr__(self, name: str) -> Any: # __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__, # but __setattr__ in nn.Module call super().__setattr__ in that case, there may be some attribute set # but can't get using __getattr__, typically these attr is build in attr of the class, so class.attr does not # call __getattr__. # Example: # ...import torch # ...l=torch.nn.Linear(100,200) # ...l.out_features # 200 # ...l.__getattr__("out_features") # AttributeError: 'Linear' object has no attribute 'out_features' try: return object.__getattribute__(self, name) # if this attr belongs to BaseInjectedModule except: if name == "orig_module": return nn.Module.__getattr__(self, "orig_module") try: return nn.Module.__getattr__(self, "orig_module").__getattr__(name) # if this attr belongs to orig_module except: return super(nn.Module, nn.Module.__getattr__(self, "orig_module")).__getattribute__(name) # if this attr belongs to orig_module but not in nn.Module.__dict__ def __setattr__(self, name: str, value: Tensor | nn.Module) -> None: if name == "orig_module": return nn.Module.__setattr__(self, "orig_module", value) elif hasattr(self, name): return object.__setattr__(self, name, value) return nn.Module.__getattr__(self, "orig_module").__setattr__(name, value) def forward(self, *args, **kwargs): return self.orig_module.forward(*args, **kwargs) def load(self): for name, child in self._modules.items(): utils.load_weights(child, self.gguf_loader, self.key+".") ================================================ FILE: archive/ktransformers/operators/cpuinfer.py ================================================ #!/usr/bin/env python # coding=utf-8 """ Description : This script defines the `CPUInferKVCache` and `CPUInfer` classes for performing inference with a Key-Value Cache on the CPU. The `CPUInferKVCache` class is responsible for configuring and managing key-value caches, updating and retrieving cache data, and handling attention operations. It supports different cache types (e.g., Q4_0, FP16) and retrieval strategies (e.g., shared, separate). The `CPUInfer` class handles task submission and synchronization on the CPU, with optional CUDA stream integration for tasks involving GPU acceleration. These classes facilitate efficient caching and memory management for deep learning models that leverage key-value attention mechanisms, particularly on CPU-based systems. Author : djw Date : 2024-08-26 23:25:24 Version : 1.0.0 LastEditors : djw LastEditTime : 2024-08-26 23:25:24 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ import sys, os from typing import Any import torch sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build")) sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release")) sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug")) import cpuinfer_ext from ktransformers.server.config.config import Config class CPUInferKVCache: def __init__( self, layer_num: int = 32, kv_head_num: int = 8, q_head_num: int = 32, head_dim: int = 128, block_len: int = 256, anchor_num: int = 4, anchor_type: str = "FIXED", kv_type: str = "Q4_0", retrieval_type: str = "SHARED", layer_step: int = 1, token_step: int = 1, layer_offset: int = 0, max_thread_num: int = 32, max_batch_size: int = 4, max_block_num: int = 512, ): if anchor_type == "FIXED": anchor_type = cpuinfer_ext.kvcache.AnchorType.FIXED elif anchor_type == "QUEST": anchor_type = cpuinfer_ext.kvcache.AnchorType.QUEST elif anchor_type == "DYNAMIC": anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC elif anchor_type == "BLOCK_MEAN": anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MEAN elif anchor_type == "BLOCK_MAX": anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MAX else: raise ValueError(f"Unknown anchor type: {anchor_type}") if kv_type == "FP16": kv_type = cpuinfer_ext.kvcache.ggml_type.FP16 elif kv_type == "FP32": assert False, "FP32 is not supported yet." kv_type = cpuinfer_ext.kvcache.ggml_type.FP32 elif kv_type == "Q4_0": kv_type = cpuinfer_ext.kvcache.ggml_type.Q4_0 elif kv_type == "Q8_0": kv_type = cpuinfer_ext.kvcache.ggml_type.Q8_0 else: raise ValueError(f"Unknown kv type: {kv_type}") if retrieval_type == "SHARED": retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER elif retrieval_type == "INDIVIDUAL": retrieval_type = cpuinfer_ext.kvcache.RetrievalType.QHEAD elif retrieval_type == "SEPARATE": retrieval_type = cpuinfer_ext.kvcache.RetrievalType.KVHEAD self.config = cpuinfer_ext.kvcache.KVCacheConfig( layer_num, kv_head_num, q_head_num, head_dim, block_len, anchor_num, anchor_type, kv_type, retrieval_type, layer_step, token_step, layer_offset, max_block_num, max_batch_size, max_thread_num, ) self.kvcache = cpuinfer_ext.kvcache.KVCache(self.config) def load_kvcache(self, tensor_file_path: str): if not os.path.exists(tensor_file_path): raise FileNotFoundError(f"The file {tensor_file_path} does not exist.") return self.kvcache.load_kvcache(tensor_file_path,) def dump_kvcache( self, block_table: torch.Tensor, cache_total_len: int, tensor_file_path: str ): assert ( block_table.dim() == 1 and block_table.dtype == torch.int and block_table.is_contiguous() and block_table.device == torch.device("cpu") ), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( block_table.dim(), block_table.size(), block_table.dtype, block_table.is_contiguous(), block_table.device, ) assert ( cache_total_len > 0 and cache_total_len <= self.config.block_len * block_table.size(0) ), "cache_total_len: {}".format(cache_total_len) if not os.path.exists(os.path.dirname(tensor_file_path)): os.makedirs(os.path.dirname(tensor_file_path)) return self.kvcache.dump_kvcache( block_table.data_ptr(), cache_total_len, tensor_file_path, ) def update_cache_total_len(self, cache_total_len: int): assert cache_total_len > 0, "cache_total_len: {}".format(cache_total_len) self.kvcache.update_cache_total_len(cache_total_len) # q_in: (bsz, q_len, q_head_num, head_dim) # output: (bsz, q_len, q_head_num, head_dim) # attn_lse: (bsz, q_len, q_head_num) # block_table: (bsz, max_block_num) def attn( self, q_in: torch.Tensor, output: torch.Tensor, attn_lse: torch.Tensor, layer_idx: int, generate_token_idx: int, block_table: torch.Tensor | None = None, cache_seqlens: torch.Tensor | None = None, pick_block_num: int | None = None, init_block_num: int | None = None, local_block_num: int | None = None, ): assert ( q_in.dim() == 4 and q_in.size(2) == self.config.q_head_num and q_in.size(3) == self.config.head_dim and q_in.dtype == torch.float16 and q_in.is_contiguous() and q_in.device == torch.device("cpu") ), "q_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( q_in.dim(), q_in.size(), q_in.dtype, q_in.is_contiguous(), q_in.device ) batch_size = q_in.size(0) q_len = q_in.size(1) assert (block_table is None) or ( block_table.dim() == 2 and block_table.size(0) == batch_size and block_table.dtype == torch.int and block_table.is_contiguous() and block_table.device == torch.device("cpu") ), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( block_table.dim(), block_table.size(), block_table.dtype, block_table.is_contiguous(), block_table.device, ) max_block_num = block_table.size(1) if block_table is not None else 0 assert ( output.dim() == 4 and output.size(0) == batch_size and output.size(2) == self.config.q_head_num and output.size(1) == q_len and output.size(3) == self.config.head_dim and output.dtype == torch.float16 and output.is_contiguous() and output.device == torch.device("cpu") ), "output dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( output.dim(), output.size(), output.dtype, output.is_contiguous(), output.device, ) assert ( attn_lse.dim() == 3 and attn_lse.size(0) == batch_size and attn_lse.size(1) == q_len and attn_lse.size(2) == self.config.q_head_num and attn_lse.dtype == torch.float32 and attn_lse.is_contiguous() and attn_lse.device == torch.device("cpu") ), "attn_lse dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( attn_lse.dim(), attn_lse.size(), attn_lse.dtype, attn_lse.is_contiguous(), attn_lse.device, ) assert ( layer_idx >= 0 and layer_idx < self.config.layer_num ), "layer_idx: {}".format(layer_idx) assert (cache_seqlens is None) or ( cache_seqlens.dim() == 1 and cache_seqlens.size(0) == batch_size and cache_seqlens.dtype == torch.int and cache_seqlens.is_contiguous() and cache_seqlens.device == torch.device("cpu") ), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( cache_seqlens.dim(), cache_seqlens.size(), cache_seqlens.dtype, cache_seqlens.is_contiguous(), cache_seqlens.device, ) return self.kvcache.attn( q_in.data_ptr(), output.data_ptr(), attn_lse.data_ptr(), layer_idx, generate_token_idx, q_len, batch_size, max_block_num, block_table.data_ptr() if block_table is not None else 0, cache_seqlens.data_ptr() if cache_seqlens is not None else 0, pick_block_num, init_block_num, local_block_num, ) # k_in: (block_len, kv_head_num, head_dim) # v_in: (block_len, kv_head_num, head_dim) def update_kvcache_one_block_fp16( self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int ): assert ( k_in.dim() == 3 and k_in.size(1) == self.config.block_len and k_in.size(0) == self.config.kv_head_num and k_in.size(2) == self.config.head_dim and k_in.dtype == torch.float16 and k_in.is_contiguous() and k_in.device == torch.device("cpu") ), "k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device ) assert ( v_in.dim() == 3 and v_in.size(1) == self.config.block_len and v_in.size(0) == self.config.kv_head_num and v_in.size(2) == self.config.head_dim and v_in.dtype == torch.float16 and v_in.is_contiguous() and v_in.device == torch.device("cpu") ), "v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device ) assert ( layer_id >= 0 and layer_id < self.config.layer_num ), "layer_id: {}".format(layer_id) assert block_idx >= 0, "block_idx: {}".format(block_idx) return self.kvcache.update_one_block_fp16( k_in.data_ptr(), v_in.data_ptr(), layer_id, block_idx, ) def get_kvcache_one_block_fp16( self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int ): assert ( k_in.dim() == 3 and k_in.size(1) == self.config.block_len and k_in.size(0) == self.config.kv_head_num and k_in.size(2) == self.config.head_dim and k_in.dtype == torch.float16 and k_in.is_contiguous() and k_in.device == torch.device("cpu") ), "k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device ) assert ( v_in.dim() == 3 and v_in.size(1) == self.config.block_len and v_in.size(0) == self.config.kv_head_num and v_in.size(2) == self.config.head_dim and v_in.dtype == torch.float16 and v_in.is_contiguous() and v_in.device == torch.device("cpu") ), "v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device ) assert ( layer_id >= 0 and layer_id < self.config.layer_num ), "layer_id: {}".format(layer_id) assert block_idx >= 0, "block_idx: {}".format(block_idx) return self.kvcache.get_one_block_fp16( k_in.data_ptr(), v_in.data_ptr(), layer_id, block_idx, ) def update_importance_one_block( self, importance: torch.Tensor, layer_id: int, block_idx: int ): assert ( importance.dim() == 1 and importance.size(0) == self.config.block_len and importance.dtype == torch.float16 and importance.is_contiguous() and importance.device == torch.device("cpu") ), "importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( importance.dim(), importance.size(), importance.dtype, importance.is_contiguous(), importance.device, ) assert ( layer_id >= 0 and layer_id < self.config.layer_num ), "layer_id: {}".format(layer_id) assert block_idx >= 0, "block_idx: {}".format(block_idx) return self.kvcache.update_importance_one_block( importance.data_ptr(), layer_id, block_idx, ) def get_importance_one_block( self, importance: torch.Tensor, layer_id: int, block_idx: int ): assert ( importance.dim() == 1 and importance.size(0) == self.config.block_len and importance.dtype == torch.float16 and importance.is_contiguous() and importance.device == torch.device("cpu") ), "importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( importance.dim(), importance.size(), importance.dtype, importance.is_contiguous(), importance.device, ) assert ( layer_id >= 0 and layer_id < self.config.layer_num ), "layer_id: {}".format(layer_id) assert block_idx >= 0, "block_idx: {}".format(block_idx) return self.kvcache.get_importance_one_block( importance.data_ptr(), layer_id, block_idx, ) def get_anchor_one_block(self, anchor: torch.Tensor, layer_id: int, block_idx: int): assert ( anchor.dim() == 3 and anchor.size(0) == self.config.kv_head_num and anchor.size(1) == self.config.anchor_num and anchor.size(2) == self.config.head_dim and anchor.dtype == torch.float16 and anchor.is_contiguous() and anchor.device == torch.device("cpu") ), "anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( anchor.dim(), anchor.size(), anchor.dtype, anchor.is_contiguous(), anchor.device, ) assert ( layer_id >= 0 and layer_id < self.config.layer_num ), "layer_id: {}".format(layer_id) assert block_idx >= 0, "block_idx: {}".format(block_idx) return self.kvcache.get_anchor_one_block( anchor.data_ptr(), layer_id, block_idx, ) def update_anchor_one_block( self, anchor: torch.Tensor, layer_id: int, block_idx: int ): assert ( anchor.dim() == 3 and anchor.size(0) == self.config.kv_head_num and anchor.size(1) == self.config.anchor_num and anchor.size(2) == self.config.head_dim and anchor.dtype == torch.float16 and anchor.is_contiguous() and anchor.device == torch.device("cpu") ), "anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( anchor.dim(), anchor.size(), anchor.dtype, anchor.is_contiguous(), anchor.device, ) assert ( layer_id >= 0 and layer_id < self.config.layer_num ), "layer_id: {}".format(layer_id) assert block_idx >= 0, "block_idx: {}".format(block_idx) return self.kvcache.update_anchor_one_block( anchor.data_ptr(), layer_id, block_idx, ) def calc_anchor_all_layers( self, block_table: torch.Tensor, cache_seqlens: torch.Tensor, ): assert ( block_table.dim() == 2 and block_table.size(0) == cache_seqlens.size(0) and block_table.dtype == torch.int and block_table.is_contiguous() and block_table.device == torch.device("cpu") ), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( block_table.dim(), block_table.size(), block_table.dtype, block_table.is_contiguous(), block_table.device, ) assert ( cache_seqlens.dim() == 1 and cache_seqlens.dtype == torch.int and cache_seqlens.is_contiguous() and cache_seqlens.device == torch.device("cpu") ), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( cache_seqlens.dim(), cache_seqlens.size(), cache_seqlens.dtype, cache_seqlens.is_contiguous(), cache_seqlens.device, ) batch_size = block_table.size(0) max_block_num = block_table.size(1) return self.kvcache.calc_anchor_all_layers( block_table.data_ptr(), cache_seqlens.data_ptr(), batch_size, max_block_num, ) def clear_importance_all_layers( self, block_table: torch.Tensor, cache_seqlens: torch.Tensor, ): assert ( block_table.dim() == 2 and block_table.size(0) == cache_seqlens.size(0) and block_table.dtype == torch.int and block_table.is_contiguous() and block_table.device == torch.device("cpu") ), "block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( block_table.dim(), block_table.size(), block_table.dtype, block_table.is_contiguous(), block_table.device, ) assert ( cache_seqlens.dim() == 1 and cache_seqlens.dtype == torch.int and cache_seqlens.is_contiguous() and cache_seqlens.device == torch.device("cpu") ), "cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}".format( cache_seqlens.dim(), cache_seqlens.size(), cache_seqlens.dtype, cache_seqlens.is_contiguous(), cache_seqlens.device, ) batch_size = block_table.size(0) max_block_num = block_table.size(1) return self.kvcache.clear_importance_all_layers( block_table.data_ptr(), cache_seqlens.data_ptr(), batch_size, max_block_num, ) def get_cache_total_len(self): return self.kvcache.get_cache_total_len() def update_kvcache_q4( self, k_in: torch.Tensor, k_scales: torch.Tensor, v_in: torch.Tensor, v_scales: torch.Tensor, layer_id: int, seq_offset: int | None = None, seq_len: int | None = None, block_table: torch.Tensor | None = None, ): raise NotImplementedError def update_kvcache_fp16( self, k_in: torch.Tensor, v_in: torch.Tensor, layer_idx, block_table: torch.Tensor, max_block_num, past_len: torch.Tensor, q_len, ): batch_size = block_table.size(0) return self.kvcache.get_kvcache_fp16( k_in.data_ptr(), v_in.data_ptr(), layer_idx, block_table.data_ptr(), batch_size, max_block_num, past_len.data_ptr(), q_len ) def get_kvcache_q4( self, k_in: torch.Tensor, k_scales: torch.Tensor, v_in: torch.Tensor, v_scales: torch.Tensor, layer_id: int, seq_offset: int | None = None, seq_len: int | None = None, block_table: torch.Tensor | None = None, ): raise NotImplementedError def get_kvcache_fp16( self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, layer_idx, block_table: torch.Tensor, max_block_num, past_len: torch.Tensor, ): batch_size = block_table.size(0) return self.kvcache.get_kvcache_fp16( k_in.data_ptr(), v_in.data_ptr(), layer_idx, block_table.data_ptr(), batch_size, max_block_num, past_len.data_ptr(), ) def get_and_update_kvcache_fp16( self, k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor, layer_idx, block_table: torch.Tensor, max_block_num, past_len: torch.Tensor, q_len, ): batch_size = block_table.size(0) return self.kvcache.get_and_update_kvcache_fp16( k_cache_cpu.data_ptr(), v_cache_cpu.data_ptr(), layer_idx, block_table.data_ptr(), batch_size, max_block_num, past_len.data_ptr(), q_len, ) def update_importance( self, importance_cache: torch.Tensor, layer_idx, block_table: torch.Tensor, max_block_num, offset: torch.Tensor, width, ): batch_size = block_table.size(0) return self.kvcache.update_importance( importance_cache.data_ptr(), layer_idx, block_table.data_ptr(), batch_size, max_block_num, offset.data_ptr(), width, ) # attn_sparsity: ((bsz, q_len, q_head_num), dtype = torch.float32) def get_attn_sparsity( self, q_in: torch.Tensor, attn_sparsity: torch.Tensor, layer_idx: int, block_table: torch.Tensor, cache_seqlens: torch.Tensor, block_table_origin: torch.Tensor, cache_seqlens_origin: torch.Tensor, generate_token_idx: int = 0, topk: int | None = None, local: int | None = None, ): batch_size = block_table.size(0) max_block_num = block_table.size(1) max_block_num_origin = block_table_origin.size(1) q_len = q_in.size(1) if topk is None or local is None or topk + local >= max_block_num: topk = -1 local = -1 return self.kvcache.get_attn_sparsity( q_in.data_ptr(), attn_sparsity.data_ptr(), layer_idx, generate_token_idx, q_len, batch_size, max_block_num, block_table.data_ptr(), cache_seqlens.data_ptr(), block_table_origin.data_ptr(), cache_seqlens_origin.data_ptr(), max_block_num_origin, topk, local, ) def attn_with_kvcache( self, q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, output: torch.Tensor, attn_lse: torch.Tensor, layer_idx: int, block_table: torch.Tensor, cache_seqlens: torch.Tensor, generate_token_idx: int = 0, topk: int | None = None, local: int | None = None, ): batch_size = block_table.size(0) max_block_num = block_table.size(1) q_len = q_in.size(1) if topk is None or local is None or topk + local >= max_block_num: topk = -1 local = -1 return self.kvcache.attn_with_kvcache( q_in.data_ptr(), k_in.data_ptr(), v_in.data_ptr(), output.data_ptr(), attn_lse.data_ptr(), layer_idx, generate_token_idx, q_len, batch_size, max_block_num, block_table.data_ptr(), cache_seqlens.data_ptr(), topk, local, ) def get_all_kvcache_one_layer( self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int ): return self.kvcache.get_all_kvcache_one_layer( k_in.data_ptr(), v_in.data_ptr(), layer_id, ) def get_importance( self, importance: torch.Tensor, block_table: torch.Tensor, ): raise NotImplementedError def get_anchor( self, anchor: torch.Tensor, block_table: torch.Tensor, ): raise NotImplementedError class CPUInfer: cpuinfer = None cur_backend_thread_num = 0 def __init__(self, thread_num): if thread_num > CPUInfer.cur_backend_thread_num: CPUInfer.cur_backend_thread_num = thread_num del CPUInfer.cpuinfer CPUInfer.cpuinfer = cpuinfer_ext.CPUInfer(thread_num) def submit(self, task): CPUInfer.cpuinfer.submit(task) def submit_with_cuda_stream(self, current_cuda_stream, task): CPUInfer.cpuinfer.submit_with_cuda_stream(current_cuda_stream, task) def sync(self): CPUInfer.cpuinfer.sync() def sync_with_cuda_stream(self, current_cuda_stream): CPUInfer.cpuinfer.sync_with_cuda_stream(current_cuda_stream) ================================================ FILE: archive/ktransformers/operators/dynamic_attention.py ================================================ #!/usr/bin/env python # coding=utf-8 """ Description : Author : Jianwei Dong Date : 2024-08-26 23:25:24 Version : 1.0.0 LastEditors : Jianwei Dong LastEditTime : 2024-08-26 23:25:24 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ import torch from transformers import AutoConfig import sys, os import logging logger = logging.getLogger("dynamic_attention") sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend") from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache try: from flash_attn import flash_attn_func, flash_attn_with_kvcache except: print("falsh attn not found") import math import json class DynamicScaledDotProductAttention: remaining_length: int cpu_infer = None def __init__( self, max_seq_len: int, block_size: int, config: AutoConfig, device: torch.device, local_windows_len: int, topk: int, threads_num: int, anchor_type: str = "DYNAMIC", kv_type: str = "FP16", dense_layer_num: int = 0, anchor_num: int = 1, block_selection_mode: str = "SHARED", layer_step: int = 1, token_step: int = 1, preselect_block: bool = False, preselect_block_count: int = 96, prefill_chunk_size: int = 20480, use_attn_sparsity: bool = False, ): # assert anchor_num == 1 # assert anchor_type == "DYNAMIC" self.remaining_length = 0 valid_anchor_types = ["DYNAMIC", "FIXED", "BLOCK_MEAN", "BLOCK_MAX", "QUEST"] assert anchor_type in valid_anchor_types if anchor_type == "QUEST": assert anchor_num == 2 elif anchor_type != "FIXED" and anchor_type != "DYNAMIC": assert anchor_num == 1 valid_kv_types = ["FP16", "FP32", "Q4_0", "Q8_0"] assert kv_type in valid_kv_types if kv_type != "FP16" and kv_type != "FP32": assert block_size % 32 == 0 valid_block_selection_modes = ["SHARED", "SEPARATE"] # individual assert block_selection_mode in valid_block_selection_modes self.max_seq_len = max_seq_len self.block_num = max_seq_len // block_size self.block_size = block_size self.anchor_type = anchor_type self.kv_type = kv_type self.anchor_num = anchor_num self.threads_num = threads_num self.layer_step = layer_step self.token_step = token_step self.preselect_block = preselect_block self.preselect_block_count = preselect_block_count self.block_selection_mode = block_selection_mode self.use_attn_sparsity = use_attn_sparsity # model config self.kv_head_num = config.num_key_value_heads self.q_head_num = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads self.layer_num = config.num_hidden_layers self.device = device self.local_windows_len = local_windows_len self.local_block_num = self.local_windows_len // self.block_size + 1 self.prefill_chunk_size = prefill_chunk_size self.topk = topk self.dense_layer_num = dense_layer_num # self.dense_layer_num = 32 self.cache_key_states = torch.zeros( (self.block_num, block_size, self.kv_head_num, self.head_dim), device=device, dtype=torch.float16, ) self.cache_value_states = torch.zeros( (self.block_num, block_size, self.kv_head_num, self.head_dim), device=device, dtype=torch.float16, ) # [max_num_block, block_size, head_num] self.cache_importance = torch.zeros( (self.block_num, block_size, self.q_head_num), device=device, dtype=torch.float16, ) # key_states: [bsz, q_len, kv_head_num, head_dim] # value_states: [bsz, q_len, kv_head_num, head_dim] # query_states: [bsz, q_len, q_head_num, head_dim] self.q_in_cpu = torch.zeros( (1, 1, self.q_head_num, self.head_dim), device="cpu", dtype=torch.float16, pin_memory=True, ) self.k_in_cpu = torch.zeros( (1, 1, self.kv_head_num, self.head_dim), device="cpu", dtype=torch.float16, pin_memory=True, ) self.v_in_cpu = torch.zeros( (1, 1, self.kv_head_num, self.head_dim), device="cpu", dtype=torch.float16, pin_memory=True, ) self.cache_seqlens_cpu = torch.empty( (1,), device="cpu", dtype=torch.int32, pin_memory=True ) self.cache_seqlens_cuda = torch.empty((1,), device=device, dtype=torch.int32) self.prefix_block_table = torch.arange( self.block_num, device="cpu", dtype=torch.int32, pin_memory=True ).view(1, -1) self.block_table_cpu = torch.arange( self.block_num, device="cpu", dtype=torch.int32, pin_memory=True ).view(1, -1) # assert ( # self.local_windows_len // self.block_size + 1 + self.preselect_block_count # <= self.block_num # ) self.output_cpu = torch.empty( (1, 1, self.q_head_num, self.head_dim), device="cpu", dtype=torch.float16, pin_memory=True, ) self.lse_cpu = torch.empty( (1, 1, self.q_head_num), device="cpu", dtype=torch.float32, pin_memory=True ) self.output_cuda = torch.empty( (1, 1, self.q_head_num, self.head_dim), device=device, dtype=torch.float16 ) self.attn_sparsity = torch.zeros( (1, 1, self.q_head_num), device="cpu", dtype=torch.float32, pin_memory=True ) if preselect_block == True: self.preselect_block_table = torch.zeros( self.layer_num, self.preselect_block_count, device=device, dtype=torch.int32, ) self.preselect_block_num = 0 # block_num before preselect self.evict_tokens = 0 if DynamicScaledDotProductAttention.cpu_infer is None: DynamicScaledDotProductAttention.cpu_infer = CPUInfer(threads_num) self.cpu_infer = DynamicScaledDotProductAttention.cpu_infer self.local_thread = CPUInferKVCache( self.layer_num, self.kv_head_num, self.q_head_num, self.head_dim, self.block_size, anchor_num=self.anchor_num, anchor_type=anchor_type, kv_type=self.kv_type, retrieval_type=self.block_selection_mode, layer_step=self.layer_step, token_step=self.token_step, layer_offset=self.dense_layer_num % self.layer_step, max_batch_size=1, max_block_num=self.block_num, max_thread_num=self.threads_num, ) print( f"local_windows_len: {local_windows_len}, topk: {topk}, dense_layer_num: {dense_layer_num}, kv_type: {self.kv_type}, anchor_type: {self.anchor_type}, preselect_block: {self.preselect_block}, preselect_block_count: {self.preselect_block_count}, token_step: {self.token_step}, layer_step: {self.layer_step}" ) self.shape_mask = ( self.q_head_num, self.block_size, self.block_size, ) mask = torch.zeros( self.shape_mask, dtype=torch.uint8, device=device ).contiguous() elm_idx = torch.arange(self.block_size, device=device) for i in range(mask.size(-2)): idx = i + mask.size(-1) - mask.size(-2) - elm_idx idx = idx[idx >= 0] mask[..., i, idx] = 1 self.tril_mask = mask self.triu_mask = mask ^ 1 self.generate_token_idx = 0 def get_attn_score_one_block( self, batch_idx: int, max_block_num: int, query: torch.Tensor, key: torch.Tensor, offset: int, width: int, mask_mode: str | None = None, use_softmax: bool = True, ): n_rep = self.q_head_num // self.kv_head_num importance = self.cache_importance.view(-1, self.q_head_num) importance = importance.narrow(0, batch_idx * max_block_num + offset, width) n_gqa_ = self.q_head_num // self.kv_head_num for head_idx in range(self.q_head_num): key_item = key[..., head_idx // n_gqa_, :].view(key.size(0), -1) qk = torch.einsum( "qd,kd->qk", query[:,head_idx,:], key_item ) # (num_attention_heads, len_q, len_k) if mask_mode == "tril": mask = self.tril_mask mask = mask[0, -qk.size(-2) :, -qk.size(-1) :] qk = qk * mask elif mask_mode == "triu": mask = self.triu_mask mask = mask[0, -qk.size(-2) :, -qk.size(-1) :] qk = qk * mask if use_softmax: qk = torch.nn.functional.softmax( qk / math.sqrt(self.head_dim), dim=-1, dtype=torch.float32 ).to(torch.float16) qk = torch.sum(qk, dim=-2) importance[...,head_idx] += qk def get_preselect_block_table_and_attn_score( self, layer_idx: int, batch_size: int, offset: torch.Tensor, width: int, query: torch.Tensor, key: torch.Tensor, union_with_last_layer: bool = True, ): max_seqs_len = offset.max().item() + width max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size for batch_idx in range(batch_size): query_cur = query[batch_idx][-128:] self.get_attn_score_one_block( batch_idx, max_block_num, query_cur, key[batch_idx][: offset[batch_idx].item() + width], 0, offset[batch_idx].item() + width, mask_mode=None, ) if self.preselect_block: self.prefill_block_num = max( 0, max_block_num - self.local_windows_len // self.block_size ) self.evict_tokens = ( max(self.prefill_block_num - self.preselect_block_count, 0) * self.block_size ) if self.prefill_block_num != 0: importance_cache = self.cache_importance.narrow( 0, 0, self.prefill_block_num * batch_size ).view( batch_size, self.prefill_block_num, self.block_size, self.q_head_num ) importance_r = importance_cache[:, 1:, : self.block_size // 4] pad_r = torch.zeros_like(importance_r[:, :1]) importance_r = torch.cat((importance_r, pad_r), dim=1) importance_l = importance_cache[:, :-1, -self.block_size // 4 :] pad_l = torch.zeros_like(importance_l[:, :1]) importance_l = torch.cat((pad_l, importance_l), dim=1) importance = torch.cat( (importance_l, importance_cache, importance_r), dim=2 ) importance = importance.mean(dim=-1) importance = importance.mean(dim=-1) # importance: (batch_size, max_block_num) topk = min(self.preselect_block_count, self.prefill_block_num) values, indices = torch.topk( importance, k=topk, dim=1, ) self.preselect_block_table[ layer_idx : layer_idx + 1, :topk, ].copy_(indices) if union_with_last_layer and layer_idx == 31: for tmp_layer_idx in range(self.layer_num - 1): for i in range(1, min(topk, 6)): x = self.preselect_block_table[-1, i] if x not in self.preselect_block_table[tmp_layer_idx]: self.preselect_block_table[tmp_layer_idx, topk - i] = x if self.anchor_type == "DYNAMIC": importance_cache = self.cache_importance.narrow( 0, 0, max_block_num * batch_size ).view(batch_size, max_block_num * self.block_size, self.q_head_num) importance_cache_cpu = torch.empty_like( importance_cache, device="cpu", pin_memory=True ) importance_cache_cpu.copy_(importance_cache) block_table_cpu = self.prefix_block_table[:, :max_block_num].to("cpu") offset_cpu = offset.contiguous().to("cpu") self.cpu_infer.submit( self.local_thread.update_importance( importance_cache_cpu, layer_idx, block_table_cpu, max_block_num, offset_cpu, width, ) ) self.cpu_infer.sync() importance_cache = self.cache_importance.narrow( 0, 0, max_block_num * batch_size ).view(batch_size, max_block_num * self.block_size, self.q_head_num) importance_cache.zero_() # key: [bsz, past_len, head_num, head_dim] float16 # query: [bsz, q_len, q_head_num, head_dim] float16 def get_attn_score( self, layer_idx: int, batch_size: int, offset: torch.Tensor, width: int, query: torch.Tensor, key: torch.Tensor, ): max_seqs_len = offset.max().item() + width max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size for batch_idx in range(batch_size): for idx in range(width // self.block_size): offset_cur = idx * self.block_size query_cur = query[batch_idx, offset_cur : offset_cur + self.block_size] self.get_attn_score_one_block( batch_idx, max_block_num, query_cur, key[ batch_idx, offset[batch_idx] + offset_cur : offset[batch_idx] + offset_cur + self.block_size, ], offset[batch_idx].item() + offset_cur, self.block_size, mask_mode="tril", use_softmax=False, ) offset_key = ( offset[batch_idx].item() + idx * self.block_size - self.local_windows_len ) if offset_key >= 0: self.get_attn_score_one_block( batch_idx, max_block_num, query_cur, key[batch_idx, offset_key : offset_key + self.block_size], offset_key, self.block_size, mask_mode="triu", use_softmax=False, ) offset_key = max(0, offset_key + self.block_size) width_key = ( offset[batch_idx].item() + idx * self.block_size - offset_key ) if width_key > 0: self.get_attn_score_one_block( batch_idx, max_block_num, query_cur, key[batch_idx, offset_key : offset_key + width_key], offset_key, width_key, mask_mode=None, use_softmax=False, ) importance_cache = self.cache_importance.narrow( 0, 0, max_block_num * batch_size ).view(batch_size, max_block_num * self.block_size, self.q_head_num) importance_cache_cpu = torch.empty_like( importance_cache, device="cpu", pin_memory=True ) importance_cache_cpu.copy_(importance_cache) block_table_cpu = self.prefix_block_table[:, :max_block_num].to("cpu") offset_cpu = offset.contiguous().to("cpu") self.cpu_infer.submit( self.local_thread.update_importance( importance_cache_cpu, layer_idx, block_table_cpu, max_block_num, offset_cpu, width, ) ) self.cpu_infer.sync() importance_cache.zero_() # key: [bsz, q_len, head_num, head_dim] float16 # value: [bsz, q_len, head_num, head_dim] float16 def swap_in_and_swap_out(self, layer_idx, past_len, q_len, key, value): batch_size = 1 max_seqs_len = past_len.max().item() + q_len max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size k_cache = self.cache_key_states.narrow(0, 0, max_block_num * batch_size).view( batch_size, max_block_num * self.block_size, self.kv_head_num, self.head_dim ) v_cache = self.cache_value_states.narrow(0, 0, max_block_num * batch_size).view( batch_size, max_block_num * self.block_size, self.kv_head_num, self.head_dim ) for batch_idx in range(batch_size): offset = past_len[batch_idx] width = q_len k_cache[batch_idx][offset : offset + width].copy_( key[batch_idx].view(-1, self.kv_head_num, self.head_dim) ) v_cache[batch_idx][offset : offset + width].copy_( value[batch_idx].view(-1, self.kv_head_num, self.head_dim) ) k_cache_cpu = torch.empty_like(k_cache, device="cpu", pin_memory=True) v_cache_cpu = torch.empty_like(v_cache, device="cpu", pin_memory=True) k_cache_cpu.copy_(k_cache) v_cache_cpu.copy_(v_cache) cur_block_num = ( q_len + past_len[0].item() + self.block_size - 1 ) // self.block_size block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu") past_len_cpu = past_len.contiguous().to("cpu") self.cpu_infer.submit( self.local_thread.get_and_update_kvcache_fp16( k_cache_cpu, v_cache_cpu, layer_idx, block_table_cpu, max_block_num, past_len_cpu, q_len, ) ) self.cpu_infer.sync() k_cache.copy_(k_cache_cpu) v_cache.copy_(v_cache_cpu) return k_cache, v_cache def calc_anchor(self, cache_seqlens: int): cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu") cache_seqlens_cpu = torch.tensor( [cache_seqlens], device="cpu", dtype=torch.int32 ) self.cpu_infer.submit( self.local_thread.calc_anchor_all_layers( block_table_cpu, cache_seqlens_cpu, ) ) self.cpu_infer.sync() def clear_importance(self, cache_seqlens: int): print(f"clear importance: {cache_seqlens}") cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu") cache_seqlens_cpu = torch.tensor( [cache_seqlens], device="cpu", dtype=torch.int32 ) self.cpu_infer.submit( self.local_thread.clear_importance_all_layers( block_table_cpu, cache_seqlens_cpu, ) ) self.cpu_infer.sync() def clear_kvcache(self, cache_seqlens: int): cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size block_table_cpu = self.prefix_block_table[:, :cur_block_num].to("cpu") cache_seqlens_cpu = torch.tensor( [cache_seqlens], device="cpu", dtype=torch.int32 ) self.cpu_infer.submit( self.local_thread.clear_kvcache_all_layers( block_table_cpu, cache_seqlens_cpu, ) ) self.cpu_infer.sync() def get_attn_sparsity( self, q_in: torch.Tensor, layer_idx: int, block_table: torch.Tensor, cache_seqlens: torch.Tensor, block_table_origin: torch.Tensor, cache_seqlens_origin: torch.Tensor, generate_token_idx: int = 0, topk: int | None = None, local: int | None = None, output_path: str = "./attn_sparsity.json", ): self.attn_sparsity.zero_() self.pcinfer.submit( self.local_thread.get_attn_sparsity( q_in, self.attn_sparsity, layer_idx, block_table, cache_seqlens, block_table_origin, cache_seqlens_origin, generate_token_idx, topk, local, ) ) self.cpu_infer.sync() with open(output_path, "a") as file: for head_idx in range(self.q_head_num): sparsity = self.attn_sparsity[0][0][head_idx].item() json_obj = { "token_idx": generate_token_idx, "layer_idx": layer_idx, "head_idx": head_idx, "sparsity": sparsity, } json.dump(json_obj, file) file.write("\n") def apply( self, layer_idx: int, bsz: int, past_len: int, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, mode: str = "prefill", generate_token_idx: int = -1, ): # key_states: [bsz, q_len, kv_head_num, head_dim] # value_states: [bsz, q_len, kv_head_num, head_dim] # query_states: [bsz, q_len, q_head_num, head_dim] assert query_states.dtype == torch.float16 assert key_states.dtype == torch.float16 assert value_states.dtype == torch.float16 assert key_states.size(2) == self.kv_head_num assert value_states.size(2) == self.kv_head_num assert query_states.size(2) == self.q_head_num q_len = query_states.size(1) batch_size = query_states.size(0) self.cache_seqlens_cuda.fill_(past_len) last_chunk = False if self.remaining_length <= self.prefill_chunk_size and q_len != 1: last_chunk = True device = query_states.device if layer_idx == 0: if q_len == 1: self.generate_token_idx += 1 elif last_chunk: self.generate_token_idx = -1 if mode == "prefill": key, value = self.swap_in_and_swap_out( layer_idx, self.cache_seqlens_cuda, q_len, key_states, value_states, ) if last_chunk and (self.anchor_type == "DYNAMIC" or self.preselect_block): self.get_preselect_block_table_and_attn_score( layer_idx, bsz, self.cache_seqlens_cuda, q_len, query_states, key, ) output = flash_attn_with_kvcache( q=query_states, k_cache=key, v_cache=value, cache_seqlens=self.cache_seqlens_cuda + q_len, causal=True, ) return output.transpose(1, 2) elif mode == "generate": assert self.generate_token_idx >= 0 self.q_in_cpu.copy_(query_states, non_blocking=True) self.k_in_cpu.copy_(key_states, non_blocking=True) self.v_in_cpu.copy_(value_states, non_blocking=True) self.cache_seqlens_cpu.copy_(self.cache_seqlens_cuda, non_blocking=True) # print(layer_idx) if layer_idx < self.dense_layer_num: self.block_table_cpu.copy_(self.prefix_block_table, non_blocking=True) self.cpu_infer.submit_with_cuda_stream( torch.cuda.current_stream("cuda").cuda_stream, self.local_thread.attn_with_kvcache( q_in=self.q_in_cpu, k_in=self.k_in_cpu, v_in=self.v_in_cpu, output=self.output_cpu, attn_lse=self.lse_cpu, layer_idx=layer_idx, block_table=self.block_table_cpu, cache_seqlens=self.cache_seqlens_cpu, ), ) else: if self.preselect_block: self.cache_seqlens_cpu.copy_( self.cache_seqlens_cuda - self.evict_tokens, non_blocking=True ) if self.preselect_block_count < self.prefill_block_num: self.block_table_cpu[:, : self.preselect_block_count].copy_( self.preselect_block_table[layer_idx : layer_idx + 1], non_blocking=True, ) self.block_table_cpu[ :, self.preselect_block_count : self.preselect_block_count + self.local_block_num, ].copy_( self.prefix_block_table[ :, self.prefill_block_num : self.prefill_block_num + self.local_block_num, ], non_blocking=True, ) # print("submit_with_cuda_stream") self.cpu_infer.submit_with_cuda_stream( torch.cuda.current_stream("cuda").cuda_stream, self.local_thread.attn_with_kvcache( q_in=self.q_in_cpu, k_in=self.k_in_cpu, v_in=self.v_in_cpu, output=self.output_cpu, attn_lse=self.lse_cpu, layer_idx=layer_idx, generate_token_idx=self.generate_token_idx, block_table=self.block_table_cpu, cache_seqlens=self.cache_seqlens_cpu, topk=( self.topk if self.topk <= self.preselect_block_count else None ), local=self.local_windows_len // self.block_size, ), ) # print("submit_with_cuda_stream enqueue\n") else: self.block_table_cpu.copy_( self.prefix_block_table, non_blocking=True ) self.cpu_infer.submit_with_cuda_stream( torch.cuda.current_stream("cuda").cuda_stream, self.local_thread.attn_with_kvcache( q_in=self.q_in_cpu, k_in=self.k_in_cpu, v_in=self.v_in_cpu, output=self.output_cpu, attn_lse=self.lse_cpu, layer_idx=layer_idx, generate_token_idx=self.generate_token_idx, block_table=self.block_table_cpu, cache_seqlens=self.cache_seqlens_cpu, topk=self.topk, local=self.local_windows_len // self.block_size, ), ) self.cpu_infer.sync_with_cuda_stream( torch.cuda.current_stream("cuda").cuda_stream ) # print("submit_with_cuda_stream finished\n") self.output_cuda.copy_(self.output_cpu, non_blocking=True) return self.output_cuda.transpose(1, 2) def save(self, path: str, length: int): cur_block_num = (length + self.block_size - 1) // self.block_size block_table_cpu = self.prefix_block_table[0, :cur_block_num].to("cpu") cache_seqlens_cpu = torch.tensor([length], device="cpu", dtype=torch.int32) self.cpu_infer.submit( self.local_thread.dump_kvcache( block_table_cpu, cache_seqlens_cpu, path, ) ) self.cpu_infer.sync() def load(self, path: str, length: int): self.cpu_infer.submit( self.local_thread.load_kvcache( path, ) ) self.cpu_infer.sync() ================================================ FILE: archive/ktransformers/operators/experts.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : Azure-Tang, Boxin Zhang, chenht2022 Date : 2024-07-25 11:25:24 Version : 0.1.0 LastEditors : Azure LastEditTime : 2024-08-29 09:41:10 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' from typing import Any, Union import numpy as np import numpy.typing as npt from torch import Tensor, nn import torch.nn.functional as F import torch import sys, os from ktransformers.operators.base_operator import BaseInjectedModule from tqdm import tqdm sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build")) sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release")) sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug")) import cpuinfer_ext from cpuinfer_ext.moe import MOEConfig, MOE import ctypes from ktransformers.util.custom_gguf import GGMLQuantizationType, translate_name_to_gguf from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader, ModelLoader from ktransformers.util.utils import InferenceState from ktransformers.server.config.config import Config from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from abc import ABC, abstractmethod from ktransformers.operators.linear import KLinearMarlin, KLinearTorch, KTransformersLinear import time from ktransformers.operators.cpuinfer import CPUInfer try: import torch_npu from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size use_torch_npu = torch_npu.npu.is_available() except: use_torch_npu = False def deduplicate_and_sort(lst): return sorted(set(lst)) def generate_cuda_graphs(chunk_size: int) -> list: assert chunk_size <= 1024 or chunk_size % 1024 == 0, "chunk_size must <= 1024 or a multiple of 1024" base_list = [1, 2, 3, Config().max_batch_size, 64, 256, 512, chunk_size] if chunk_size <= 1024: return deduplicate_and_sort(base_list) multiples = [i for i in range(1024, chunk_size + 1, 1024)] return deduplicate_and_sort(base_list + multiples) #cuda_graphs = [Config().chunk_size] if torch.cuda.is_available(): cuda_graphs = generate_cuda_graphs(Config().chunk_size) elif use_torch_npu: cuda_graphs = deduplicate_and_sort([1, 2, 3, 4]) else: cuda_graphs = 1 # class Base(BaseInjectedModule, ABC): class KExpertsBase(ABC): def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs): # super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) self.key = key self.gguf_loader = gguf_loader self.config = config self.device = device @abstractmethod def forward(self, input_tensor, expert_ids, weights): pass @abstractmethod def load(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu", warmup: bool = False): pass @abstractmethod def unload(): pass def load_weights(self, override_key: str | None = None, device: str = "cpu"): res = {} if override_key is not None: keys = override_key else: keys = [self.key] gate = None up = None down = None gate_type = None up_type = None down_type = None for key in keys: if self.gguf_loader.has_tensor(key + ".ffn_gate_exps.weight"): targets = [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight" ] tensors = self.load_multi(key, targets, device=device) gate = tensors[".ffn_gate_exps.weight"] up = tensors[".ffn_up_exps.weight"] down = tensors[".ffn_down_exps.weight"] gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"] up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"] down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"] elif self.gguf_loader.has_tensor(key + ".ffn_down.0.weight"): # for supporting Mixtral-8x7B-Instuct gate = [] up = [] down = [] for i in range(8): gatei, upi, downi = f".ffn_gate.{i}.weight", f".ffn_up.{i}.weight", f".ffn_down.{i}.weight" targets = [gatei, upi, downi] tensors = self.load_multi(key, targets, device=device) gate_it, up_it, down_it = tensors[gatei], tensors[upi], tensors[downi] gate.append(gate_it) up.append(up_it) down.append(down_it) gate = torch.stack(gate) up = torch.stack(up) down = torch.stack(down) gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate.0.weight"]["ggml_type"] up_type = self.gguf_loader.tensor_info[key + ".ffn_up.0.weight"]["ggml_type"] down_type = self.gguf_loader.tensor_info[key + ".ffn_down.0.weight"]["ggml_type"] else: raise ValueError(f"Experts {key} not found in gguf_loader") res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}} return res def load_multi(self, key: str, keys: list[str], device: str = "cpu"): tensors = {} for k in keys: tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device) return tensors class KExpertsCPU(KExpertsBase): input_tensor_cpu:Tensor = None expert_ids_cpu:Tensor = None weights_cpu:Tensor = None output_cpu:Tensor = None output_gpu_map:dict = {} # Manage output tensor buffer on different gpu #stream_map:dict = {} # Manage cuda stream on different gpu # @TODO add yaml CPU_INFER = CPUInfer(Config().cpu_infer) def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, n_routed_experts: int, orig_module: nn.Module = None, device: str = "cpu", out_device: str = "cuda", # this device mean which device the output should on. TODO: support cpu. **kwargs ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU" self.n_routed_experts = n_routed_experts self.out_device = out_device self.backend = kwargs.get("backend", "llamafile") def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False): if use_torch_npu and get_tensor_parallel_size() != 1 and ( not torch.distributed.is_initialized() or torch.distributed.get_rank() != 0): return if device: assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU, Parameter \"device\" can be cpu or None." if w is None: w = self.load_weights()[self.key] self.gate = w["gate"] self.up = w["up"] self.down = w["down"] self.gate_type = w["gate_type"] self.up_type = w["up_type"] self.down_type = w["down_type"] gate_ptr = ctypes.addressof( ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents ) up_ptr = ctypes.addressof( ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents ) down_ptr = ctypes.addressof( ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents ) # print(self.gate_qtype, self.up_qtype, self.down_qtype) n_routed_experts = self.n_routed_experts self.cpu_infer = KExpertsCPU.CPU_INFER # n_routed_experts = len(self.orig_module) model_dtype = torch.get_default_dtype() if torch.xpu.is_available() and model_dtype == torch.float16: hidden_type = 1 # fp16 else: hidden_type = 30 # bf16 if self.backend == "llamafile": moe_config = MOEConfig( n_routed_experts, self.config.num_experts_per_tok, self.config.hidden_size, self.config.moe_intermediate_size, 64, 10, 1024, self.config.hidden_act == 'silu', gate_ptr, up_ptr, down_ptr, self.gate_type, self.up_type, self.down_type, hidden_type, # TODO: get from model.dtype ) self.moe = MOE(moe_config) elif self.backend == "AMXBF16": from cpuinfer_ext.moe import AMX_MOEConfig, AMXBF16_MOE assert self.gate_type == GGMLQuantizationType.BF16 assert self.up_type == GGMLQuantizationType.BF16 assert self.down_type == GGMLQuantizationType.BF16 moe_config = AMX_MOEConfig( n_routed_experts, self.config.num_experts_per_tok, self.config.hidden_size, self.config.moe_intermediate_size, max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size, self.config.hidden_act == 'silu', gate_ptr, up_ptr, down_ptr, ) self.moe = AMXBF16_MOE(moe_config) self.cpu_infer.submit(self.moe.load_weights()) self.cpu_infer.sync() elif self.backend == "AMXInt8": from cpuinfer_ext.moe import AMX_MOEConfig, AMXInt8_MOE assert self.gate_type == GGMLQuantizationType.BF16 assert self.up_type == GGMLQuantizationType.BF16 assert self.down_type == GGMLQuantizationType.BF16 moe_config = AMX_MOEConfig( n_routed_experts, self.config.num_experts_per_tok, self.config.hidden_size, self.config.moe_intermediate_size, max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size, self.config.hidden_act == 'silu', gate_ptr, up_ptr, down_ptr, ) self.moe = AMXInt8_MOE(moe_config) self.cpu_infer.submit(self.moe.load_weights()) self.cpu_infer.sync() # print(n_routed_experts, hidden_size, moe_intermediate_size) num_experts_per_tok = self.config.num_experts_per_tok if warmup: self.cpu_infer.submit(self.moe.warm_up()) self.cpu_infer.sync() if self.out_device not in KExpertsCPU.output_gpu_map: if isinstance(cuda_graphs, list): KExpertsCPU.output_gpu_map[self.out_device] = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device=self.out_device) for i in range(len(cuda_graphs))] else: KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((cuda_graphs, self.config.hidden_size), device=self.out_device) if KExpertsCPU.input_tensor_cpu == None: if isinstance(cuda_graphs, list): if use_torch_npu: KExpertsCPU.input_tensor_cpu = [[torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)] for i in range(len(cuda_graphs))] KExpertsCPU.expert_ids_cpu = [[torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)] for i in range(len(cuda_graphs))] KExpertsCPU.weights_cpu = [[torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)] for i in range(len(cuda_graphs))] KExpertsCPU.output_cpu = [[torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)] for i in range(len(cuda_graphs))] KExpertsCPU.bsz_tensor_cpu = [[torch.tensor([cuda_graphs[i]], device="cpu", dtype=torch.int32, pin_memory=True)] for i in range(len(cuda_graphs))] else: KExpertsCPU.input_tensor_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True) for i in range(len(cuda_graphs))] KExpertsCPU.expert_ids_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) for i in range(len(cuda_graphs))] KExpertsCPU.weights_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) for i in range(len(cuda_graphs))] KExpertsCPU.output_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) for i in range(len(cuda_graphs))] KExpertsCPU.bsz_tensor_cpu = [torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True) for i in range(len(cuda_graphs))] else: KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True) KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) if torch.xpu.is_available(): KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=model_dtype) KExpertsCPU.bsz_tensor_cpu = torch.ones((1), device="cpu", dtype=torch.int32, pin_memory=True) else: KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True) def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0): if bsz_tensor is None: bsz_tensor = torch.ones(1, device=input_tensor.device, dtype=torch.int32) if cuda_graph_idx != -1: KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True) KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True) KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True) KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True) self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr())) else: KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True) KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True) KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True) KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True) self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr())) def sync_for_one_decode(self, cuda_graph_idx=0): if cuda_graph_idx != -1: self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream) KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True) return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx] else: self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream) KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True) return KExpertsCPU.output_gpu_map[self.out_device] def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0): # generate, capture and run cuda graph # print(expert_ids) if bsz_tensor is None and (not torch.xpu.is_available() or input_tensor.size(0) > 1): bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32) if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): if cuda_graph_idx != -1: KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True) KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True) KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True) KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True) self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr())) self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True) return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx] else: KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True) KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True) KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True) KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True) self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr())) self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True) return KExpertsCPU.output_gpu_map[self.out_device] elif input_tensor.size(0)==1 and torch.xpu.is_available(): KExpertsCPU.input_tensor_cpu.copy_(input_tensor.view(-1), non_blocking=True) KExpertsCPU.expert_ids_cpu.copy_(expert_ids.view(-1), non_blocking=True) KExpertsCPU.weights_cpu.copy_(weights.view(-1), non_blocking=True) # KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor.view(-1), non_blocking=True) self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr())) self.cpu_infer.sync() KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True) return KExpertsCPU.output_gpu_map[self.out_device].view(1, -1) else: input_tensor = input_tensor.contiguous().cpu() expert_ids = expert_ids.contiguous().cpu() weights = weights.contiguous().to(torch.float32).cpu() bsz_tensor = bsz_tensor.contiguous().cpu() output = torch.empty_like(input_tensor).contiguous() self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr())) self.cpu_infer.sync() return output.to(device=object.__getattribute__(self, "out_device")) def unload(self): return def load_weights(self, override_key: str | None = None, device: str = "cpu"): # TODO: support Bias res = {} if override_key is not None: keys = override_key else: keys = [self.key] gate = None up = None down = None gate_type = None up_type = None down_type = None for key in keys: if isinstance(self.gguf_loader, SafeTensorLoader): res = self.gguf_loader.load_experts(key) return {key: res} elif self.gguf_loader.has_tensor(key + ".ffn_gate_exps.weight"): gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight") up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight") down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight") # gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"] # up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"] # down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"] gate_type = self.gguf_loader.get_ggml_type(key + ".ffn_gate_exps.weight") up_type = self.gguf_loader.get_ggml_type(key + ".ffn_up_exps.weight") down_type = self.gguf_loader.get_ggml_type(key + ".ffn_down_exps.weight") elif key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight") up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight") down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight") gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"] up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"] down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"] elif key + ".ffn_down.0.weight" in self.gguf_loader.tensor_info: # for supporting Mixtral-8x7B-Instuct gate = [] up = [] down = [] for i in range(8): gate_it = self.gguf_loader.get_mmap_tensor(f"{key}.ffn_gate.{i}.weight") up_it = self.gguf_loader.get_mmap_tensor(f"{key}.ffn_up.{i}.weight") down_it = self.gguf_loader.get_mmap_tensor(f"{key}.ffn_down.{i}.weight") gate.append(gate_it) up.append(up_it) down.append(down_it) gate = np.stack(gate) up = np.stack(up) down = np.stack(down) gate_type = self.gguf_loader.get_ggml_type(key + ".ffn_gate.0.weight") up_type = self.gguf_loader.get_ggml_type(key + ".ffn_up.0.weight") down_type = self.gguf_loader.get_ggml_type(key + ".ffn_down.0.weight") elif self.gguf_loader.safetensor_loader is not None: # for npu # using a temp ugly way to temprary load the tensor translate_key = translate_name_to_gguf(key) gate = self.gguf_loader.safetensor_loader.load_tensor(translate_key + ".ffn_gate_exps.weight").numpy() up = self.gguf_loader.safetensor_loader.load_tensor(translate_key + ".ffn_up_exps.weight").numpy() down = self.gguf_loader.safetensor_loader.load_tensor(translate_key + ".ffn_down_exps.weight").numpy() gate_type = self.gguf_loader.safetensor_loader.load_tensor(translate_key + ".ffn_gate_exps.ggml_type").item() up_type = self.gguf_loader.safetensor_loader.load_tensor(translate_key + ".ffn_up_exps.ggml_type").item() down_type = self.gguf_loader.safetensor_loader.load_tensor(translate_key + ".ffn_down_exps.ggml_type").item() else: raise ValueError(f"Experts {key} not found in gguf_loader") res = {key:{"gate": gate, "up": up, "down": down, "gate_type": gate_type, "up_type": up_type, "down_type": down_type}} return res class KExpertsMarlin(KExpertsBase): expert_num: int loaded_experts_idx: list[int] def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, n_routed_experts: int, orig_module: nn.Module = None, device: str = "cuda", **kwargs ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) self.expert_num = n_routed_experts self.loaded_experts_idx = [] self.act_fn = ACT2FN[config.hidden_act] assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU" self.device = device self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size # create empty marlin experts according to the number of experts per token # up self.up_projs = [KLinearMarlin(key+ "." + "ffn_up_exps", gguf_loader, config, device=device) for i in range(self.expert_num)] # gate self.gate_projs = [KLinearMarlin(key+ "." + "ffn_gate_exps", gguf_loader, config, device=device) for i in range(self.expert_num)] # down self.down_projs = [KLinearMarlin(key+ "." + "ffn_down_exps", gguf_loader, config, device=device) for i in range(self.expert_num)] def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False): if device is None: device = self.device assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU" if w is None: w = self.load_weights() load_by_experts = True if load_by_experts: if isinstance(w, dict): self.gate = w["gate"] self.up = (w["up"]) self.down = (w["down"]) for i in tqdm(range(self.expert_num), desc=f"Dequanting and quanting for KExpertsMarlin {self.key}"): up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", self.up, i, self.elements_per_tensor, device=self.device) gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", self.gate, i, self.elements_per_tensor, device=self.device) down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", self.down, i, self.elements_per_tensor, device=self.device) self.up_projs[i].load(nn.Parameter(up_weights), device=device) self.gate_projs[i].load(nn.Parameter(gate_weights), device=device) self.down_projs[i].load(nn.Parameter(down_weights), device=device) self.loaded_experts_idx.append(i) else: if isinstance(w, dict): self.gate = w["gate"] self.up = (w["up"]) self.down = (w["down"]) for i in range(self.expert_num): self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device) self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device) self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device) self.loaded_experts_idx.append(i) return def unload(self): for i in self.loaded_experts_idx: self.up_projs[i].unload() self.gate_projs[i].unload() self.down_projs[i].unload() self.loaded_experts_idx = [] def load_weights(self, override_key: str | None = None): res = {} if override_key is not None: keys = override_key else: keys = [self.key] gate = None up = None down = None for key in keys: if self.gguf_loader.has_tensor(key + ".ffn_gate_exps.weight"): gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight") up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight") down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight") res = {"gate": gate, "up": up, "down": down} return res def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: org_dtype = hidden_states_cpu.dtype org_device = hidden_states_cpu.device hidden_states_cpu = hidden_states_cpu.to(self.device) selected_experts_cpu = selected_experts_cpu.to(self.device) routing_weights_cpu = routing_weights_cpu.to(self.device).to(org_dtype) batch_sequence_length, hidden_dim = hidden_states_cpu.size() final_hidden_states = torch.zeros( (batch_sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.expert_num).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.expert_num): if not expert_mask[expert_idx].any(): continue idx, top_x = torch.where(expert_mask[expert_idx]) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim) G = self.gate_projs[expert_idx].forward(current_state) A = self.act_fn(G) U = self.up_projs[expert_idx].forward(current_state) H = A * U # Element-wise multiplication current_hidden_states = self.down_projs[expert_idx].forward(H) * routing_weights_cpu[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states) return final_hidden_states.to(dtype=org_dtype, device=org_device) # untested, CUDA OOM class KExpertsTorch(KExpertsBase): expert_num: int loaded_experts_idx: list[int] gate: torch.Tensor up: torch.Tensor down: torch.Tensor def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, n_routed_experts: int, orig_module: nn.Module = None, device: str = "cpu", **kwargs ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) self.expert_num = n_routed_experts # self.loaded_experts_idx = [] self.act_fn = ACT2FN[config.hidden_act] self.device = device self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size self.gate = [None for _ in range(self.expert_num)] self.up = [None for _ in range(self.expert_num)] self.down = [None for _ in range(self.expert_num)] self.dtype = torch.get_default_dtype() def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False): if device is None: device = self.device if w is None: w = self.load_weights() load_by_experts = True if load_by_experts: if isinstance(w, dict): for i in tqdm(range(self.expert_num), desc=f"Dequanting for KExpertsTorch {self.key}"): up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", w["up"], i, self.elements_per_tensor, device=self.device) gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", w["gate"], i, self.elements_per_tensor, device=self.device) down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", w["down"], i, self.elements_per_tensor, device=self.device) self.up[i] = up_weights self.gate[i] = gate_weights self.down[i] = down_weights else: if isinstance(w, dict): for i in range(self.expert_num): self.gate[i] = w["gate"][i, ...].to(device=device, dtype=self.dtype) self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype) self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype) self.up = torch.stack(self.up, dim=0) self.gate = torch.stack(self.gate, dim=0) self.down = torch.stack(self.down, dim=0) return def unload(self): if self.gate is not None: self.gate = None self.up = None self.down = None def load_weights(self, override_key: str | None = None): res = {} if override_key is not None: keys = override_key else: keys = [self.key] gate = None up = None down = None for key in keys: if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info: gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight") up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight") down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight") res = {"gate": gate, "up": up, "down": down} return res def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: org_device = hidden_states_cpu.device hidden_states_cpu = hidden_states_cpu.to(self.device) selected_experts_cpu = selected_experts_cpu.to(self.device) routing_weights_cpu = routing_weights_cpu.to(self.device) batch_sequence_length, hidden_dim = hidden_states_cpu.size() final_hidden_states = torch.zeros( (batch_sequence_length, hidden_dim), dtype=self.gate.dtype, device=hidden_states_cpu.device ) org_dtype = hidden_states_cpu.dtype hidden_states_cpu = hidden_states_cpu.to(self.gate.dtype) routing_weights_cpu = routing_weights_cpu.to(self.gate.dtype) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.expert_num).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.expert_num): idx, top_x = torch.where(expert_mask[expert_idx]) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim) G = current_state @ self.gate[expert_idx,...].T A = self.act_fn(G) U = current_state @ self.up[expert_idx,...].T H = A * U # Element-wise multiplication current_hidden_states = H @ self.down[expert_idx,...].T * routing_weights_cpu[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states) return final_hidden_states.to(dtype=org_dtype, device=org_device) EXPERTS_MAP = { "KExpertsCPU": KExpertsCPU, "KExpertsTorch": KExpertsTorch, "KExpertsMarlin": KExpertsMarlin, } class KTransformersExperts(BaseInjectedModule, KExpertsBase): def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", prefill_device:str = "cuda", prefill_op: str | None = "KExpertsTorch", generate_device: str = "cpu", generate_op: str | None = "KExpertsCPU", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) if generate_op is not None: self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs) else: self.generate_experts = None if prefill_op is not None: self.prefill_experts = EXPERTS_MAP[prefill_op](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs) else: self.prefill_experts = None self.gpu_mlp_type = prefill_op self.cpu_mlp_type = generate_op self.mode = InferenceState.UNLOAD def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True): # TODO support w as input if not mode: mode = InferenceState.GENERATE if mode == InferenceState.GENERATE: self.prefill_experts.unload() self.generate_experts.load(w, warmup=warmup) self.device = self.generate_experts.device self.mode = mode elif mode == InferenceState.PREFILL: self.generate_experts.unload() self.prefill_experts.load(w, warmup=warmup) self.device = self.prefill_experts.device self.mode = mode elif mode == InferenceState.UNLOAD: self.unload() self.mode = mode self.device = self.generate_experts.device else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") def unload(self): if self.generate_experts is not None: self.generate_experts.unload() if self.prefill_experts is not None: self.prefill_experts.unload() self.device = self.generate_experts.device def forward(self, input_tensor, expert_ids, weights): if self.mode == InferenceState.GENERATE: assert self.generate_experts is not None, "generate_experts is None" return self.generate_experts.forward(input_tensor, expert_ids, weights) elif self.mode == InferenceState.PREFILL: assert self.prefill_experts is not None, "prefill_experts is None" return self.prefill_experts.forward(input_tensor, expert_ids, weights) else: raise ValueError("load or set_inference_mode before forward") def set_inference_mode(self, mode: InferenceState): if mode == InferenceState.GENERATE: self.load(mode=InferenceState.GENERATE, warmup=False) elif mode == InferenceState.PREFILL: self.load(mode=InferenceState.PREFILL, warmup=False) elif mode == InferenceState.UNLOAD: self.unload() else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") from ktransformers.models.modeling_deepseek import DeepseekV2MoE from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock from ktransformers.models.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock from ktransformers.models.modeling_smallthinker import SmallthinkerMoeBlock from ktransformers.models.modeling_glm4_moe import Glm4MoeMoE from ktransformers.models.modeling_qwen3_next import Qwen3NextSparseMoeBlock class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ orig_shape = hidden_states.shape batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) if self.norm_topk_prob: routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode"): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0], routing_weights[0]) shared_expert_output = self.shared_expert(hidden_states) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) y += shared_expert_output y.resize_(*orig_shape) return y, router_logits hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else hidden_states.cpu() selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else selected_experts.cpu() routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else routing_weights.cpu() shared_expert_output = self.shared_expert(hidden_states) shared_expert_output = ( F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output ) if isinstance(self.experts, KExpertsBase): y = ( self.moe_kexperts( hidden_states_expert, selected_experts_expert, routing_weights_expert ) .view(*orig_shape) .to(device=hidden_states.device) ) elif hidden_states_expert.size(0) > 10: y = self.moe_infer( hidden_states_expert, selected_experts_expert, routing_weights_expert, orig_shape ).to(device=hidden_states.device) else: y = self.moe_infer_simple( hidden_states_expert, selected_experts_expert, routing_weights_expert ).to(device=hidden_states.device) y += shared_expert_output y.resize_(*orig_shape) return y, router_logits @torch.no_grad() def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: outs = self.experts(x, topk_ids, topk_weight) return outs @torch.no_grad() # TODO may bugs here def moe_infer_simple(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: ''' hidden_states_cpu: [num_tokens, hidden_size] topk_ids, topk_weight: [num_tokens, num_selected_experts] ''' outs = torch.zeros_like(hidden_states_cpu) for token_idx in range(selected_experts_cpu.size(0)): for expert_idx in range(selected_experts_cpu.size(1)): expert = self.experts[selected_experts_cpu[token_idx, expert_idx]] outs[token_idx] += expert.forward(hidden_states_cpu[token_idx]) * routing_weights_cpu[token_idx, expert_idx] return outs @torch.no_grad() # TODO may bugs here def moe_infer(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor, orig_shape: tuple) -> torch.Tensor: batch_size, sequence_length, hidden_dim = orig_shape final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer.forward(current_state) * routing_weights_cpu[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype)) return final_hidden_states class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE): def forward(self, hidden_states): identity = hidden_states orig_shape = hidden_states.shape sequence_length = orig_shape[1] topk_idx, topk_weight, aux_loss = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) y += y_ y.resize_(*orig_shape) return y if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) if isinstance(self.experts, KExpertsBase): y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device) elif hidden_states.size(0) > 10: # TODO may bugs here y = ( self.moe_infer(hidden_states, topk_idx, topk_weight) .view(*orig_shape) .to(device=hidden_states.device) ) else: # TODO may bugs here y = ( self.moe_infer_simple(hidden_states, topk_idx, topk_weight) .view(*orig_shape) .to(device=hidden_states.device) ) if self.config.n_shared_experts is not None: y += y_ return y @torch.no_grad() def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: outs = self.experts(x, topk_ids, topk_weight) return outs @torch.no_grad() # TODO may bugs here def moe_infer_simple( self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor ) -> torch.Tensor: """ x: [num_tokens, hidden_size] topk_ids, topk_weight: [num_tokens, num_selected_experts] """ outs = torch.zeros_like(x) for token_idx in range(topk_ids.size(0)): for expert_idx in range(topk_ids.size(1)): expert = self.experts[topk_ids[token_idx, expert_idx]] outs[token_idx] += ( expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] ) return outs @torch.no_grad() # TODO may bugs here def moe_infer(self, x, topk_ids, topk_weight): cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) cnts.scatter_(1, topk_ids, 1) tokens_per_expert = cnts.sum(dim=0) idxs = topk_ids.view(-1).argsort() sorted_tokens = x[idxs // topk_ids.shape[1]] tokens_per_expert = tokens_per_expert.cpu().numpy() outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue expert = self.experts[i + self.ep_rank * self.experts_per_rank] tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = expert.forward(tokens_for_this_expert) outputs.append(expert_out) start_idx = end_idx outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) new_x = torch.empty_like(outs) new_x[idxs] = outs final_out = ( new_x.view(*topk_ids.shape, -1) .type(topk_weight.dtype) .mul_(topk_weight.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype) ) return final_out class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): def forward(self, hidden_states): identity = hidden_states orig_shape = hidden_states.shape sequence_length = orig_shape[1] topk_idx, topk_weight = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # only for generate phase if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) y += y_ y.resize_(*orig_shape) return y if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) if isinstance(self.experts, KExpertsBase): y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device) elif hidden_states.size(0) > 10: # TODO may bugs here y = ( self.moe_infer(hidden_states, topk_idx, topk_weight) .view(*orig_shape) .to(device=hidden_states.device) ) else: # TODO may bugs here y = ( self.moe_infer_simple(hidden_states, topk_idx, topk_weight) .view(*orig_shape) .to(device=hidden_states.device) ) if self.config.n_shared_experts is not None: y += y_ return y @torch.no_grad() def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: outs = self.experts(x, topk_ids, topk_weight) return outs @torch.no_grad() # TODO may bugs here def moe_infer_simple( self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor ) -> torch.Tensor: """ x: [num_tokens, hidden_size] topk_ids, topk_weight: [num_tokens, num_selected_experts] """ outs = torch.zeros_like(x) for token_idx in range(topk_ids.size(0)): for expert_idx in range(topk_ids.size(1)): expert = self.experts[topk_ids[token_idx, expert_idx]] outs[token_idx] += ( expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] ) return outs @torch.no_grad() # TODO may bugs here def moe_infer(self, x, topk_ids, topk_weight): cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) cnts.scatter_(1, topk_ids, 1) tokens_per_expert = cnts.sum(dim=0) idxs = topk_ids.view(-1).argsort() sorted_tokens = x[idxs // topk_ids.shape[1]] tokens_per_expert = tokens_per_expert.cpu().numpy() outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue expert = self.experts[i + self.ep_rank * self.experts_per_rank] tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = expert.forward(tokens_for_this_expert) outputs.append(expert_out) start_idx = end_idx outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) new_x = torch.empty_like(outs) new_x[idxs] = outs final_out = ( new_x.view(*topk_ids.shape, -1) .type(topk_weight.dtype) .mul_(topk_weight.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype) ) return final_out class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ orig_shape = hidden_states.shape batch_size, sequence_length, hidden_dim = hidden_states.shape if self.training and self.jitter_noise > 0: hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode"): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0], routing_weights[0]) y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) y.resize_(*orig_shape) return y, router_logits hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else hidden_states_expert.cpu() selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else selected_experts_expert.cpu() routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else routing_weights_expert.cpu() if isinstance(self.experts, KExpertsBase): y = ( self.moe_kexperts( hidden_states_expert, selected_experts_expert, routing_weights_expert ) .view(*orig_shape) .to(device=hidden_states.device) ) elif hidden_states_expert.size(0) > 10: y = self.moe_infer( hidden_states_expert, selected_experts_expert, routing_weights_expert, orig_shape ).to(device=hidden_states.device) else: y = self.moe_infer_simple( hidden_states_expert, selected_experts_expert, routing_weights_expert ).to(device=hidden_states.device) y.resize_(*orig_shape) return y, router_logits @torch.no_grad() def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: outs = self.experts(x, topk_ids, topk_weight) return outs @torch.no_grad() # TODO may bugs here def moe_infer_simple(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: ''' hidden_states_cpu: [num_tokens, hidden_size] topk_ids, topk_weight: [num_tokens, num_selected_experts] ''' outs = torch.zeros_like(hidden_states_cpu) for token_idx in range(selected_experts_cpu.size(0)): for expert_idx in range(selected_experts_cpu.size(1)): expert = self.experts[selected_experts_cpu[token_idx, expert_idx]] outs[token_idx] += expert.forward(hidden_states_cpu[token_idx]) * routing_weights_cpu[token_idx, expert_idx] return outs @torch.no_grad() # TODO may bugs here def moe_infer(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor, orig_shape: tuple) -> torch.Tensor: batch_size, sequence_length, hidden_dim = orig_shape final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer.forward(current_state) * routing_weights_cpu[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype)) return final_hidden_states class KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE): def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0): identity = hidden_states orig_shape = hidden_states.shape sequence_length = orig_shape[1] topk_idx, topk_weight = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # only for generate phase if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx) if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity, bsz_tensor).squeeze(0) y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0) y += y_ y.resize_(*orig_shape) return y if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity, bsz_tensor).squeeze(0) if isinstance(self.experts, KExpertsBase): y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device) elif hidden_states.size(0) > 10: # TODO may bugs here y = ( self.moe_infer(hidden_states, topk_idx, topk_weight) .view(*orig_shape) .to(device=hidden_states.device) ) else: # TODO may bugs here y = ( self.moe_infer_simple(hidden_states, topk_idx, topk_weight) .view(*orig_shape) .to(device=hidden_states.device) ) if self.config.n_shared_experts is not None: y += y_ return y @torch.no_grad() def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor: outs = torch.empty_like(x) outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx) return outs @torch.no_grad() # TODO may bugs here def moe_infer_simple( self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor ) -> torch.Tensor: """ x: [num_tokens, hidden_size] topk_ids, topk_weight: [num_tokens, num_selected_experts] """ outs = torch.zeros_like(x) for token_idx in range(topk_ids.size(0)): for expert_idx in range(topk_ids.size(1)): expert = self.experts[topk_ids[token_idx, expert_idx]] outs[token_idx] += ( expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] ) return outs @torch.no_grad() # TODO may bugs here def moe_infer(self, x, topk_ids, topk_weight): cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) cnts.scatter_(1, topk_ids, 1) tokens_per_expert = cnts.sum(dim=0) idxs = topk_ids.view(-1).argsort() sorted_tokens = x[idxs // topk_ids.shape[1]] tokens_per_expert = tokens_per_expert.cpu().numpy() outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue expert = self.experts[i + self.ep_rank * self.experts_per_rank] tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = expert.forward(tokens_for_this_expert) outputs.append(expert_out) start_idx = end_idx outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) new_x = torch.empty_like(outs) new_x[idxs] = outs final_out = ( new_x.view(*topk_ids.shape, -1) .type(topk_weight.dtype) .mul_(topk_weight.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype) ) return final_out class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase): def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", prefill_device:str = "cuda", prefill_op: str | None = "KExpertsTorch", generate_device: str = "cpu", generate_op: str | None = "KExpertsCPU", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) if prefill_op == 'None': prefill_op = None if generate_op == 'None': generate_op = None if generate_op is not None: self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs) else: self.generate_experts = None if prefill_op is not None: self.prefill_experts = EXPERTS_MAP[prefill_op](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs) else: self.prefill_experts = None self.gpu_mlp_type = prefill_op self.cpu_mlp_type = generate_op self.mode = InferenceState.UNLOAD def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True): # TODO support w as input if not mode: mode = InferenceState.GENERATE if mode == InferenceState.GENERATE: self.prefill_experts.unload() self.generate_experts.load(w, warmup=warmup) self.device = self.generate_experts.device self.mode = mode elif mode == InferenceState.PREFILL: self.generate_experts.unload() self.prefill_experts.load(w, warmup=warmup) self.device = self.prefill_experts.device self.mode = mode elif mode == InferenceState.UNLOAD: self.unload() self.mode = mode self.device = self.generate_experts.device else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") def unload(self): if self.generate_experts is not None: self.generate_experts.unload() if self.prefill_experts is not None: self.prefill_experts.unload() self.device = self.generate_experts.device def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0): if self.mode == InferenceState.GENERATE: assert self.generate_experts is not None, "generate_experts is None" return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) elif self.mode == InferenceState.PREFILL: assert self.prefill_experts is not None, "prefill_experts is None" return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) else: raise ValueError("load or set_inference_mode before forward") def set_inference_mode(self, mode: InferenceState): if mode == InferenceState.GENERATE: self.load(mode=InferenceState.GENERATE, warmup=False) elif mode == InferenceState.PREFILL: self.load(mode=InferenceState.PREFILL, warmup=False) elif mode == InferenceState.UNLOAD: self.unload() else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") class KSmallthinkerExperts(BaseInjectedModule, KExpertsBase): def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", prefill_device:str = "cuda", prefill_op: str | None = "KExpertsTorch", generate_device: str = "cpu", generate_op: str | None = "KExpertsCPU", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) if generate_op is not None: self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs) else: self.generate_experts = None if prefill_op is not None: self.prefill_experts = None self.gpu_mlp_type = prefill_op self.cpu_mlp_type = generate_op self.mode = InferenceState.UNLOAD def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True): # TODO support w as input if not mode: mode = InferenceState.GENERATE if mode == InferenceState.GENERATE: # self.prefill_experts.unload() self.generate_experts.load(w, warmup=warmup) self.device = self.generate_experts.device self.mode = mode elif mode == InferenceState.PREFILL: self.generate_experts.unload() self.prefill_experts.load(w, warmup=warmup) self.device = self.prefill_experts.device self.mode = mode elif mode == InferenceState.UNLOAD: self.unload() self.mode = mode self.device = self.generate_experts.device else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") def unload(self): if self.generate_experts is not None: self.generate_experts.unload() if self.prefill_experts is not None: self.prefill_experts.unload() self.device = self.generate_experts.device def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0): if self.mode == InferenceState.GENERATE: assert self.generate_experts is not None, "generate_experts is None" return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) elif self.mode == InferenceState.PREFILL: assert self.prefill_experts is not None, "prefill_experts is None" return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) else: raise ValueError("load or set_inference_mode before forward") def set_inference_mode(self, mode: InferenceState): if mode == InferenceState.GENERATE: self.load(mode=InferenceState.GENERATE, warmup=False) elif mode == InferenceState.PREFILL: self.load(mode=InferenceState.PREFILL, warmup=False) elif mode == InferenceState.UNLOAD: self.unload() else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") class KGlm4Experts(BaseInjectedModule, KExpertsBase): def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, # device: str = "cuda", prefill_device:str = "cuda", prefill_op: str | None = "KExpertsTorch", generate_device: str = "cpu", generate_op: str | None = "KExpertsCPU", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) if generate_op is not None: self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs) else: self.generate_experts = None if prefill_op is not None: self.prefill_experts = None self.gpu_mlp_type = prefill_op self.cpu_mlp_type = generate_op self.mode = InferenceState.UNLOAD def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True): # TODO support w as input if not mode: mode = InferenceState.GENERATE if mode == InferenceState.GENERATE: # self.prefill_experts.unload() self.generate_experts.load(w, warmup=warmup) self.device = self.generate_experts.device self.mode = mode elif mode == InferenceState.PREFILL: self.generate_experts.unload() self.prefill_experts.load(w, warmup=warmup) self.device = self.prefill_experts.device self.mode = mode elif mode == InferenceState.UNLOAD: self.unload() self.mode = mode self.device = self.generate_experts.device else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") def unload(self): if self.generate_experts is not None: self.generate_experts.unload() if self.prefill_experts is not None: self.prefill_experts.unload() self.device = self.generate_experts.device def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0): if self.mode == InferenceState.GENERATE: assert self.generate_experts is not None, "generate_experts is None" return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) elif self.mode == InferenceState.PREFILL: assert self.prefill_experts is not None, "prefill_experts is None" return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx) else: raise ValueError("load or set_inference_mode before forward") def set_inference_mode(self, mode: InferenceState): if mode == InferenceState.GENERATE: self.load(mode=InferenceState.GENERATE, warmup=False) elif mode == InferenceState.PREFILL: self.load(mode=InferenceState.PREFILL, warmup=False) elif mode == InferenceState.UNLOAD: self.unload() else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock): def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0): orig_shape = hidden_states.shape sequence_length = orig_shape[1] hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) router_logits = self.gate(hidden_states, bsz_tensor) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) if self.norm_topk_prob: routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) # only for generate phase if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx) y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0) y += y_ y.resize_(*orig_shape) return y y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) y_ = ( F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ ) if isinstance(self.experts, KExpertsBase): y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device) elif hidden_states.size(0) > 10: # TODO may bugs here y = ( self.moe_infer(hidden_states, selected_experts, routing_weights) .view(*orig_shape) .to(device=hidden_states.device) ) else: # TODO may bugs here y = ( self.moe_infer_simple(hidden_states, selected_experts, routing_weights) .view(*orig_shape) .to(device=hidden_states.device) ) y += y_ return y @torch.no_grad() def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor: outs = torch.empty_like(x) outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx) return outs @torch.no_grad() # TODO may bugs here def moe_infer_simple( self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor ) -> torch.Tensor: """ x: [num_tokens, hidden_size] topk_ids, topk_weight: [num_tokens, num_selected_experts] """ outs = torch.zeros_like(x) for token_idx in range(topk_ids.size(0)): for expert_idx in range(topk_ids.size(1)): expert = self.experts[topk_ids[token_idx, expert_idx]] outs[token_idx] += ( expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] ) return outs @torch.no_grad() # TODO may bugs here def moe_infer(self, x, topk_ids, topk_weight): cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) cnts.scatter_(1, topk_ids, 1) tokens_per_expert = cnts.sum(dim=0) idxs = topk_ids.view(-1).argsort() sorted_tokens = x[idxs // topk_ids.shape[1]] tokens_per_expert = tokens_per_expert.cpu().numpy() outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue expert = self.experts[i + self.ep_rank * self.experts_per_rank] tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = expert.forward(tokens_for_this_expert) outputs.append(expert_out) start_idx = end_idx outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) new_x = torch.empty_like(outs) new_x[idxs] = outs final_out = ( new_x.view(*topk_ids.shape, -1) .type(topk_weight.dtype) .mul_(topk_weight.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype) ) return final_out class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock): def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0): orig_shape = hidden_states.shape sequence_length = orig_shape[1] hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) if bsz_tensor is None: router_logits = self.gate(hidden_states) else: router_logits = self.gate(hidden_states, bsz_tensor) if router_logits.device.type == "xpu": from ipex_llm.transformers.models.common import moe_softmax_topk selected_experts, routing_weights = moe_softmax_topk( router_logits.half(), self.top_k, self.norm_topk_prob ) else: routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) if self.norm_topk_prob: routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) # only for generate phase if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx) # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0) # y += y_ y.resize_(*orig_shape) return y # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) # y_ = ( # F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ # ) if isinstance(self.experts, KExpertsBase): y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device) elif hidden_states.size(0) > 10: # TODO may bugs here y = ( self.moe_infer(hidden_states, selected_experts, routing_weights) .view(*orig_shape) .to(device=hidden_states.device) ) else: # TODO may bugs here y = ( self.moe_infer_simple(hidden_states, selected_experts, routing_weights) .view(*orig_shape) .to(device=hidden_states.device) ) # y += y_ return y @torch.no_grad() def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor: outs = torch.empty_like(x) outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx) return outs @torch.no_grad() # TODO may bugs here def moe_infer_simple( self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor ) -> torch.Tensor: """ x: [num_tokens, hidden_size] topk_ids, topk_weight: [num_tokens, num_selected_experts] """ outs = torch.zeros_like(x) for token_idx in range(topk_ids.size(0)): for expert_idx in range(topk_ids.size(1)): expert = self.experts[topk_ids[token_idx, expert_idx]] outs[token_idx] += ( expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] ) return outs @torch.no_grad() # TODO may bugs here def moe_infer(self, x, topk_ids, topk_weight): cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) cnts.scatter_(1, topk_ids, 1) tokens_per_expert = cnts.sum(dim=0) idxs = topk_ids.view(-1).argsort() sorted_tokens = x[idxs // topk_ids.shape[1]] tokens_per_expert = tokens_per_expert.cpu().numpy() outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue expert = self.experts[i + self.ep_rank * self.experts_per_rank] tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = expert.forward(tokens_for_this_expert) outputs.append(expert_out) start_idx = end_idx outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) new_x = torch.empty_like(outs) new_x[idxs] = outs final_out = ( new_x.view(*topk_ids.shape, -1) .type(topk_weight.dtype) .mul_(topk_weight.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype) ) return final_out class KSmallthinkerMoeBlock(BaseInjectedModule, SmallthinkerMoeBlock): def forward(self, router_input: torch.Tensor, hidden_states: torch.Tensor, bsz_tensor=None, cuda_graph_idx=0): orig_shape = hidden_states.shape sequence_length = orig_shape[1] hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) if bsz_tensor is None: if self.enable_early_router: router_logits = self.primary_router(router_input) else: router_logits = self.primary_router(hidden_states) else: if self.enable_early_router: router_logits = self.primary_router(router_input, bsz_tensor) else: router_logits = self.primary_router(hidden_states, bsz_tensor) router_logits, selected_experts = torch.topk(router_logits, self.num_active_primary_experts, dim=-1) if router_logits.device.type == "xpu": # TODO: support self.moe_primary_router_apply_softmax False case from ipex_llm.transformers.models.common import moe_softmax_topk selected_experts, routing_weights = moe_softmax_topk( router_logits.half(), self.top_k, self.norm_topk_prob ) else: if self.moe_primary_router_apply_softmax: routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) else: routing_weights = F.sigmoid(router_logits) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) # only for generate phase if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx) # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0) # y += y_ y.resize_(*orig_shape) return y # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) # y_ = ( # F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ # ) if isinstance(self.experts, KExpertsBase): y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device) elif hidden_states.size(0) > 10: # TODO may bugs here y = ( self.moe_infer(hidden_states, selected_experts, routing_weights) .view(*orig_shape) .to(device=hidden_states.device) ) else: # TODO may bugs here y = ( self.moe_infer_simple(hidden_states, selected_experts, routing_weights) .view(*orig_shape) .to(device=hidden_states.device) ) # y += y_ return y @torch.no_grad() def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor: outs = torch.empty_like(x) outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx) return outs @torch.no_grad() # TODO may bugs here def moe_infer_simple( self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor ) -> torch.Tensor: """ x: [num_tokens, hidden_size] topk_ids, topk_weight: [num_tokens, num_selected_experts] """ outs = torch.zeros_like(x) for token_idx in range(topk_ids.size(0)): for expert_idx in range(topk_ids.size(1)): expert = self.experts[topk_ids[token_idx, expert_idx]] outs[token_idx] += ( expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] ) return outs @torch.no_grad() # TODO may bugs here def moe_infer(self, x, topk_ids, topk_weight): cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) cnts.scatter_(1, topk_ids, 1) tokens_per_expert = cnts.sum(dim=0) idxs = topk_ids.view(-1).argsort() sorted_tokens = x[idxs // topk_ids.shape[1]] tokens_per_expert = tokens_per_expert.cpu().numpy() outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue expert = self.experts[i + self.ep_rank * self.experts_per_rank] tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = expert.forward(tokens_for_this_expert) outputs.append(expert_out) start_idx = end_idx outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) new_x = torch.empty_like(outs) new_x[idxs] = outs final_out = ( new_x.view(*topk_ids.shape, -1) .type(topk_weight.dtype) .mul_(topk_weight.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype) ) return final_out class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE): def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0): orig_shape = hidden_states.shape sequence_length = orig_shape[1] topk_idx, topk_weight = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # only for generate phase if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx) y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0) # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0) y += y_ y.resize_(*orig_shape) return y y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0) # y_ = ( # F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ # ) if isinstance(self.experts, KExpertsBase): y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device) elif hidden_states.size(0) > 10: # TODO may bugs here y = ( self.moe_infer(hidden_states, topk_idx, topk_weight) .view(*orig_shape) .to(device=hidden_states.device) ) else: # TODO may bugs here y = ( self.moe_infer_simple(hidden_states, topk_idx, topk_weight) .view(*orig_shape) .to(device=hidden_states.device) ) y += y_ return y @torch.no_grad() def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor: outs = torch.empty_like(x) outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx) return outs @torch.no_grad() # TODO may bugs here def moe_infer_simple( self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor ) -> torch.Tensor: """ x: [num_tokens, hidden_size] topk_ids, topk_weight: [num_tokens, num_selected_experts] """ outs = torch.zeros_like(x) for token_idx in range(topk_ids.size(0)): for expert_idx in range(topk_ids.size(1)): expert = self.experts[topk_ids[token_idx, expert_idx]] outs[token_idx] += ( expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] ) return outs @torch.no_grad() # TODO may bugs here def moe_infer(self, x, topk_ids, topk_weight): cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) cnts.scatter_(1, topk_ids, 1) tokens_per_expert = cnts.sum(dim=0) idxs = topk_ids.view(-1).argsort() sorted_tokens = x[idxs // topk_ids.shape[1]] tokens_per_expert = tokens_per_expert.cpu().numpy() outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue expert = self.experts[i + self.ep_rank * self.experts_per_rank] tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = expert.forward(tokens_for_this_expert) outputs.append(expert_out) start_idx = end_idx outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) new_x = torch.empty_like(outs) new_x[idxs] = outs final_out = ( new_x.view(*topk_ids.shape, -1) .type(topk_weight.dtype) .mul_(topk_weight.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype) ) return final_out class KQwen3NextSparseMoeBlockV2(BaseInjectedModule, Qwen3NextSparseMoeBlock): def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0): orig_shape = hidden_states.shape sequence_length = orig_shape[1] hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) if bsz_tensor is None: router_logits = self.gate(hidden_states) else: router_logits = self.gate(hidden_states, bsz_tensor) if router_logits.device.type == "xpu": from ipex_llm.transformers.models.common import moe_softmax_topk selected_experts, routing_weights = moe_softmax_topk( router_logits.half(), self.top_k, self.norm_topk_prob ) else: routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) if self.norm_topk_prob: routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) if self.norm_topk_prob: routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) # only for generate phase if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx) y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0) y += y_ y.resize_(*orig_shape) return y y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) y_ = ( F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ ) if isinstance(self.experts, KExpertsBase): y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device) elif hidden_states.size(0) > 10: # TODO may bugs here y = ( self.moe_infer(hidden_states, selected_experts, routing_weights) .view(*orig_shape) .to(device=hidden_states.device) ) else: # TODO may bugs here y = ( self.moe_infer_simple(hidden_states, selected_experts, routing_weights) .view(*orig_shape) .to(device=hidden_states.device) ) y += y_ return y @torch.no_grad() def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor: outs = torch.empty_like(x) outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx) return outs @torch.no_grad() # TODO may bugs here def moe_infer_simple( self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor ) -> torch.Tensor: """ x: [num_tokens, hidden_size] topk_ids, topk_weight: [num_tokens, num_selected_experts] """ outs = torch.zeros_like(x) for token_idx in range(topk_ids.size(0)): for expert_idx in range(topk_ids.size(1)): expert = self.experts[topk_ids[token_idx, expert_idx]] outs[token_idx] += ( expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] ) return outs @torch.no_grad() # TODO may bugs here def moe_infer(self, x, topk_ids, topk_weight): cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) cnts.scatter_(1, topk_ids, 1) tokens_per_expert = cnts.sum(dim=0) idxs = topk_ids.view(-1).argsort() sorted_tokens = x[idxs // topk_ids.shape[1]] tokens_per_expert = tokens_per_expert.cpu().numpy() outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue expert = self.experts[i + self.ep_rank * self.experts_per_rank] tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = expert.forward(tokens_for_this_expert) outputs.append(expert_out) start_idx = end_idx outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) new_x = torch.empty_like(outs) new_x[idxs] = outs final_out = ( new_x.view(*topk_ids.shape, -1) .type(topk_weight.dtype) .mul_(topk_weight.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype) ) return final_out ================================================ FILE: archive/ktransformers/operators/flashinfer_batch_prefill_wrapper.py ================================================ import torch import flashinfer import gc try: from flash_attn import flash_attn_with_kvcache print("found flash_attn") except ImportError: print("flash_attn not found, flashinfer unit test needed it. If you are using balance serve, ignore this.") from typing import Union, Optional def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) setup_seed(998244353) try: import torch_npu use_torch_npu = torch_npu.npu.is_available() except: use_torch_npu = False if not use_torch_npu: torch.set_grad_enabled(False) torch.set_default_dtype(torch.bfloat16) global_dtype=torch.bfloat16 global_device=torch.device("cuda",0) torch.cuda.set_device(0) torch.backends.cudnn.enabled =True torch.backends.cudnn.benchmark = True class flashInferAttn(): float_workspace_buffer = None def __init__(self, max_batch_token, max_batch_size, max_pages, device = "cuda:0", kv_layout: str = "NHD", use_cuda_graph: bool = False, ) -> None: self.device = device self.max_batch_token = max_batch_token self.kv_layout = kv_layout self.use_cuda_graph = use_cuda_graph if flashInferAttn.float_workspace_buffer is None: flashInferAttn.float_workspace_buffer = torch.empty(max_batch_token * 1024 * 1024, dtype=torch.uint8, device=device) self.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device) self.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device) self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device) self.paged_kv_last_page_len_buf = torch.empty((max_batch_size,), dtype=torch.int32, device=device) self.batch_size_tensor_buf = torch.empty((1,), dtype=torch.int32, device=device) self.num_tokens_tensor_buf = torch.empty((1,), dtype=torch.uint32, device=device) # TODO: custom mask self.custom_mask_buf = None self.qk_indptr_buf = None self.warpper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( flashInferAttn.float_workspace_buffer, self.kv_layout, use_cuda_graph=self.use_cuda_graph, qo_indptr_buf=self.qo_indptr_buf, paged_kv_indptr_buf=self.paged_kv_indptr_buf, paged_kv_indices_buf=self.paged_kv_indices_buf, paged_kv_last_page_len_buf=self.paged_kv_last_page_len_buf, backend = "fa2", ) def plan(self, qo_indptr: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_indices: torch.Tensor, paged_kv_last_page_len: torch.Tensor, batch_size_tensor: torch.Tensor, num_tokens_tensor: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, page_size: int, causal: bool = True, pos_encoding_mode: str = "NONE", q_data_type: Union[str, torch.dtype] = torch.bfloat16, kv_data_type: Optional[Union[str, torch.dtype]] = None): self.batch_size_tensor_buf.copy_(batch_size_tensor, non_blocking=True) self.num_tokens_tensor_buf.copy_(num_tokens_tensor, non_blocking=True) self.page_size = page_size self.warpper.plan( qo_indptr, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len, num_qo_heads, num_kv_heads, head_dim, page_size, causal = causal, pos_encoding_mode = pos_encoding_mode, q_data_type = q_data_type, kv_data_type = kv_data_type ) def calc_batch_indices(self, ragged_size = None): if self.use_cuda_graph: self.batch_indices, self.positions = flashinfer.get_batch_indices_positions( self.qo_indptr_buf, flashinfer.get_seq_lens(self.paged_kv_indptr_buf, self.paged_kv_last_page_len_buf, self.page_size), self.batch_size_tensor_buf, self.max_batch_token) else: self.batch_indices, self.positions = flashinfer.get_batch_indices_positions( self.warpper._qo_indptr_buf, flashinfer.get_seq_lens(self.warpper._paged_kv_indptr_buf, self.warpper._paged_kv_last_page_len_buf, self.page_size), self.batch_size_tensor_buf, ragged_size) def forward(self, q, k_cache, v_cache, k, v): if self.use_cuda_graph: flashinfer.page.append_paged_kv_cache(k, v, self.batch_indices, self.positions, (k_cache, v_cache), self.paged_kv_indices_buf, self.paged_kv_indptr_buf, self.paged_kv_last_page_len_buf, self.num_tokens_tensor_buf) return self.warpper.run(q, (k_cache, v_cache)) else: flashinfer.page.append_paged_kv_cache(k, v, self.batch_indices, self.positions, (k_cache, v_cache), self.warpper._paged_kv_indices_buf, self.warpper._paged_kv_indptr_buf, self.warpper._paged_kv_last_page_len_buf, self.num_tokens_tensor_buf) return self.warpper.run(q, (k_cache, v_cache)) def testCudaGraph(): # use max batch to create buffer batch_decode = 8 prefill_chunk = 48 past_kv_0 = 4090 past_kv_1 = 4096 raged_size = prefill_chunk + batch_decode num_key_value_heads = 8 head_dim = 128 num_attention_heads = 64 page_size = 256 num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size attn = flashInferAttn(raged_size, batch_decode+1, total_num_pages, use_cuda_graph=True) batch_size_tensor = torch.tensor([batch_decode + 1], device=global_device, dtype=torch.int32) k_caches = [] v_caches = [] ks = [] vs = [] qs = [] for layer_idx in range(3): k_caches.append(torch.randn(total_num_pages, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16)) v_caches.append(torch.randn(total_num_pages, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16)) ks.append(torch.randn(raged_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16)) vs.append(torch.randn(raged_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16)) qs.append(torch.randn(raged_size, num_attention_heads, head_dim, device=global_device, dtype=torch.bfloat16)) # warmup and capture small batch past_kv_0 = 250 past_kv_1 = 256 num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size q_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device) q_indptr[0] = 0 q_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32) kv_indptr = torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq kv_indices = torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32) kv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device) kv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1) kv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1) print(q_indptr) print(kv_indptr) print(kv_indices) print(kv_last_page_len) attn.plan(q_indptr, kv_indptr, kv_indices, kv_last_page_len, batch_size_tensor, num_attention_heads, num_key_value_heads, head_dim, page_size, causal = True, pos_encoding_mode="NONE", q_data_type=torch.bfloat16) attn.calc_batch_indices(raged_size) for layer_idx in range(3): attn.forward(qs[layer_idx], k_caches[layer_idx], v_caches[layer_idx], ks[layer_idx], vs[layer_idx]) torch.cuda.synchronize() outs = [] g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): for layer_idx in range(3): outs.append(attn.forward(qs[layer_idx], k_caches[layer_idx], v_caches[layer_idx], ks[layer_idx], vs[layer_idx])) g.replay() kv_last_page_len[:1+batch_decode//2] = int(past_kv_0) kv_last_page_len[1+batch_decode//2:] = int(past_kv_1) for layer_idx in range(3): for i in range(batch_decode + 1): qi = qs[layer_idx][q_indptr[i] : q_indptr[i + 1]] o_ref_i = flash_attn_with_kvcache( qi.unsqueeze(0), k_caches[layer_idx], v_caches[layer_idx], causal=True, block_table=kv_indices[kv_indptr[i]:kv_indptr[i+1]].unsqueeze(0), cache_seqlens=torch.tensor([past_kv_0 if i < 1+batch_decode//2 else past_kv_1], device=global_device, dtype=torch.int32) ) o_i = outs[layer_idx][q_indptr[i] : q_indptr[i + 1]] print(layer_idx, i) torch.testing.assert_close(o_i.unsqueeze(0), o_ref_i, rtol=5e-3, atol=5e-3) # run another batch size use capture cuda graph past_kv_0 = 4090 past_kv_1 = 4096 prefill_chunk = 24 batch_decode = 4 num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size batch_size_tensor = torch.tensor([batch_decode + 1], device=global_device, dtype=torch.int32) num_tokens_tensor = torch.tensor([batch_decode + prefill_chunk], device=global_device, dtype=torch.int32) q_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device) q_indptr[0] = 0 q_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32) kv_indptr = torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq kv_indices = torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32) kv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device) kv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1) kv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1) attn.plan(q_indptr, kv_indptr, kv_indices, kv_last_page_len, batch_size_tensor, num_attention_heads, num_key_value_heads, head_dim, page_size, causal = True, pos_encoding_mode="NONE", q_data_type=torch.bfloat16) attn.calc_batch_indices(raged_size) g.replay() kv_last_page_len[:1+batch_decode//2] = int(past_kv_0) kv_last_page_len[1+batch_decode//2:] = int(past_kv_1) for layer_idx in range(3): for i in range(batch_decode + 1): qi = qs[layer_idx][q_indptr[i] : q_indptr[i + 1]] o_ref_i = flash_attn_with_kvcache( qi.unsqueeze(0), k_caches[layer_idx], v_caches[layer_idx], causal=True, block_table=kv_indices[kv_indptr[i]:kv_indptr[i+1]].unsqueeze(0), cache_seqlens=torch.tensor([past_kv_0 if i < 1+batch_decode//2 else past_kv_1], device=global_device, dtype=torch.int32) ) o_i = outs[layer_idx][q_indptr[i] : q_indptr[i + 1]] print(layer_idx, i) torch.testing.assert_close(o_i.unsqueeze(0), o_ref_i, rtol=5e-3, atol=5e-3) def testAttentionFlashInfer( ): batch_decode = 32 prefill_chunk = 64 past_kv_0 = 510 past_kv_1 = 512 raged_size = prefill_chunk + batch_decode num_key_value_heads = 8 head_dim = 128 num_attention_heads = 64 cases = 1 page_size = 32 num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") qs = [] kvs = [] q_indptrs = [] kv_indptrs = [] kv_indicess = [] kv_last_page_lens = [] wrappers = [] for case_id in range(cases): kvs.append(torch.randn(total_num_pages, 2, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16)) qs.append(torch.randn(raged_size, num_attention_heads, head_dim, device=global_device, dtype=torch.bfloat16)) q_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device) q_indptr[0] = 0 q_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32) q_indptrs.append(q_indptr) kv_indptrs.append(torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq) kv_indicess.append(torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32)) kv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device) kv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1) kv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1) kv_last_page_lens.append(kv_last_page_len) wrappers.append(flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, "NHD", use_cuda_graph=True, qo_indptr_buf=q_indptrs[case_id], paged_kv_indptr_buf=kv_indptrs[case_id], paged_kv_indices_buf=kv_indicess[case_id], paged_kv_last_page_len_buf=kv_last_page_lens[case_id], )) wrappers[case_id].plan( q_indptrs[case_id], kv_indptrs[case_id], kv_indicess[case_id], kv_last_page_lens[case_id], num_attention_heads, num_key_value_heads, head_dim, page_size, causal = True, pos_encoding_mode="ROPE_LLAMA", q_data_type=torch.bfloat16 ) def custom_forward(case_id): out = wrappers[case_id].run(qs[case_id], kvs[case_id]) custom_forward(0) # testCudaGraph() # pass ================================================ FILE: archive/ktransformers/operators/flashinfer_wrapper.py ================================================ ''' Description : flashinfer MLA wrapper Author : Boxin Zhang Version : 0.2.3 ''' import torch import os flashinfer_enabled = False try: import flashinfer flashinfer_enabled = True print("found flashinfer") except ImportError: print("flashinfer not found, use triton for linux") try: import torch_npu use_torch_npu = torch_npu.npu.is_available() except: use_torch_npu = False if not use_torch_npu: from ktransformers.operators.triton_attention import decode_attention_fwd_grouped import math def attention_ref_torch( batch_size, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool, sm_scale: float, ) -> torch.Tensor: qo_len = q.shape[0] // batch_size kv_len = k.shape[0] // batch_size num_qo_heads = q.shape[1] head_dim_qk = q.shape[2] head_dim_vo = v.shape[2] logits = ( torch.einsum( "bmhd,bnhd->bhmn", q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(), k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(), ) * sm_scale ) #print("attn weights", logits) if causal: mask = ( torch.arange(kv_len - qo_len, kv_len).unsqueeze(1) >= torch.arange(0, kv_len).unsqueeze(0) ).to(q.device) else: mask = torch.ones(qo_len, kv_len).to(q.device) logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2) p = torch.softmax(logits, dim=-1) o_ref = ( torch.einsum( "bhmn,bnhd->bmhd", p, v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(), ) .contiguous() .view(batch_size * qo_len, num_qo_heads, head_dim_vo) .to(q) ) return o_ref, lse_ref * math.log2(math.e) class MLAWrapper(): def __init__(self, max_batch_size, max_pages, use_cuda_graph = True, device = "cuda", ): self.float_workspace_buffer = torch.empty(128*1024*1024, dtype=torch.int8, device=device) self.max_batch_size = max_batch_size self.max_pages = max_pages if use_cuda_graph: if self.max_batch_size == 1: self.qo_indptr_buf = torch.arange(0, max_batch_size+1, dtype=torch.int32, device=device) self.kv_indptr_buf = torch.tensor([0, max_pages], dtype=torch.int32, device=device) self.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device) else: self.qo_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device) self.kv_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device) self.kv_indices_buf = torch.empty(max_pages, dtype=torch.int32, device=device) self.batch_size_tensor_buf = torch.tensor([self.max_batch_size], dtype=torch.int32, device=device) self.kv_len_arr_buf = torch.empty(max_batch_size, dtype=torch.int32, device=device) else: self.qo_indptr_buf = None self.kv_indptr_buf = None self.kv_indices_buf = None self.kv_len_arr_buf = None self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( self.float_workspace_buffer, use_cuda_graph=use_cuda_graph, qo_indptr=self.qo_indptr_buf, kv_indptr=self.kv_indptr_buf, kv_indices=self.kv_indices_buf, kv_len_arr=self.kv_len_arr_buf, bsz_tensor=self.batch_size_tensor_buf, backend = "fa2", ) self.need_plan = True def plan(self, qo_indptr, kv_indptr, kv_indices, kv_len_arr, bsz_tensor, num_heads, head_dim_ckv, head_dim_kpe, page_size, sm_scale, q_data_type, kv_data_type, ): if qo_indptr is None: assert self.max_batch_size == 1 qo_indptr = self.qo_indptr_buf if kv_indptr is None: assert self.max_batch_size == 1 kv_indptr = self.kv_indptr_buf if kv_indices is None: assert self.max_batch_size == 1 kv_indices = self.kv_indices_buf if bsz_tensor is None: assert self.max_batch_size == 1 bsz_tensor = self.batch_size_tensor_buf self.wrapper.plan( qo_indptr, kv_indptr, kv_indices, kv_len_arr, num_heads, head_dim_ckv, head_dim_kpe, page_size, True, # causal sm_scale, q_data_type, kv_data_type, bsz_tensor ) def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False): return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse) class MLAWrapperSingleton(): wrappers:dict = {} @classmethod def get_instance(cls, device, *args, **kwargs)->MLAWrapper: if device not in cls.wrappers: cls.make_instance(device, *args, **kwargs) return cls.wrappers[device] @classmethod def make_instance(cls, device, *args, **kwargs): cls.wrappers[device] = MLAWrapper(*args, **kwargs, device=device) @classmethod def plan_all(cls, qo_indptr, kv_indptr, kv_indices, kv_len_arr, bsz_tensor, num_heads, head_dim_ckv, head_dim_kpe, page_size, sm_scale, q_data_type, kv_data_type,): for device, wrapper in cls.wrappers.items(): kv_len_arr_cur_device = kv_len_arr.to(device) wrapper.plan(qo_indptr, kv_indptr, kv_indices, kv_len_arr_cur_device, bsz_tensor, num_heads, head_dim_ckv, head_dim_kpe, page_size, sm_scale, q_data_type, kv_data_type,) wrapper.need_plan = False @classmethod def need_plan_all(cls): for device, wrapper in cls.wrappers.items(): wrapper.need_plan = True @classmethod def reset_buffer(cls): for device, wrapper in cls.wrappers.items(): wrapper.qo_indptr_buf[1] = 1 # assert max_batch_size=1 here. @classmethod def update_buffer(cls, max_pages): for device, wrapper in cls.wrappers.items(): wrapper.kv_indptr_buf[1] = max_pages # assert max_batch_size=1 here. wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device) wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf def checksame(): flashinfer_folder = "./flashinfer_output" flashinfer_folder = "./kv_cache_flashinfer" triton_folder = "./triton_output" triton_folder = "./kv_cache_triton" max_layer_id = 1 max_forward_id = 2 for forward_id in range(0, 19): print("forward_id", forward_id) for layer_id in range(max_layer_id): print(layer_id) #file_name = f"layer_{layer_id}_forward_{forward_id}_attn_output.pt" #file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt" file_name = f"layer_{layer_id}.pt" flashinfer_path = os.path.join(flashinfer_folder, file_name) triton_path = os.path.join(triton_folder, file_name) if not os.path.exists(triton_path): print(f"{file_name} not exist in {triton_folder}") continue if not os.path.exists(flashinfer_path): print(f"{file_name} not exist in {flashinfer_folder}") continue flashinfer_tensor = torch.load(flashinfer_path)[1:2, :62]# triton_tensor = torch.load(triton_path)[1:2, :62]#.squeeze(1)# try: torch.testing.assert_close(flashinfer_tensor, triton_tensor, rtol=1e-9, atol=1e-9) except AssertionError as e: print(e) if __name__ == "__main__": #checksame() #exit(0) max_batch_size = 2 max_batch_tokens = 256 max_pages = 128 page_size = 64 num_heads = 128 # warm-up kv_len = 4023 q_len = 1 q_nope_buf = torch.randn((max_batch_tokens, num_heads, 512), dtype=torch.bfloat16, device="cuda") q_pe_buf = torch.randn((max_batch_tokens, num_heads, 64), dtype=torch.bfloat16, device="cuda") kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda") ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1) wrapper = MLAWrapperSingleton.get_instance( "cuda", max_batch_size, max_pages, ) used_pages = (kv_len + page_size - 1)// page_size kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda") qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda") kv_indptr = torch.tensor([0, used_pages], dtype=torch.int32, device="cuda") kv_indices = torch.empty(max_pages, dtype=torch.int32, device="cuda") kv_indices[:used_pages] = torch.arange(0, used_pages, dtype=torch.int32, device="cuda") bsz_tensor = torch.tensor([1], dtype=torch.int32, device="cuda") wrapper.plan( qo_indptr, kv_indptr, kv_indices, kv_len_arr, bsz_tensor, 128, 512, 64, page_size, 192 ** (-0.5), torch.bfloat16, torch.bfloat16, ) attn_output = wrapper.run(q_nope_buf[:q_len], q_pe_buf[:q_len], ckv, k_pe) print(attn_output.shape) graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe) graph.replay() q = torch.cat([q_nope_buf, q_pe_buf], dim=-1) k = ( torch.cat([ckv, k_pe], dim=-1) .view(-1, 1, 512 + 64) .repeat_interleave(num_heads, dim=1) ) v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1) attn_ref, lse_ref = attention_ref_torch( 1, q[:q_len], k[:kv_len], v[:kv_len], True, 192 ** (-0.5) ) torch.testing.assert_close(attn_output[:q_len], attn_ref, rtol=5e-3, atol=5e-3) # warm-up finished kv_len = 512 q_len = 128 pages = max_pages used_pages = (kv_len + page_size - 1)// page_size q_nope = torch.randn((q_len*2, num_heads, 512), dtype=torch.bfloat16, device="cuda") q_nope[q_len:] = q_nope[:q_len] q_pe = torch.randn((q_len*2, num_heads, 64), dtype=torch.bfloat16, device="cuda") q_pe[q_len:] = q_pe[:q_len] kv_cache = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda") kv_cache[used_pages:2*used_pages] = kv_cache[:used_pages] ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1) kv_len_arr = torch.tensor([kv_len, kv_len], dtype=torch.int32, device="cuda") qo_indptr = torch.tensor([0, q_len, q_len*2], dtype=torch.int32, device="cuda") kv_indptr = torch.tensor([0, used_pages, used_pages*2], dtype=torch.int32, device="cuda") kv_indices = torch.empty(max_pages, dtype=torch.int32, device="cuda") kv_indices[:2*used_pages] = torch.arange(0, 2*used_pages, dtype=torch.int32, device="cuda") bsz_tensor = torch.tensor([2], dtype=torch.int32, device="cuda") wrapper.plan( qo_indptr, kv_indptr, kv_indices, kv_len_arr, bsz_tensor, 128, 512, 64, page_size, 192 ** (-0.5), torch.bfloat16, torch.bfloat16, ) q_nope_buf.copy_(q_nope) q_pe_buf.copy_(q_pe) kv_buf[:pages].copy_(kv_cache) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() # ref_torch q = torch.cat([q_nope, q_pe], dim=-1) k = ( torch.cat([ckv, k_pe], dim=-1) .view(-1, 1, 512 + 64) .repeat_interleave(num_heads, dim=1) ) v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1) attn_ref, lse_ref = attention_ref_torch( max_batch_size, q, k[:2*kv_len], v[:2*kv_len], True, 192 ** (-0.5) ) torch.testing.assert_close(attn_ref[:q_len], attn_ref[q_len:q_len*2], rtol=1e-9, atol=1e-9) torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9) torch.testing.assert_close(attn_output[:q_len], attn_ref[:q_len], rtol=5e-3, atol=5e-3) torch.testing.assert_close(attn_output[q_len:q_len*2], attn_ref[q_len:q_len*2], rtol=5e-3, atol=5e-3) #torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9) #torch.testing.assert_close(attn_output, attn_ref, rtol=5e-3, atol=5e-3) exit(0) for forward_id in range(0, 1): print("forward_id", forward_id) for layer_id in range(1): print(layer_id) flashinfer_folder = "./kv_cache_flashinfer" forward_id = 17 layer_id = 0 file_name = f"layer_{layer_id}.pt" kv_cache_path = os.path.join(flashinfer_folder, file_name) flashinfer_folder = "./flashinfer_output" q_len = 1 kv_len = 126 file_name = f"layer_{layer_id}_forward_{forward_id}_q_nope.pt" q_nope = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,512).to(device="cuda") file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt" q_pe = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,64).to(device="cuda") q = torch.cat([q_nope, q_pe], dim=-1) kv_cache = torch.load(kv_cache_path).to(device="cuda") pages, page_size, _, head_dim = kv_cache.shape kv_cache = kv_cache.view(pages, page_size, head_dim) ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1) kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda") qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda") wrapper.plan( None, None, None, kv_len_arr, 128, 512, 64, page_size, 192 ** (-0.5), torch.bfloat16, torch.bfloat16, ) q_nope_buf.copy_(q_nope) q_pe_buf.copy_(q_pe) kv_buf[:pages].copy_(kv_cache) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() # ref_torch k = ( torch.cat([ckv, k_pe], dim=-1) .view(-1, 1, 512 + 64) .repeat_interleave(num_heads, dim=1) ) v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1) attn_ref, lse_ref = attention_ref_torch( max_batch_size, q, k[:kv_len], v[:kv_len], False, 192 ** (-0.5) ) torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3) # ref_triton attn_logits = torch.empty( ( max_batch_size, num_heads, 4, #num_kv_splits # follow vLLM, fix it TODO 512 + 1, ), dtype=torch.float32, device = "cuda" ) triton_ref = torch.zeros_like(q_nope) page_table = torch.arange(max_pages, dtype=torch.int32, device="cuda") ckv_with_pe = torch.cat([ckv, k_pe], dim=-1).contiguous().view(pages, page_size, 1, 576) ckv = ckv.view(pages, page_size, 1, 512) decode_attention_fwd_grouped(q, ckv_with_pe, ckv, triton_ref, page_table, kv_len_arr, attn_logits, 4, #num_kv_splits # follow vLLM, fix it TODO 192 ** (-0.5), page_size) torch.testing.assert_close(attn_output, triton_ref, rtol=1e-3, atol=1e-3) #file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt" #ktrans_output = torch.load(file_name) #torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3) print("test past") ================================================ FILE: archive/ktransformers/operators/gate.py ================================================ from typing import Optional from torch import nn import torch import torch.nn.functional as F import os from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.linear import KTransformersLinear from ktransformers.util.custom_loader import GGUFLoader, ModelLoader, SafeTensorLoader, translate_name_to_gguf from transformers.configuration_utils import PretrainedConfig from abc import ABC, abstractmethod # class Base(BaseInjectedModule, ABC): class KMoEGateBase(ABC): def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs): # super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) super().__init__() self.key = key self.gguf_loader = gguf_loader self.config = config self.device = device self.orig_module = orig_module @abstractmethod def forward(self, input_tensor, expert_ids, weights): pass @abstractmethod def load(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu", warmup: bool = False): pass @abstractmethod def unload(): pass def load_weights(self, override_key: str | None = None, device: str = "cpu"): res = {} if override_key is not None: keys = override_key else: keys = [self.key] gate = None up = None down = None gate_type = None up_type = None down_type = None for key in keys: if self.gguf_loader.safetensor_loader is not None: # for npu translate_key = translate_name_to_gguf(key) translate_key = ".".join(translate_key.split(".")[:2]) targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"] weight = self.gguf_loader.safetensor_loader.load_tensor(translate_key + ".ffn_gate_inp.weight") e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(translate_key + ".exp_probs_b.bias") weight_type = weight.dtype e_score_correction_bias_type = e_score_correction_bias.dtype res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type} # key = ".".join(key.split(".")[:-1]) elif isinstance(self.gguf_loader, SafeTensorLoader): res = self.gguf_loader.load_gate(key, device=device) elif self.gguf_loader.has_tensor(key+".weight"): # targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"] targets = [".weight", ".e_score_correction_bias"] tensors = self.load_multi(key, targets, device=device) weight = tensors[".weight"] e_score_correction_bias = tensors[".e_score_correction_bias"] # weight_type = self.gguf_loader.tensor_info[key + ".weight"]["ggml_type"] res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias} else: raise ValueError(f"Experts {key} not found in gguf_loader") return res def load_multi(self, key: str, keys: list[str], device: str = "cpu"): tensors = {} for k in keys: tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device) return tensors class KMoEGate(BaseInjectedModule, KMoEGateBase): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, generate_device: str = "cuda", prefill_device: str = "cuda", **kwargs, ): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) self.generate_device = generate_device self.prefill_device = prefill_device def forward(self, hidden_states) -> torch.Tensor: return self.orig_module.forward(hidden_states) def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): if device is None: device = self.device if w is None: w = self.load_weights(device=device) if isinstance(w, dict): self.orig_module.weight = nn.Parameter(w["weight"]) self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"]) else: raise ValueError("Invalid weight type") self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device)) self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device)) def unload(self): if self.weight is not None: self.weight = None if self.e_score_correction_bias is not None: self.e_score_correction_bias = None class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, generate_device: str = "cuda", generate_op: str| None = "KLinearMarlin", prefill_device: str = "cuda", prefill_op: str| None = "KLinearMarlin", use_quant: bool = False, **kwargs, ): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) self.generate_device = generate_device self.prefill_device = prefill_device self.generate_op = generate_op self.prefill_op = prefill_op self.is_windows = os.name == 'nt' self.use_quant = use_quant if not self.is_windows and use_quant: self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device) self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp", gguf_loader, config, self.gate_linear, #orig_module generate_device, generate_op, prefill_device, prefill_op) else: self.gate_linear = None def forward(self, hidden_states) -> torch.Tensor: if self.is_windows: return self.orig_module.forward(hidden_states) bsz, seq_len, h = hidden_states.shape ### compute gating score hidden_states = hidden_states.view(-1, h) if self.use_quant: logits = self.gate_linear.forward(logits) else: logits = F.linear( hidden_states.type(torch.float32), self.weight.type(torch.float32), None ) return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob, self.n_group, self.topk_group) def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): if device is None: device = self.device if w is None: w = self.load_weights(device=device) if isinstance(w, dict): self.orig_module.weight = nn.Parameter(w["weight"]) self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"]) else: raise ValueError("Invalid weight type") self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device)) self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device)) if not self.is_windows and self.use_quant: self.gate_linear.load(self.orig_module.weight) def unload(self): if self.weight is not None: self.weight = None if self.e_score_correction_bias is not None: self.e_score_correction_bias = None class KMoEGateIPEXLLM(KMoEGate): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, generate_device: str = "xpu", prefill_device: str = "xpu", **kwargs, ): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) KMoEGate.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) self.generate_device = generate_device self.prefill_device = prefill_device def forward(self, hidden_states) -> torch.Tensor: x = hidden_states.view(-1, hidden_states.size(-1)) logits = torch.nn.functional.linear( x.type(torch.float32), self.orig_module.weight.type(torch.float32), None ) scores = logits.sigmoid() from ipex_llm.transformers.models.common import moe_group_topk topk_idx, topk_weight = moe_group_topk(scores, self.orig_module.e_score_correction_bias, self.n_group, self.topk_group, self.top_k, self.norm_topk_prob, self.routed_scaling_factor) return topk_idx, topk_weight.to(x.dtype) ================================================ FILE: archive/ktransformers/operators/layernorm.py ================================================ ''' Date: 2024-11-13 15:05:52 LastEditors: Xie Weiyu ervinxie@qq.com LastEditTime: 2024-11-25 08:59:19 ''' """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """Fused operators for normalization layers.""" import logging from typing import Optional, Tuple, Union from transformers import PretrainedConfig import torch import torch.nn as nn from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm from ktransformers.models.modeling_qwen3_next import Qwen3NextRMSNorm from ktransformers.models.modeling_smallthinker import SmallthinkerRMSNorm from ktransformers.models.modeling_glm4_moe import Glm4MoeRMSNorm from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_loader import GGUFLoader if not torch.xpu.is_available(): from flashinfer.norm import ( fused_add_rmsnorm, rmsnorm, ) logger = logging.getLogger(__name__) class RMSNorm(DeepseekV3RMSNorm, BaseInjectedModule): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.hidden_size, orig_module.variance_epsilon) def forward( self, x: torch.Tensor, batch_size_tensor: torch.Tensor = None, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: #return self.forward_native(x, residual) if batch_size_tensor is None: return self.forward_native(x) if residual is not None: fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon) #residual = x + residual #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon) return x, residual # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous()) out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon) return out def forward_native( self, hidden_states ): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class KQwen2MoeRMSNorm(Qwen2MoeRMSNorm, BaseInjectedModule): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(config.hidden_size, orig_module.variance_epsilon) def forward( self, x: torch.Tensor, batch_size_tensor: torch.Tensor = None, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: #return self.forward_native(x, residual) if batch_size_tensor is None: return self.forward_native(x) if residual is not None: fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon) #residual = x + residual #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon) return x, residual # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous()) out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon) return out def forward_native( self, hidden_states ): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.hidden_size, orig_module.variance_epsilon) def forward( self, x: torch.Tensor, batch_size_tensor: torch.Tensor = None, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: #return self.forward_native(x, residual) bsz, hidden_size = x.shape x = x.view(-1, self.orig_module.hidden_size) if batch_size_tensor is None: return self.forward_native(x) if residual is not None: fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon) #residual = x + residual #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon) return x, residual # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous()) out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon) out = out.view(bsz, hidden_size) return out def forward_native( self, hidden_states ): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class KQwen3NextRMSNorm(Qwen3NextRMSNorm, BaseInjectedModule): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.hidden_size, orig_module.variance_epsilon) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x, num_tokens_tensors, residual = None): if residual is not None: x = x + residual residual = x x = x.view(-1, self.orig_module.hidden_size) output = self._norm(x.float()) # Llama does x.to(float16) * w whilst Qwen3Next is (x * w).to(float16) # See https://github.com/huggingface/transformers/pull/29402 output = output * (1.0 + self.weight.float()) if residual is None: return output.type_as(x) return output.type_as(x), residual def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" class KSmallthinkerRMSNorm(SmallthinkerRMSNorm, BaseInjectedModule): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.hidden_size, orig_module.variance_epsilon) def forward( self, x: torch.Tensor, batch_size_tensor: torch.Tensor = None, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: #return self.forward_native(x, residual) bsz, hidden_size = x.shape x = x.view(-1, self.orig_module.hidden_size) if batch_size_tensor is None: return self.forward_native(x) if residual is not None: fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon) #residual = x + residual #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon) return x, residual # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous()) out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon) out = out.view(bsz, hidden_size) return out def forward_native( self, hidden_states ): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class KGlm4MoeRMSNorm(Glm4MoeRMSNorm, BaseInjectedModule): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.hidden_size, orig_module.variance_epsilon) def forward( self, x: torch.Tensor, batch_size_tensor: torch.Tensor = None, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: #return self.forward_native(x, residual) bsz, hidden_size = x.shape x = x.view(-1, self.orig_module.hidden_size) if batch_size_tensor is None: return self.forward_native(x) if residual is not None: fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon) #residual = x + residual #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon) return x, residual # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous()) out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon) out = out.view(bsz, hidden_size) return out def forward_native( self, hidden_states ): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.hidden_size, orig_module.variance_epsilon) def forward( self, x, batch_size_tensor: torch.Tensor = None, residual: Optional[torch.Tensor] = None, )-> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if residual is not None: x = x + residual residual = x # range batch_size_tensor for x input_dtype = x.dtype x = x.to(torch.float32) variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.variance_epsilon) if residual is not None: return self.weight * x.to(input_dtype), residual return self.weight * x.to(input_dtype) class KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "xpu", generate_device: str = "xpu", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.weight.shape[0], orig_module.variance_epsilon) self.eps = orig_module.variance_epsilon def forward(self, x: torch.Tensor) -> torch.Tensor: from ipex_llm.transformers.models.common import rms_norm_forward if x.dtype not in [torch.float32, torch.float16]: output = rms_norm_forward(self, x.float()) else: output = rms_norm_forward(self, x) return output.to(x.dtype) def load(self): BaseInjectedModule.load(self) if self.weight.dtype not in [torch.float32, torch.float16]: self.weight = self.weight.float() ================================================ FILE: archive/ktransformers/operators/linear.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : Azure-Tang, Boxin Zhang Date : 2024-07-25 11:25:24 Version : 0.1.0 LastEditors : Azure LastEditTime : 2024-08-29 09:11:16 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import ctypes import torch from torch import Tensor, nn try: import torch_npu use_torch_npu = torch_npu.npu.is_available() except: use_torch_npu = False if not torch.xpu.is_available() and not use_torch_npu: import KTransformersOps import vLLMMarlin from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader from ktransformers.util.utils import InferenceState if not torch.xpu.is_available(): from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import ( MarlinWorkspace, marlin_quantize, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MIN_THREAD_K, GPTQ_MARLIN_MAX_PARALLEL, vllm_marlin_quantize ) from ktransformers.operators.base_operator import BaseInjectedModule from transformers.configuration_utils import PretrainedConfig try: from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant except: print("no triton") from abc import ABC, abstractmethod import sys, os sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build")) sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release")) sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug")) import cpuinfer_ext from ktransformers.operators.cpuinfer import CPUInfer from ktransformers.server.config.config import Config from typing import Dict, Tuple, Optional, Union import numpy as np #class KLinearBase(BaseInjectedModule, ABC): class KLinearBase(ABC): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, device: str = "cuda", **kwargs, ): # super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) super().__init__() self.key = key self.gguf_loader = gguf_loader self.device = device self.config = config self.has_bias = False self.dtype = torch.get_default_dtype() if orig_module is not None: self.in_features = orig_module.in_features self.out_features = orig_module.out_features else: shape = self.gguf_loader.tensor_info[key + ".weight"]["shape"] if len(shape) == 1: print("Warning: orig_module is not set, but has in_features or out_features equals to 1, can't get in_features and out_features from GGUF") self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0] self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1] self.loaded = False # for lm_head pre-load, TODO: use new way to do lm_head pre-load when layer wise prefill. @abstractmethod def forward(self, x: torch.Tensor) -> torch.Tensor: pass def load_weight(self, override_key: str | None = None, device: str | None = None): if override_key is not None: keys = override_key else: keys = [self.key] for key in keys: if isinstance(self.gguf_loader, SafeTensorLoader): # using safetensor_loader tensor = self.gguf_loader.load_tensor(key+'.weight') try: bias = self.gguf_loader.load_tensor(key+'.bias') except: bias = None if self.gguf_loader.has_tensor(key+'.weight_scale_inv'): weight_scale_inv = self.gguf_loader.load_tensor(key+'.weight_scale_inv') return nn.Parameter(tensor), nn.Parameter(weight_scale_inv) if bias is not None: return nn.Parameter(tensor), nn.Parameter(bias) else: return nn.Parameter(tensor) elif self.gguf_loader.has_tensor(key + ".weight") or "kv_b_proj" in key: if key + ".bias" in self.gguf_loader.tensor_file_map: tensors = self.load_multi(key, ["weight", "bias"], device=device) tensor = tensors["weight"] bias = tensors["bias"] # self.qtype = GGML_TYPE_QTYPE_MAP[tensorinfo[key + ".weight"]["ggml_type"]] # print(torch.isinf(tensor).any(), torch.isinf(bias).any()) return nn.Parameter(tensor), nn.Parameter(bias) elif "kv_b_proj" in key and not self.gguf_loader.has_tensor(key + ".weight"): attn_k_b_tensors = self.load_multi(key.replace("self_attn.kv_b_proj", "attn_k_b"), ["weight"], device=device) attn_k_b = attn_k_b_tensors["weight"] del attn_k_b_tensors attn_k_b = attn_k_b.transpose(1, 2).contiguous() attn_v_b_tensors = self.load_multi(key.replace("self_attn.kv_b_proj", "attn_v_b"), ["weight"], device=device) attn_v_b = attn_v_b_tensors["weight"] del attn_v_b_tensors kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1) kv_b_proj = kv_b_proj.contiguous() if kv_b_proj.ndim == 2 else kv_b_proj.flatten(0, 1).contiguous() del attn_k_b del attn_v_b return nn.Parameter(kv_b_proj) else: tensors = self.load_multi(key, ["weight"], device=device) tensor = tensors["weight"] # self.qtype = GGML_TYPE_QTYPE_MAP[tensorinfo[key + ".weight"]["ggml_type"]] return nn.Parameter(tensor) else: raise FileNotFoundError(f"Weight file not found for key {key}") def load_multi(self, key: str, keys: list[str], device: str = "cpu"): tensors = {} for k in keys: tensors[k] = self.gguf_loader.load_gguf_tensor(key + "." + k, device=device) return tensors @abstractmethod def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = "cuda"): pass @abstractmethod def unload(self): pass class KLinearTorch(KLinearBase): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, device: str = "cuda", **kwargs, ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) self.has_bias = False self.dtype = torch.get_default_dtype() self.weight = None self.has_bias = False def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> torch.Tensor: dtype = x.dtype out_device = x.device # TODO: support CUDA Graph when using cpu, but CPUInfer is recommended. x = x.to(device=self.device, dtype=self.dtype) x = x @ self.weight if self.has_bias: x = x + self.bias x = x.to(dtype=dtype, device=out_device) return x def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): if self.loaded: return if device is None: device = self.device if w is None: w = self.load_weight(device=device) # else: self.out_features = w.shape[0], self.in_features = w.shape[1] if isinstance(w, nn.Parameter): try: self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T except: self.weight = w.to(dtype=self.dtype).T self.has_bias = False elif isinstance(w, tuple): try: self.weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T except: self.weight = w[0].to(dtype=self.dtype).T self.bias = w[1].to(dtype=self.dtype) self.has_bias = True else: raise ValueError("Invalid weight type") # self.linear = self.linear.to(device) self.weight = self.weight.to(device) if self.has_bias: self.bias = self.bias.to(device) self.loaded = True def unload(self): if self.weight is not None: self.weight = None if self.has_bias: self.bias = None class KLinearQ8(KLinearBase): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, device: str = "cuda", **kwargs, ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) self.has_bias = False self.compute_dtype = torch.float32 self.weight = None self.weight_scale = None self.weight_zero_point = None self.bias = None self.loaded = False def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None) -> torch.Tensor: orig_dtype = x.dtype out_device = x.device x = x.to(device=self.device, dtype=self.compute_dtype) # 使用原始权重做矩阵乘法,模拟原始行为 # 反量化权重进行矩阵乘法 weight_dequant = self._dequantize_weight(self.weight, self.weight_scale, bits=8) out = x @ weight_dequant.T if self.has_bias: out = out + self.bias return out.to(dtype=orig_dtype, device=out_device) def _dequantize_weight(self, q_matrix, scales, bits=8): """ Dequantize a low-precision matrix back to floating-point Args: q_matrix (torch.Tensor): Quantized int matrix scales (torch.Tensor): Scale factors for each column bits (int): Quantization bits used (8 or 4) Returns: torch.Tensor: Dequantized floating-point matrix """ # Ensure inputs are torch tensors if not isinstance(q_matrix, torch.Tensor): q_matrix = torch.tensor(q_matrix, dtype=torch.int8) if not isinstance(scales, torch.Tensor): scales = torch.tensor(scales, dtype=torch.float32) # Convert to correct dtype if needed if q_matrix.dtype != torch.int8: q_matrix = q_matrix.to(torch.int8) if scales.dtype != torch.float32: scales = scales.to(torch.float32) # For Q4, ensure the values stay within 4-bit range if bits == 4: q_matrix = torch.clamp(q_matrix, -7, 7) rows, cols = q_matrix.shape dequant_matrix = q_matrix.to(torch.float32) scales_broadcast = scales.view(1, cols) # Apply dequantization to all columns at once using matrix multiplication dequant_matrix = dequant_matrix * scales_broadcast return dequant_matrix def _quantize_weight(self, matrix, bits=8): """ Quantize a floating-point matrix to lower precision (Q8 or Q4) Args: matrix (torch.Tensor): Input matrix in floating-point format bits (int): Quantization bits, either 8 or 4 Returns: tuple: (quantized int matrix, scale factors for each column) """ if not isinstance(matrix, torch.Tensor): matrix = torch.tensor(matrix, dtype=torch.float32) # Convert to float32 if needed if matrix.dtype != torch.float32: matrix = matrix.to(torch.float32) # Get matrix shape rows, cols = matrix.shape # Determine quantization parameters based on bits if bits == 8: max_int = 127 qtype = torch.int8 elif bits == 4: max_int = 7 qtype = torch.int8 # We'll still use int8 storage but limit to 4-bit range, wait for native support else: raise ValueError("Quantization bits must be either 8 or 4") scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device) # Calculate max absolute value for each column max_abs_vals, _ = torch.max(torch.abs(matrix), dim=0) # Handle zero columns (avoid division by zero) zero_cols = max_abs_vals == 0 max_abs_vals[zero_cols] = 1.0 # Calculate scale factors for all columns at once scales = max_abs_vals / max_int # Prepare the scales for broadcasting [1, cols] scales_broadcast = scales.view(1, cols) # Apply quantization to the entire matrix at once q_matrix = torch.round(matrix / scales_broadcast).to(qtype) # For Q4, clamp values to ensure they stay within 4-bit range if bits == 4: q_matrix = torch.clamp(q_matrix, -max_int, max_int) return q_matrix, scales def load(self, w: Union[Dict, nn.Parameter, Tuple, None] = None, device: Optional[str] = None): if self.loaded: return if device is None: device = self.device if w is None: w = self.load_weight(device=device) if isinstance(w, nn.Parameter): try: weight = w.to(dtype=self.compute_dtype).view(self.out_features, self.in_features) except: weight = w.to(dtype=self.compute_dtype) self.has_bias = False elif isinstance(w, tuple): try: weight = w[0].to(dtype=self.compute_dtype).view(self.out_features, self.in_features) except: weight = w[0].to(dtype=self.compute_dtype) self.bias = w[1].to(dtype=self.compute_dtype).to(device) self.has_bias = True else: raise ValueError("Invalid weight type") self.weight, self.weight_scale = self._quantize_weight(weight, bits=8) self.weight = self.weight.to(device) self.weight_scale = self.weight_scale.to(device) if self.has_bias: self.bias = self.bias.to(device) self.loaded = True def unload(self): self.weight = None self.weight_scale = None self.weight_zero_point = None self._orig_weight = None if self.has_bias: self.bias = None self.loaded = False class KLinearFP8(KLinearBase): # this kernel requires special handling for weight # Please load the weight file downloaded from KVCache.AI has_bias: bool weight: torch.Tensor bias: torch.Tensor def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, device: str = "cuda", block_size: int = 128, **kwargs, ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) self.has_bias = False self.dtype = torch.get_default_dtype() self.block_size = block_size def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor: x = x.to(self.device) orig_dtype = x.dtype x_quantized, scale_x = act_quant(x, self.block_size) y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight_scale_inv) return y.to(dtype=orig_dtype) def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): if device is None: device = self.device if w is None: w = self.load_weight(device=device) ### TODO fit weight_inv format if isinstance(w, tuple): self.weight = w[0].to(device) self.weight_scale_inv = w[1].to(device) self.has_bias = False else: raise ValueError("Invalid weight type") self.weight = self.weight.to(device) if self.has_bias: self.bias = self.bias.to(device) def unload(self): if self.weight is not None: self.weight = None if self.has_bias: self.bias = None # TODO: merge two marlin class class VLinearMarlin(KLinearBase): marlin_q_w: torch.Tensor marlin_s: torch.Tensor g_idx: torch.Tensor sort_indices: torch.Tensor has_bias: bool def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, device: str = "cuda", num_bits: int = 4, # 4-bit/8-bit is supported group_size: int = 64, # -1, 32, 64, 128 act_order: bool = False, is_k_full=True, **kwargs, ): assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) self.num_bits = num_bits self.group_size = group_size self.act_order = act_order self.is_k_full = is_k_full self.padding = False self.orin_in_features = self.in_features self.orin_out_features = self.out_features if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0: #print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding") self.padding = True self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N #print(f"After padding: in_features={in_features}, out_features={out_features}") self.k = self.in_features self.n = self.out_features def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): if self.loaded: return if device is None: device = self.device assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" #if self.in_features * self.out_features: if w is None: w = self.load_weight(device=device) if isinstance(w, nn.Parameter): # pad weight weight = w.view(self.orin_out_features, self.orin_in_features).T self.has_bias = False elif isinstance(w, tuple): w = list(w) weight = w[0].view(self.orin_out_features, self.orin_in_features).T self.bias = w[1].view(self.orin_out_features) self.bias = w[1] self.has_bias = True else: raise ValueError("Invalid weight type") weight = weight.to(device) if self.has_bias: self.bias = self.bias.to(device) if self.padding: padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device) padded_weight[:self.orin_in_features, :self.orin_out_features] = weight weight = padded_weight # Pack Marlin linear marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( weight, self.num_bits, self.group_size, self.act_order ) self.workspace = MarlinWorkspace( self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device ) self.weight = marlin_q_w self.marlin_q_w = marlin_q_w self.marlin_s = marlin_s self.g_idx = g_idx self.sort_indices = sort_indices self.k = weight.shape[0] self.n = weight.shape[1] # self.shape_buffer = torch.tensor([60], dtype=torch.int32, device=self.device) self.loaded = True def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor: if bsz_tensor is None: bsz_tensor = torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device) # Only support input x as BF16 and FP16 x = x.to(self.device) orig_shape = list(x.shape) orig_dtype = x.dtype x = x.reshape(-1, orig_shape[-1]) marlin_s = self.marlin_s.to(x.dtype) sms = -1 # padding x.shape[0] to avoid CUDA illegal memory access error x, orig_size_m = self._pad_input(x) x = vLLMMarlin.gptq_marlin_gemm( x, self.marlin_q_w, marlin_s, self.g_idx, self.sort_indices, self.workspace.scratch, self.num_bits, bsz_tensor, x.shape[0], self.n, x.shape[-1], sms, self.is_k_full, ) x = x[:orig_size_m] if self.has_bias: x = x + self.bias orig_shape[-1] = self.n return x.reshape(orig_shape).to(orig_dtype) def unload(self): if self.has_bias: self.bias = None self.marlin_q_w = None self.marlin_s = None self.g_idx = None self.sort_indices = None self.workspace = None def _pad_input(self, x): size_m = x.shape[0] size_k = x.shape[1] # size_m and align value depends on VLinearMarlin implementation if size_m > 1024: align = 1024 elif size_m > 64: align = 64 else: align = 1 padded_size_m = ((size_m + align - 1) // align) * align if padded_size_m > size_m: pad_len = padded_size_m - size_m pad_tensor = torch.zeros((pad_len, size_k), dtype=x.dtype, device=x.device) x = torch.cat([x, pad_tensor], dim = 0).contiguous() return x, size_m class KLinearMarlin(KLinearBase): marlin_q_w: torch.Tensor marlin_s: torch.Tensor g_idx: torch.Tensor sort_indices: torch.Tensor has_bias: bool def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, device: str = "cuda", num_bits: int = 4, # 4-bit/8-bit is supported group_size: int = 64, # -1, 32, 64, 128 act_order: bool = False, is_k_full=True, **kwargs, ): assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) self.num_bits = num_bits self.group_size = group_size self.act_order = act_order self.is_k_full = is_k_full self.padding = False self.orin_in_features = self.in_features self.orin_out_features = self.out_features if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0: #print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding") self.padding = True self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N #print(f"After padding: in_features={in_features}, out_features={out_features}") self.k = self.in_features self.n = self.out_features def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): if self.loaded: return if device is None: device = self.device assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" #if self.in_features * self.out_features: if w is None: w = self.load_weight(device=device) if isinstance(w, nn.Parameter): # pad weight weight = w.view(self.orin_out_features, self.orin_in_features).T self.has_bias = False elif isinstance(w, tuple): w = list(w) weight = w[0].view(self.orin_out_features, self.orin_in_features).T self.bias = w[1].view(self.orin_out_features) self.bias = w[1] self.has_bias = True else: raise ValueError("Invalid weight type") weight = weight.to(device) if self.has_bias: self.bias = self.bias.to(device) if self.padding: padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device) padded_weight[:self.orin_in_features, :self.orin_out_features] = weight weight = padded_weight # Pack Marlin linear marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( weight, self.num_bits, self.group_size, self.act_order ) self.workspace = MarlinWorkspace( self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device ) self.weight = marlin_q_w # modeling_xxx.py may use linear.weight self.marlin_q_w = marlin_q_w self.marlin_s = marlin_s self.g_idx = g_idx self.sort_indices = sort_indices self.k = weight.shape[0] self.n = weight.shape[1] self.loaded = True def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> torch.Tensor: # Only support input x as BF16 and FP16 x = x.to(self.device) orig_shape = list(x.shape) orig_dtype = x.dtype x = x.reshape(-1, orig_shape[-1]) x = x.reshape(-1, x.shape[-1]) if self.padding: padding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype) padding_input[:,:self.orin_in_features] = x x = padding_input marlin_s = self.marlin_s.to(x.dtype) x = KTransformersOps.gptq_marlin_gemm( x, self.marlin_q_w, marlin_s, self.g_idx, self.sort_indices, self.workspace.scratch, self.num_bits, x.shape[0], self.n, x.shape[-1], self.is_k_full, ) if self.padding: x = x[:,:self.orin_out_features] orig_shape[-1] = self.orin_out_features else: orig_shape[-1] = self.out_features if self.has_bias: x = x + self.bias return x.reshape(orig_shape).to(orig_dtype) def unload(self): if self.has_bias: self.bias = None self.marlin_q_w = None self.marlin_s = None self.g_idx = None self.sort_indices = None self.workspace = None class KLinearCPUInfer(KLinearBase): CPU_INFER = None def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, device: str = "cpu", out_device: str = "cuda", # this device mean which device the output should on. TODO: support cpu. stride = 16, group_max_len = 1024, **kwargs, ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) if KLinearCPUInfer.CPU_INFER is None: KLinearCPUInfer.CPU_INFER = CPUInfer(Config().cpu_infer) self.has_bias = False self.dtype = torch.get_default_dtype() self.w = None self.has_bias = False self.stride = stride self.group_max_len = group_max_len self.out_device = out_device def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor: origin_shape = x.shape # [batch_size, q_len, hidden_size] if origin_shape[1] == 1 and torch.cuda.is_current_stream_capturing(): out_device = x.device self.input_tensor_cpu.copy_(x, non_blocking=True) qlen = origin_shape[1] KLinearCPUInfer.CPU_INFER.submit_with_cuda_stream( torch.cuda.current_stream().cuda_stream, self.linear.forward( qlen, self.input_tensor_cpu.data_ptr(), self.output_cpu.data_ptr() ) ) KLinearCPUInfer.CPU_INFER.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) self.output_gpu.copy_(self.output_cpu, non_blocking=True) if self.has_bias: self.output_gpu += self.bias return self.output_gpu else: dtype = x.dtype out_device = x.device x = x.to(device=self.device) qlen = origin_shape[1] output_shape = (*origin_shape[:-1], self.out_features) output = torch.empty(output_shape, device=x.device, dtype=x.dtype) KLinearCPUInfer.CPU_INFER.submit( self.linear.forward( qlen, x.data_ptr(), output.data_ptr() ) ) KLinearCPUInfer.CPU_INFER.sync() if self.has_bias: output = output + self.bias output = output.to(dtype=dtype, device=out_device) return output def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None, warmup:bool = True): print(f"loading {self.key} to {self.device} using CPUInfer") if device is None: device = self.device self.load_weights(w=w, device=device) if self.bias is not None: self.has_bias = True self.bias = self.bias.to(device) weight_ptr = ctypes.addressof( ctypes.cast(self.weight.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents ) config = cpuinfer_ext.linear.LinearConfig(self.in_features, self.out_features, self.stride, self.group_max_len, weight_ptr, self.weight_type, 30) self.linear = cpuinfer_ext.linear.Linear(config) if warmup: KLinearCPUInfer.CPU_INFER.submit(self.linear.warm_up()) KLinearCPUInfer.CPU_INFER.sync() self.input_tensor_cpu = torch.zeros((1, 1, self.in_features), device="cpu", pin_memory=True) self.output_cpu = torch.zeros((1, 1, self.out_features), device="cpu", pin_memory=True, dtype=torch.bfloat16) self.output_gpu = torch.zeros((1, 1, self.out_features), device=self.out_device) def load_weights(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu"): if self.gguf_loader.has_tensor(self.key + ".weight"): if self.key + ".bias" in self.gguf_loader.tensor_file_map: self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight") self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"] self.bias = self.gguf_loader.load_gguf_tensor(self.key + ".bias", device=device) else: self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight") self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"] self.bias = None else: raise ValueError(f"Linear {self.key} not found in gguf_loader") def unload(self): if self.w is not None: self.w = None if self.has_bias: self.bias = None class KLinearIPEXLLM(KLinearBase): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module = None, device: str = "xpu", precision: str = "sym_int4", **kwargs, ): super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) self.has_bias = False self.dtype = torch.get_default_dtype() self.weight = None self.has_bias = False self.precision = precision self.qtype = None def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor: dtype = x.dtype out_device = x.device from ipex_llm.transformers.models.common import linear_forward x = linear_forward(x.half(), self.weight, self.qtype, self.out_features) if self.has_bias: x = x + self.bias x = x.to(dtype=dtype, device=out_device) return x def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): if self.loaded: return if device is None: device = self.device assert device.lower()[:3] == "xpu", "IPEX-LLM quantized linear only supports XPU device" if w is None: w = self.load_weight(device=device) if isinstance(w, nn.Parameter): try: weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T except: weight = w.to(dtype=self.dtype).T self.has_bias = False elif isinstance(w, tuple): try: weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T except: weight = w[0].to(dtype=self.dtype).T self.bias = w[1].to(dtype=self.dtype) self.has_bias = True else: raise ValueError("Invalid weight type") weight = weight.to("cpu").float().transpose(0, 1).contiguous() if self.has_bias: self.bias = self.bias.to(device) # quantize linear weight from ipex_llm.transformers.models.common import quantize_linear paramsLowBit, qtype = quantize_linear(weight, self.in_features, self.precision) self.weight = paramsLowBit.to(device) self.qtype = qtype self.loaded = True def unload(self): if self.weight is not None: self.weight = None if self.has_bias: self.bias = None LINEAR_MAP = { "KLinearMarlin": KLinearMarlin, "KLinearTorch": KLinearTorch, "KLinearCPUInfer": KLinearCPUInfer, "VLinearMarlin": VLinearMarlin, "KLinearFP8": KLinearFP8, "KLinearQ8": KLinearQ8, "KLinearIPEXLLM": KLinearIPEXLLM, } class KTransformersLinear(BaseInjectedModule, KLinearBase): def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, generate_device: str = "cuda", generate_op: str| None = "KLinearMarlin", prefill_device: str = "cuda", prefill_op: str| None = "KLinearTorch", **kwargs, ): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) KLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) # build all the linear operators if prefill_op is not None: assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported" self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs) else: self.prefill_linear = None if generate_op is not None: assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported" self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs) else: self.generate_linear = None self.mode = InferenceState.UNLOAD def forward(self, x, bsz_tensor=None): if self.mode == InferenceState.PREFILL: assert self.prefill_linear is not None, "cpu linear is not initialized" y = self.prefill_linear.forward(x, bsz_tensor) else: assert self.generate_linear is not None, "gpu linear is not initialized" y = self.generate_linear.forward(x, bsz_tensor) return y def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE): if not mode: mode = InferenceState.GENERATE # load to device if mode == InferenceState.PREFILL: self.generate_linear.unload() self.prefill_linear.load(w=w) self.device = self.prefill_linear.device self.weight = self.prefill_linear.weight # modeling_xxx.py may use linear.weight elif mode == InferenceState.GENERATE: self.prefill_linear.unload() self.generate_linear.load(w=w) self.device = self.generate_linear.device self.weight = self.generate_linear.weight # modeling_xxx.py may use linear.weight elif mode == InferenceState.UNLOAD: self.prefill_linear.unload() self.generate_linear.unload() self.device = "cpu" else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") self.mode = mode def unload(self): if self.prefill_linear is not None: self.prefill_linear.unload() if self.generate_linear is not None: self.generate_linear.unload() self.device = self.generate_linear.device def set_inference_mode(self, mode: InferenceState): if not mode: mode = InferenceState.GENERATE if mode == InferenceState.GENERATE: self.load(mode=InferenceState.GENERATE) elif mode == InferenceState.PREFILL: self.load(mode=InferenceState.PREFILL) elif mode == InferenceState.UNLOAD: self.unload() else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") ================================================ FILE: archive/ktransformers/operators/mlp.py ================================================ from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_loader import GGUFLoader from transformers import PretrainedConfig import torch.nn as nn from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP from ktransformers.models.modeling_qwen2_moe import Qwen2MoeMLP from ktransformers.models.modeling_smallthinker import SmallthinkerDenseMlpBlock from ktransformers.models.modeling_glm4_moe import Glm4MoeMLP class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.hidden_size, orig_module.intermediate_size) def forward(self, x, bsz_tensor): down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor) return down_proj class KQwen2MoeMLP(Qwen2MoeMLP, BaseInjectedModule): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.intermediate_size) def forward(self, x, bsz_tensor): down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor) return down_proj class KSmallthinkerDenseMlpBlock(SmallthinkerDenseMlpBlock, BaseInjectedModule): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.config) def forward(self, x, bsz_tensor): down_proj = self.down(nn.functional.relu(self.gate(x, bsz_tensor)) * self.up(x, bsz_tensor), bsz_tensor) return down_proj class KGlm4MoeMLP(Glm4MoeMLP, BaseInjectedModule): def __init__(self, key: str, gguf_loader : GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, prefill_device: str = "cuda", generate_device: str = "cuda", **kwargs): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.hidden_size, orig_module.intermediate_size) def forward(self, x, bsz_tensor): down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor) return down_proj ================================================ FILE: archive/ktransformers/operators/models.py ================================================ #!/usr/bin/env python # coding=utf-8 """ Description : Author : Azure-Tang Date : 2024-07-25 11:25:24 Version : 1.0.0 LastEditors : Azure LastEditTime : 2024-08-27 07:29:04 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. """ import inspect import math from typing import List, Optional, Tuple, Union import time import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ktransformers.operators.dynamic_attention import DynamicScaledDotProductAttention from ktransformers.server.config.config import Config import os import yaml from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_attn_mask_utils import ( AttentionMaskConverter, ) from transformers.modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from ktransformers.models.modeling_qwen2_moe import ( Qwen2MoeSparseMoeBlock, Qwen2MoeMLP, Qwen2MoeDecoderLayer, ) from ktransformers.models.modeling_deepseek import ( BaseModelOutputWithPast, DeepseekV2DecoderLayer, DeepseekV2MoE, ) from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.utils import InferenceState, get_compute_capability from ktransformers.util.custom_loader import GGUFLoader from transformers.configuration_utils import PretrainedConfig from ktransformers.models.modeling_llama import ( LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, ) if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa _flash_supports_window_size = "window_size" in list( inspect.signature(flash_attn_func).parameters ) try: import torch_npu use_torch_npu = torch_npu.npu.is_available() except: use_torch_npu = False logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "Qwen/Qwen1.5-MoE-A2.7B" _CONFIG_FOR_DOC = "Qwen2MoeConfig" QWEN2MOE_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`Qwen2MoeConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ QWEN2MOE_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. output_router_logits (`bool`, *optional*): Whether or not to return the logits of all the routers. They are useful for computing the router loss, and should not be returned during inference. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ @add_start_docstrings( "The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.", QWEN2MOE_START_DOCSTRING, ) class KQwen2MoeModel(BaseInjectedModule): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`] Args: config: Qwen2MoeConfig """ def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill transfer_map: dict = None, **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, device, **kwargs ) self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold self.transfer_map = transfer_map self.stream_device_map = dict() @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, per_layer_prefill_intput_threshold: ( int | None ) = None, # if None or 0, close per-layer prefill ) -> Union[Tuple, MoeModelOutputWithPast]: # print(f'Total length of input_ids: {input_ids.size(1)}, {input_ids.size()}') if per_layer_prefill_intput_threshold is None: per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold per_layer_prefill_flag = False seq_lenth = ( inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1) ) if ( per_layer_prefill_intput_threshold and per_layer_prefill_intput_threshold < seq_lenth ): per_layer_prefill_flag = True for layer in self.layers: self.load_layer_to(layer, InferenceState.UNLOAD) else: pass output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False use_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) if inputs_embeds is None: input_ids = input_ids.to("cpu") inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = inputs_embeds.to("cuda") if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers if torch.xpu.is_available() and inputs_embeds.device.type == "xpu": position_embeddings = self.rotary_emb(hidden_states, position_ids) else: position_embeddings = None # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None next_decoder_cache = None for i, decoder_layer in enumerate(self.layers): if self.transfer_map is not None and i in self.transfer_map: prev_stream = torch.cuda.current_stream() cur_device = self.transfer_map[i] if cur_device not in self.stream_device_map: self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) torch.cuda.set_device(cur_device) self.stream_device_map[cur_device].wait_stream(prev_stream) torch.cuda.set_stream(self.stream_device_map[cur_device]) hidden_states = hidden_states.to( self.transfer_map[i], non_blocking=True ) causal_mask = ( causal_mask.to(self.transfer_map[i], non_blocking=True) if causal_mask is not None else None ) position_ids = ( position_ids.to(self.transfer_map[i], non_blocking=True) if position_ids is not None else None ) cache_position = ( cache_position.to(self.transfer_map[i], non_blocking=True) if cache_position is not None else None ) if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, output_router_logits, use_cache, cache_position, ) else: if per_layer_prefill_flag: # print(f"to gpu") self.load_layer_to(decoder_layer, InferenceState.PREFILL) torch.cuda.empty_cache() layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) if per_layer_prefill_flag: # print(f"to cpu") self.load_layer_to(decoder_layer, InferenceState.UNLOAD) torch.cuda.empty_cache() hidden_states = layer_outputs[0] if use_cache and len(layer_outputs) > 1: next_decoder_cache = layer_outputs[2 if output_attentions else 1] else: next_decoder_cache = None if output_attentions: all_self_attns += (layer_outputs[1],) if output_router_logits and layer_outputs[-1] is not None: all_router_logits += (layer_outputs[-1],) hidden_states = self.norm(hidden_states) if per_layer_prefill_flag: per_layer_prefill_flag = False for layer in self.layers: self.load_layer_to(layer, InferenceState.GENERATE) if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: if next_decoder_cache is not None: next_cache = ( next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache ) else: next_cache = past_key_values if not return_dict: return tuple( v for v in [ hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits, ] if v is not None ) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) def load_layer_to(self, layer: Qwen2MoeDecoderLayer, target: InferenceState): assert isinstance( layer, Qwen2MoeDecoderLayer ), "module should be nn.ModuleList of decoder layers" # TODO Support restore to original device, not only cuda device = "cpu" if target == InferenceState.UNLOAD else "cuda" # attn layer.self_attn.q_proj.set_inference_mode(target) layer.self_attn.k_proj.set_inference_mode(target) layer.self_attn.v_proj.set_inference_mode(target) layer.self_attn.o_proj.set_inference_mode(target) layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(device) # mlp if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock): layer.mlp.gate.set_inference_mode(target) layer.mlp.experts.set_inference_mode(target) layer.mlp.shared_expert.gate_proj.set_inference_mode(target) layer.mlp.shared_expert.up_proj.set_inference_mode(target) layer.mlp.shared_expert.down_proj.set_inference_mode(target) layer.mlp.shared_expert.act_fn.to(device) layer.mlp.shared_expert_gate.to(device) else: layer.mlp.gate_proj.set_inference_mode(target) layer.mlp.up_proj.set_inference_mode(target) layer.mlp.down_proj.set_inference_mode(target) layer.mlp.act_fn.to(device) # layer norm layer.input_layernorm.to(device) layer.post_attention_layernorm.to(device) DeepseekV2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ class KDeepseekV2Model(BaseInjectedModule): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] Args: config: DeepseekV2Config """ def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill transfer_map: dict = None, **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, device, **kwargs ) self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold self.transfer_map = transfer_map self.stream_device_map = dict() @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, per_layer_prefill_intput_threshold: ( int | None ) = None, # if None, no per-layer prefill is_prefill: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPast]: if per_layer_prefill_intput_threshold is None: per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold per_layer_prefill_flag = False seq_lenth = ( inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1) ) if ( per_layer_prefill_intput_threshold and per_layer_prefill_intput_threshold < seq_lenth ): per_layer_prefill_flag = True for layer in self.layers: self.load_layer_to(layer, InferenceState.UNLOAD) torch.cuda.empty_cache() else: pass output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." ) use_cache = False past_key_values_length = 0 if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) if inputs_embeds is None: org_device = input_ids.device # TODO move to embed_tokens's device, not hard code to cpu input_ids = input_ids.to("cpu") inputs_embeds = self.embed_tokens(input_ids).to(org_device) input_ids = input_ids.to(org_device) if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) if inputs_embeds.device.type == "xpu" and position_ids is not None: cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds, position_ids) position_embeddings = (cos, sin) else: position_embeddings = None if per_layer_prefill_flag: causal_mask = None elif use_torch_npu and not is_prefill: causal_mask = None else: if (use_torch_npu or os.name == 'nt' or get_compute_capability() < 8 or (self.transfer_map is not None and 'cpu' in self.transfer_map.values()) or device_manager.gpu_vendor != GPUVendor.NVIDIA): # print("for Windows or GPU before ampere, use forward_windows") # only use mask in forward windows or can't flash attn causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) else: causal_mask = None # embed positions hidden_states = inputs_embeds if per_layer_prefill_flag: print(f"Total length of input_ids: {hidden_states.size(1)}") # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None t_gpu = 0 t_cpu = 0 t_f = 0 for i, decoder_layer in enumerate(self.layers): # print(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n") if self.transfer_map is not None and i in self.transfer_map: prev_stream = torch.cuda.current_stream() cur_device = self.transfer_map[i] if cur_device not in self.stream_device_map and cur_device.lower() != "cpu": self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) if cur_device.lower() != "cpu": torch.cuda.set_device(cur_device) self.stream_device_map[cur_device].wait_stream(prev_stream) torch.cuda.set_stream(self.stream_device_map[cur_device]) hidden_states = hidden_states.to( self.transfer_map[i], non_blocking=True ) causal_mask = ( causal_mask.to(self.transfer_map[i], non_blocking=True) if causal_mask is not None else None ) position_ids = ( position_ids.to(self.transfer_map[i], non_blocking=True) if position_ids is not None else None ) cache_position = ( cache_position.to(self.transfer_map[i], non_blocking=True) if cache_position is not None else None ) if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, ) else: t3 = time.time() if per_layer_prefill_flag: # print(f"to gpu") self.load_layer_to(decoder_layer, InferenceState.PREFILL) torch.cuda.empty_cache() t4 = time.time() # with open("log.txt", "a") as f: # f.write(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n") layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, is_prefill = is_prefill, ) t5 = time.time() if per_layer_prefill_flag: # print(f"to cpu") self.load_layer_to(decoder_layer, InferenceState.UNLOAD) torch.cuda.empty_cache() t6 = time.time() t_gpu += t4 - t3 t_cpu += t6 - t5 t_f += t5 - t4 hidden_states = layer_outputs[0] # @@@@@@@ TODO open this notes, tmp close to fit deepseekv3 if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) if use_torch_npu: hidden_states_without_norm = hidden_states.clone() hidden_states = self.norm(hidden_states) # with open("log.txt", "a") as f: # f.write(f"@@@After layers\n") # f.write(f"hidden_states={hidden_states}\n") # f.write(f"hidden_states.shape={hidden_states.shape}\n") if per_layer_prefill_flag: t6 = time.time() # print(f"restore") per_layer_prefill_flag = False for layer in self.layers: self.load_layer_to(layer, InferenceState.GENERATE) torch.cuda.empty_cache() t7 = time.time() print( f"total time: {t7-t3}, \n layer num{len(self.layers)}, gpu time: {t_gpu}, cpu time: {t_cpu}, forward time: {t_f}, restore time: {t7-t6}" ) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = ( next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache ) if not return_dict: if use_torch_npu: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, hidden_states_without_norm] if v is not None ) else: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) def load_layer_to(self, layer: DeepseekV2DecoderLayer, target: InferenceState): assert isinstance( layer, DeepseekV2DecoderLayer ), "module should be nn.ModuleList of decoder layers" # TODO Support restore to original device, not only cuda device = "cpu" if target == InferenceState.UNLOAD else "cuda" # TODO Support DFS to auto use {to, set_inference_mode} according to the module type # attn layer.self_attn.to(device) # # mlp if isinstance(layer.mlp, DeepseekV2MoE): layer.mlp.gate.to(device) layer.mlp.experts.set_inference_mode(target) layer.mlp.shared_experts.gate_proj.set_inference_mode(target) layer.mlp.shared_experts.up_proj.set_inference_mode(target) layer.mlp.shared_experts.down_proj.set_inference_mode(target) layer.mlp.shared_experts.act_fn.to(device) # layer.mlp.shared_expert_gate.to(device) else: layer.mlp.gate_proj.set_inference_mode(target) layer.mlp.up_proj.set_inference_mode(target) layer.mlp.down_proj.set_inference_mode(target) layer.mlp.act_fn.to(device) # layer norm layer.input_layernorm.to(device) layer.post_attention_layernorm.to(device) LLAMA_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`LlamaConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ LLAMA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ @add_start_docstrings( "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) class LlamaPreTrainedModel(PreTrainedModel): config_class = LlamaConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class KLlamaModel(BaseInjectedModule): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] Args: config: LlamaConfig """ dynamic_sdpa = None def __init__( self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill transfer_map: dict = None, **kwargs, ): BaseInjectedModule.__init__( self, key, gguf_loader, config, orig_module, device, **kwargs ) self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold self.transfer_map = transfer_map self.stream_device_map = dict() user_path: str = os.path.expanduser('~') localstore_path: str = os.path.join(user_path,'.ktransformers') config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME) with open(config_path,"r") as file: config_yaml = yaml.safe_load(file.read()) self.long_context_config = config_yaml.get("long_context") self.ext_config = config_yaml.get("ext") KLlamaModel.dynamic_sdpa = DynamicScaledDotProductAttention( max_seq_len=self.long_context_config["max_seq_len"], block_size=self.long_context_config["block_size"], config=config, device=torch.device("cuda"), local_windows_len=self.long_context_config["local_windows_len"], topk=self.long_context_config["second_select_num"], threads_num=self.ext_config["cpu_infer"], anchor_type=self.long_context_config["anchor_type"], kv_type=self.long_context_config["kv_type"], dense_layer_num=self.long_context_config["dense_layer_num"], anchor_num=self.long_context_config["anchor_num"], preselect_block=self.long_context_config["preselect_block"], block_selection_mode=self.long_context_config["head_select_mode"], preselect_block_count=self.long_context_config["preselect_block_count"], layer_step=self.long_context_config["layer_step"], token_step=self.long_context_config["token_step"], prefill_chunk_size=self.long_context_config["chunk_size"], use_attn_sparsity=False, ) def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False return_legacy_cache = False if ( use_cache and not isinstance(past_key_values, Cache) and not self.training ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device="cuda", ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = None chunck_size = self.long_context_config["chunk_size"] cur_idx = 0 if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids.to("cpu")) q_len = cache_position.size(0) # generate if q_len == 1: x = inputs_embeds[:, -1:, :] position_ids = position_ids[:, -1:] return self.forward_chunk( x, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, output_hidden_states, return_dict, ) elif q_len <= chunck_size: inputs_embeds = inputs_embeds.to('cuda') output = self.forward_chunk( inputs_embeds, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, output_hidden_states, return_dict, ) KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1) KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1) return output cur_idx = 0 assert ( output_attentions == False ), "output_attentions is not supported when using chunked attention" attn_output = None # prefill KLlamaModel.dynamic_sdpa.remaining_length = q_len while cur_idx < q_len: print(f'current prefill length: {cur_idx}') chunk_mask = None if inputs_embeds.device.type == 'cpu': tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)].to("cuda") else: tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)] output_with_past = self.forward_chunk( tmp_inputs_embeds, chunk_mask, position_ids[:, cur_idx : min(cur_idx + chunck_size, q_len)], past_key_values, output_attentions, use_cache, cache_position[cur_idx : min(cur_idx + chunck_size, q_len)], ) cur_output = output_with_past.last_hidden_state KLlamaModel.dynamic_sdpa.remaining_length -= ( min(cur_idx + chunck_size, q_len) - cur_idx ) cur_idx += chunck_size # if attn_output is None: attn_output = cur_output # else: # attn_output = torch.cat((attn_output, cur_output), dim=-2) KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1) KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1) return BaseModelOutputWithPast(last_hidden_state=attn_output) def forward_chunk( self, inputs_embeds, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_legacy_cache = False if use_cache and not isinstance( past_key_values, Cache ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: next_cache = next_cache.to_legacy_cache() if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if ( self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing if attention_mask.max() != 0: raise ValueError( "Custom 4D attention mask should be passed in inverted form with max==0`" ) causal_mask = attention_mask else: causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device, ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange( target_length, device=device ) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand( input_tensor.shape[0], 1, -1, -1 ) if attention_mask is not None: causal_mask = ( causal_mask.clone() ) # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = ( causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[ :, :, :, :mask_length ].masked_fill(padding_mask, min_dtype) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended( causal_mask, min_dtype ) return causal_mask ================================================ FILE: archive/ktransformers/operators/triton_attention.py ================================================ # Adapted from # https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py # which was originally adapted from # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py import triton import triton.language as tl from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor @triton.jit def tanh(x): # Tanh is just a scaled sigmoid return 2 * tl.sigmoid(2 * x) - 1 @triton.jit def _fwd_grouped_kernel_stage1( Q, K_Buffer, V_Buffer, sm_scale, Req_to_tokens, B_Seqlen, Att_Out, stride_req_to_tokens_b, stride_qbs, stride_qh, stride_buf_kbs, stride_buf_kh, stride_buf_vbs, stride_buf_vh, stride_mid_ob, stride_mid_oh, stride_mid_os, kv_group_num: tl.constexpr, q_head_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DPE: tl.constexpr, BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, NUM_KV_SPLITS: tl.constexpr, PAGE_SIZE: tl.constexpr, logit_cap: tl.constexpr, Lk: tl.constexpr, Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head_id = tl.program_id(1) cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) split_kv_id = tl.program_id(2) if kv_group_num > BLOCK_H: VALID_BLOCK_H: tl.constexpr = BLOCK_H else: VALID_BLOCK_H: tl.constexpr = kv_group_num cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H mask_h = mask_h & (cur_head < q_head_num) offs_d = tl.arange(0, BLOCK_DMODEL) offs_dv = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lk mask_dv = offs_dv < Lv cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_req_idx = cur_batch offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[ None, :] q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) mask_dpe = offs_dpe < Lk off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]) qpe = tl.load(Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0) kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) if split_kv_end > split_kv_start: for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n // PAGE_SIZE, mask=offs_n < split_kv_end, other=0, ) kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE offs_buf_k = (kv_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_d[:, None]) k = tl.load( K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0, ) qk = tl.dot(q, k.to(q.dtype)) if BLOCK_DPE > 0: offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_dpe[:, None]) kpe = tl.load( K_Buffer + offs_buf_kpe, mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), other=0.0, ) qk += tl.dot(qpe, kpe.to(qpe.dtype)) qk *= sm_scale if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")) offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + cur_kv_head * stride_buf_vh + offs_dv[None, :]) v = tl.load( V_Buffer + offs_buf_v, mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), other=0.0, ) n_e_max = tl.maximum(tl.max(qk, 1), e_max) re_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max[:, None]) acc *= re_scale[:, None] acc += tl.dot(p.to(v.dtype), v) e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max offs_mid_o = (cur_batch * stride_mid_ob + cur_head[:, None] * stride_mid_oh + split_kv_id * stride_mid_os + offs_dv[None, :]) tl.store( Att_Out + offs_mid_o, acc / e_sum[:, None], mask=(mask_h[:, None]) & (mask_dv[None, :]), ) offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os + Lv) tl.store( Att_Out + offs_mid_o_1, e_max + tl.log(e_sum), mask=mask_h, ) def _decode_grouped_att_m_fwd( q, k_buffer, v_buffer, att_out, Req_to_tokens, B_Seqlen, num_kv_splits, sm_scale, page_size, logit_cap, ): BLOCK = 32 Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] # [TODO] work around shmem limit on MI3xx # TODO: support hip if device_manager.gpu_vendor == GPUVendor.AMD and Lk >= 576: BLOCK = 16 if Lk == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 elif Lk == 288: BLOCK_DMODEL = 256 BLOCK_DPE = 32 else: BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) batch, head_num = q.shape[0], q.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[-2] BLOCK_H = 16 NUM_KV_SPLITS = num_kv_splits grid = ( batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), NUM_KV_SPLITS, ) extra_kargs = {} # TODO: support hip """ if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = { "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2 } """ _fwd_grouped_kernel_stage1[grid]( q, k_buffer, v_buffer, sm_scale, Req_to_tokens, B_Seqlen, att_out, Req_to_tokens.stride(0), q.stride(0), q.stride(1), k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) att_out.stride(0), att_out.stride(1), att_out.stride(2), kv_group_num=kv_group_num, q_head_num=head_num, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, BLOCK_DV=BLOCK_DV, BLOCK_N=BLOCK, BLOCK_H=BLOCK_H, NUM_KV_SPLITS=NUM_KV_SPLITS, PAGE_SIZE=page_size, logit_cap=logit_cap, num_warps=4, num_stages=2, Lk=Lk, Lv=Lv, **extra_kargs, ) @triton.jit def _fwd_kernel_stage2( Mid_O, o, B_Seqlen, stride_mid_ob, stride_mid_oh, stride_mid_os, stride_obs, stride_oh, NUM_KV_SPLITS: tl.constexpr, BLOCK_DV: tl.constexpr, Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) offs_d = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lv e_sum = 0.0 e_max = -float("inf") acc = tl.zeros([BLOCK_DV], dtype=tl.float32) offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv for split_kv_id in range(0, NUM_KV_SPLITS): kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) if split_kv_end > split_kv_start: tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0) tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) n_e_max = tl.maximum(tlogic, e_max) old_scale = tl.exp(e_max - n_e_max) acc *= old_scale exp_logic = tl.exp(tlogic - n_e_max) acc += exp_logic * tv e_sum = e_sum * old_scale + exp_logic e_max = n_e_max tl.store( o + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / e_sum, mask=mask_d, ) def _decode_softmax_reducev_fwd( logits, q, o, v_buffer, b_seq_len, num_kv_splits, ): batch, head_num = q.shape[0], q.shape[1] Lv = v_buffer.shape[-1] BLOCK_DV = triton.next_power_of_2(Lv) NUM_KV_SPLITS = num_kv_splits extra_kargs = {} # TODO: support hip """ if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = { "waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2 } """ grid = (batch, head_num) _fwd_kernel_stage2[grid]( logits, o, b_seq_len, logits.stride(0), logits.stride(1), logits.stride(2), o.stride(0), o.stride(1), NUM_KV_SPLITS=NUM_KV_SPLITS, BLOCK_DV=BLOCK_DV, Lv=Lv, num_warps=4, num_stages=2, **extra_kargs, ) def decode_attention_fwd_grouped( q, k_buffer, v_buffer, o, req_to_token, b_seq_len, attn_logits, num_kv_splits, sm_scale, page_size, logit_cap=0.0, ): _decode_grouped_att_m_fwd( q, k_buffer, v_buffer, attn_logits, req_to_token, b_seq_len, num_kv_splits, sm_scale, page_size, logit_cap, ) _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) ================================================ FILE: archive/ktransformers/operators/triton_attention_prefill.py ================================================ # Adapted from # https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py # which was originally adapted from # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 """ Memory-efficient attention for prefill. It supporst page size = 1. """ # Adapted from # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 import torch import triton import triton.language as tl is_cuda_available = torch.cuda.is_available() if is_cuda_available: CUDA_CAPABILITY = torch.cuda.get_device_capability() @triton.jit def _fwd_kernel( Q, K, V, sm_scale, B_Start_Loc, B_Seqlen, Out, stride_qbs, stride_qh, stride_kbs, stride_kh, stride_vbs, stride_vh, stride_obs, stride_oh, kv_group_num: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, IS_CAUSAL: tl.constexpr, Lk: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) cur_kv_head = cur_head // kv_group_num cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) block_start_loc = BLOCK_M * start_m # initialize offsets offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] ) off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] mask_d = offs_d < Lk q = tl.load( Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]), other=0.0, ) k_ptrs = K + off_k v_ptrs = V + off_v # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) end_n = ( cur_batch_seq_len if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len) ) for start_n in range(0, block_mask * end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load( k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]), other=0.0, ) # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale if IS_CAUSAL: qk += tl.where( (start_n + offs_n[None, :] < cur_batch_seq_len) & (offs_m[:, None] >= (start_n + offs_n[None, :])), 0, float("-inf"), ) else: qk += tl.where( (start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf") ) # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) p = tl.exp(qk - m_ij[:, None]) l_ij = tl.sum(p, 1) # -- update m_i and l_i m_i_new = tl.maximum(m_i, m_ij) alpha = tl.exp(m_i - m_i_new) beta = tl.exp(m_ij - m_i_new) l_i_new = alpha * l_i + beta * l_ij # -- update output accumulator -- # scale p p_scale = beta / l_i_new p = p * p_scale[:, None] # scale acc acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc v = tl.load( v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]), other=0.0, ) p = p.to(v.dtype) acc += tl.dot(p, v) # update m_i and l_i l_i = l_i_new m_i = m_i_new # initialize pointers to output off_o = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] ) out_ptrs = Out + off_o tl.store( out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]) ) def context_attention_fwd( q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True ): """ q, k, v: [b * s, head, head_dim] b_start_loc: [b] b_seq_len: [b] out: [b * s, head, head_dim] """ if is_cuda_available and CUDA_CAPABILITY[0] > 8: BLOCK = 128 else: BLOCK = 64 Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) num_warps = 4 if Lk <= 64 else 8 _fwd_kernel[grid]( q, k, v, sm_scale, b_start_loc, b_seq_len, o, q.stride(0), q.stride(1), k.stride(0), k.stride(1), v.stride(0), v.stride(1), o.stride(0), o.stride(1), kv_group_num=kv_group_num, BLOCK_M=BLOCK, BLOCK_DMODEL=triton.next_power_of_2(Lk), BLOCK_N=BLOCK, IS_CAUSAL=is_causal, num_warps=num_warps, num_stages=1, Lk=Lk, ) ================================================ FILE: archive/ktransformers/optimize/optimize.py ================================================ ''' Description : Author : Boxin Zhang, Azure-Tang Version : 0.1.0 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' from typing import Mapping, List import torch import yaml import re from torch import nn from transformers import AutoConfig from transformers.configuration_utils import PretrainedConfig # from operators import BaseInjectedModule from ktransformers.util.custom_loader import GGUFLoader, ModelLoaderFactory from ktransformers.util.custom_gguf import translate_name_to_gguf from ktransformers.util import utils from ktransformers.util.utils import set_module, load_weights import itertools import copy try: import torch_npu use_torch_npu = torch_npu.npu.is_available() except: use_torch_npu = False def inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader:GGUFLoader, prefix=''): for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name if child_prefix in local_optimization_dict: inject_module_meta=local_optimization_dict[child_prefix] if inject_module_meta["class"] != "default": import_path = inject_module_meta["class"].split(".") import_module_name = ".".join(import_path[:-1]) gguf_loader.tensor_device_map[inject_module_meta["key"]] = inject_module_meta["kwargs"] if "kwargs" in inject_module_meta else dict() import_class_name = import_path[-1] module_cls=getattr(__import__(import_module_name, fromlist=[""]), import_class_name) if use_torch_npu: print(f"Injecting {child_prefix} as", import_module_name, ".", import_class_name) if torch.distributed.get_rank() == 0 else None #TODO 分布式 else: print(f"Injecting {child_prefix} as", import_module_name, ".", import_class_name) inject_module=module_cls(key = inject_module_meta["key"], gguf_loader = gguf_loader, config = model_config, orig_module=child, **inject_module_meta["kwargs"]) set_module(module, name, inject_module) elif inject_module_meta["class"] == "default": print(f"Injecting {child_prefix} as default") gguf_loader.tensor_device_map[inject_module_meta["key"]] = inject_module_meta["kwargs"] if "kwargs" in inject_module_meta else dict() else: raise Exception("inject_module_meta[\"class\"] must be \"default\" or a class path") child_prefix += "." child_optimization_dict = {k: v for k, v in local_optimization_dict.items() if k.startswith(child_prefix)} inject(child, child_optimization_dict, model_config, gguf_loader, child_prefix) def del_meta(module:nn.Module): #print("default loading weights", prefix) persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set} local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items()) local_state = {k: v for k, v in local_name_params if v is not None} for name, param in local_state.items(): if param.device == "meta" or param.device == torch.device("meta"): module.__delattr__(name) for name, child in module._modules.items(): del_meta(child) def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, prefix: str="", default_device: str = "cuda:0"): module_name = prefix[:-1] if use_torch_npu: translated_name = translate_name_to_gguf(prefix)[:-1] recursive = True for rule in rule_list: match_meta = rule["match"] if "class" not in match_meta and "name" not in match_meta: raise Exception("match must have at least one of \"class\" and \"name\"") if "class" in match_meta: import_path = match_meta["class"].split(".") import_module_name = ".".join(import_path[:-1]) import_class_name = import_path[-1] module_cls=getattr(__import__(import_module_name, fromlist=[""]), import_class_name) if not isinstance(module, module_cls): continue if "name" in match_meta: if re.search(match_meta["name"], module_name) is None: continue if "replace" not in rule: raise Exception("replace must be in rule") if "replace" in rule: replace_meta = rule["replace"] if module_name not in out_data: out_data[module_name]={"key": module_name if not use_torch_npu else translated_name, "class": replace_meta["class"] if "class" in replace_meta else "default", # "device": replace_meta["device"] if "device" in replace_meta else default_device, "kwargs": copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict()} else: if out_data[module_name]["class"] == "default": out_data[module_name]["class"] = replace_meta["class"] if "class" in replace_meta else "default" out_data[module_name]["kwargs"].update(copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict()) if "recursive" in rule: recursive = bool(rule["recursive"]) break if module_name not in out_data: out_data[module_name]= { "class": "default", "key": module_name if not use_torch_npu else translated_name, "kwargs": {"generate_device": default_device, "prefill_device": default_device} } #print(out_data[module_name]) #input() if recursive: for name, child in module._modules.items(): if child is not None: child_prefix = prefix + name + "." gen_optimize_config(child, out_data, rule_list, child_prefix, default_device = default_device) def translate_model_config(model_config: PretrainedConfig): # for supporting some special model if model_config.model_type == "mixtral": model_config.moe_intermediate_size = model_config.intermediate_size return model_config def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0", q4_gguf_path=""): with open(rule_file, 'r', encoding='utf-8') as f: rule_list = yaml.load(f.read(), Loader=yaml.FullLoader) optimize_config = dict() gen_optimize_config(module, optimize_config, rule_list, default_device = default_device) model_config = translate_model_config(model_config) if use_torch_npu: if q4_gguf_path: q4_gguf_loader = GGUFLoader(q4_gguf_path) utils.Q4_GGUF_LODER = q4_gguf_loader gguf_loader = GGUFLoader(gguf_path, getattr(model_config, "quantize", None)) with torch.device("meta"): inject(module, optimize_config, model_config, gguf_loader) # pre load lm_head because its big inter result load_weights(module.lm_head, gguf_loader, "lm_head.") load_weights(module, gguf_loader) module.gguf_loader = gguf_loader else: weights_loader = ModelLoaderFactory.create_loader(gguf_path) with torch.device("meta"): inject(module, optimize_config, model_config, weights_loader) # pre load lm_head because its big inter result load_weights(module.lm_head, weights_loader, "lm_head.", device=default_device) load_weights(module, weights_loader, device=default_device) module.gguf_loader = weights_loader del_meta(module) if torch.cuda.is_available(): torch.cuda.empty_cache() elif torch.xpu.is_available(): torch.xpu.empty_cache() else: torch.cuda.empty_cache() ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml ================================================ - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: name: "^model\\.layers\\.([0-9]|[1][0-4])\\." class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbedding kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([2][0-9]|[1][5-9])\\." class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbedding kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.([3][0-9]|[4][0-4])\\." class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbedding kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" - match: name: "^model\\.layers\\.([5][0-9]|[4][5-9])\\." class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbedding kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" - match: name: "^model\\.layers\\.([0-9]|[1][0-4])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.([2][0-9]|[1][5-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.([3][0-9]|[4][0-4])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.([5][0-9]|[4][5-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.([0-9]|[1][0-4])\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE replace: class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([2][0-9]|[1][5-9])\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE replace: class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.([3][0-9]|[4][0-4])\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE replace: class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" - match: name: "^model\\.layers\\.([5][0-9]|[4][5-9])\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE replace: class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" - match: name: "^model\\.layers\\.([0-9]|[1][0-4])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda:0" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:0" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\.([2][0-9]|[1][5-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda:1" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:1" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\.([3][0-9]|[4][0-4])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda:2" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:2" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\.([5][0-9]|[4][5-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda:3" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:3" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\.([0-9]|[1][0-4])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([2][0-9]|[1][5-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.([3][0-9]|[4][0-4])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" - match: name: "^model\\.layers\\.([5][0-9]|[4][5-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill transfer_map: 15: "cuda:1" 30: "cuda:2" 45: "cuda:3" - match: name: "^model\\.layers\\.([0-9]|[1][0-4])\\." replace: class: "default" kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "(^model\\.layers\\.([2][0-9]|[1][5-9])\\.)" replace: class: "default" kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "(^model\\.layers\\.([3][0-9]|[4][0-4])\\.)" replace: class: "default" kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" - match: name: "^lm_head" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "(^model\\.layers\\.([5][0-9]|[4][5-9])\\.)|(^model.norm)" replace: class: "default" kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml ================================================ - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbedding kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([345][0-9])\\." class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbedding kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.([345][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE replace: class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([345][0-9])\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE replace: class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda:0" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:0" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\.([345][0-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda:1" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:1" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([345][0-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill transfer_map: 30: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." replace: class: "default" kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^lm_head" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)" replace: class: "default" kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml ================================================ - match: class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbedding kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^lm_head" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE replace: class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-gpu-cpu.yaml ================================================ - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" # === Rotary Embedding Replacement === # GPU 0: layers 0–9 - match: name: "^model\\.layers\\.(0|[1-9])\\." class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbedding kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" # CPU: layers 10-29 - match: name: "^model\\.layers\\.([12][0-9])\\." class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbedding kwargs: generate_device: "cpu" prefill_device: "cpu" # === Linear Layers Replacement (excluding self_attn) === # GPU 0: layers 0–9 - match: name: "^model\\.layers\\.(0|[1-9])\\.(?!self_attn).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # CPU: layers 10-29 - match: name: "^model\\.layers\\.([12][0-9])\\.(?!self_attn).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cpu" prefill_device: "cpu" generate_op: "KLinearCPUInfer" prefill_op: "KLinearTorch" out_device: "cpu" # === MLP (MoE) Replacement === # GPU 0: layers 0–9 - match: name: "^model\\.layers\\.(0|[1-9])\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE replace: class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" # CPU: layers 10-29 - match: name: "^model\\.layers\\.([12][0-9])\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE replace: class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function kwargs: generate_device: "cpu" prefill_device: "cpu" # === MLP Gate Replacement === # GPU 0: layers 0–9 - match: name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" # CPU: layers 10-29 - match: name: "^model\\.layers\\.([12][0-9])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cpu" prefill_device: "cpu" # === MLP Experts Replacement === # GPU 0: layers 0–9 - match: name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda:0" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:0" recursive: False # don't recursively inject submodules of this module # CPU: layers 10-29 - match: name: "^model\\.layers\\.([12][0-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cpu" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cpu" recursive: False # don't recursively inject submodules of this module # === Self-Attention Replacement === # GPU 0: layers 0–9 - match: name: "^model\\.layers\\.(0|[1-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" # CPU: layers 10-29 - match: name: "^model\\.layers\\.([12][0-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cpu" prefill_device: "cpu" # === Overall Model Replacement with Transfer Map === - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill transfer_map: 10: "cpu" # === Default Catch-All for Other Modules ===# # GPU 0: layers 0–9 - match: name: "^model\\.layers\\.(0|[1-9])\\." replace: class: "default" kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" #lmm_head on GPU 0 - match: name: "^lm_head" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # CPU: layers 10-29 - match: name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml ================================================ - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: name: "^model\\.layers\\.(0|[1-9])\\." class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbedding kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([12][0-9])\\." class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbedding kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9])\\.(?!self_attn).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.([12][0-9])\\.(?!self_attn).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(0|[1-9])\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE replace: class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([12][0-9])\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE replace: class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda:0" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:0" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\.([12][0-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda:1" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:1" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\.(0|[1-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([12][0-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill transfer_map: 10: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9])\\." replace: class: "default" kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^lm_head" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)" replace: class: "default" kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml ================================================ - match: class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbedding kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^lm_head" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE replace: class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-amx.yaml ================================================ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default) recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" absorb_for_prefill: False # change this to True to enable long context(prefill may slower). - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve-amx.yaml ================================================ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearFP8" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoEV2 # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" backend: "llamafile" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm replace: class: ktransformers.operators.layernorm.RMSNorm kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP replace: class: ktransformers.operators.mlp.kDeepseekV3MLP kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "VLinearMarlin" prefill_op: "KLinearTorch" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve.yaml ================================================ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearFP8" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoEV2 # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm replace: class: ktransformers.operators.layernorm.RMSNorm kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP replace: class: ktransformers.operators.mlp.kDeepseekV3MLP kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "VLinearMarlin" prefill_op: "KLinearTorch" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml ================================================ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearFP8" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml ================================================ - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" # === Rotary Embedding Replacement === # GPU 0: layers 0–14 - match: name: "^model\\.layers\\.([0-9]|1[0-4])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" # GPU 1: layers 15–29 - match: name: "^model\\.layers\\.(1[5-9]|2[0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" # GPU 2: layers 30–44 - match: name: "^model\\.layers\\.(3[0-9]|4[0-4])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" # GPU 3: layers 45–60 - match: name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" # === Linear Layers Replacement (excluding self_attn.kv_b_proj) === # GPU 0: layers 0–14 - match: name: "^model\\.layers\\.([0-9]|1[0-4])\\.(?!self_attn\\.kv_b_proj).*$" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # GPU 1: layers 15–29 - match: name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.(?!self_attn\\.kv_b_proj).*$" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # GPU 2: layers 30–44 - match: name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.(?!self_attn\\.kv_b_proj).*$" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # GPU 3: layers 45–60 - match: name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.(?!self_attn\\.kv_b_proj).*$" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # === MLP (MoE) Replacement === # GPU 0: layers 0–14 - match: name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" # GPU 1: layers 15–29 - match: name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" # GPU 2: layers 30–44 - match: name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" # GPU 3: layers 45–60 - match: name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" # === MLP Gate Replacement === # GPU 0: layers 0–14 - match: name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" # GPU 1: layers 15–29 - match: name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" # GPU 2: layers 30–44 - match: name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" # GPU 3: layers 45–60 - match: name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" # === MLP Experts Replacement === # replace with marlin expert. Open and modify layer-num as needed. # Each layer of malin experts takes about 6GB of GPU memory. # !!!Do remember 'close' cuda graph if you are using marlin expert.!!! # !!!KExpertsTorch is untested, we don't have enough VRAM.!!! # GPU 0: layers 3–4 # - match: # name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$" # replace: # class: ktransformers.operators.experts.KTransformersExperts # kwargs: # generate_device: "cuda:0" # generate_op: "KExpertsMarlin" # recursive: False # # GPU 1: layers 15–17 # - match: # name: "^model\\.layers\\.(1[5-7])\\.mlp\\.experts$" # replace: # class: ktransformers.operators.experts.KTransformersExperts # kwargs: # generate_device: "cuda:1" # generate_op: "KExpertsMarlin" # recursive: False # # GPU 2: layers 30–32 # - match: # name: "^model\\.layers\\.(3[0-2])\\.mlp\\.experts$" # replace: # class: ktransformers.operators.experts.KTransformersExperts # kwargs: # generate_device: "cuda:2" # generate_op: "KExpertsMarlin" # recursive: False # # GPU 3: layers 45–46 # - match: # name: "^model\\.layers\\.(4[5-6])\\.mlp\\.experts$" # replace: # class: ktransformers.operators.experts.KTransformersExperts # kwargs: # generate_device: "cuda:3" # generate_op: "KExpertsMarlin" # recursive: False # === MLP Experts Replacement === # GPU 0: layers 0–14 - match: name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts kwargs: prefill_device: "cuda:0" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:0" recursive: False # GPU 1: layers 15–29 - match: name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts kwargs: prefill_device: "cuda:1" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:1" recursive: False # GPU 2: layers 30–44 - match: name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts kwargs: prefill_device: "cuda:2" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:2" recursive: False # GPU 3: layers 45–60 - match: name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts kwargs: prefill_device: "cuda:3" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:3" recursive: False # === Self-Attention Replacement === # GPU 0: layers 0–14 - match: name: "^model\\.layers\\.([0-9]|1[0-4])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" absorb_for_prefill: False # GPU 1: layers 15–29 - match: name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" absorb_for_prefill: False # GPU 2: layers 30–44 - match: name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" absorb_for_prefill: False # GPU 3: layers 45–60 - match: name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" absorb_for_prefill: False # === Overall Model Replacement with Transfer Map === - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 means close layer‐wise prefill transfer_map: 15: "cuda:1" # Layers 15+ on GPU 1 30: "cuda:2" # Layers 30+ on GPU 2 45: "cuda:3" # Layers 45+ on GPU 3 # === Default Catch-All for Other Modules === # GPU 0: layers 0–14 - match: name: "^model\\.layers\\.([0-9]|1[0-4])\\." replace: class: "default" kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" # GPU 1: layers 15–29 - match: name: "^model\\.layers\\.(1[5-9]|2[0-9])\\." replace: class: "default" kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" # GPU 2: layers 30–44 - match: name: "^model\\.layers\\.(3[0-9]|4[0-4])\\." replace: class: "default" kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" - match: name: "^lm_head" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # For final modules (model.norm), ensure they are on GPU 3 (as in your original config) - match: name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)" replace: class: "default" kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml ================================================ - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" # === Rotary Embedding Replacement === # GPU 0: layers 0–7 - match: name: "^model\\.layers\\.([0-7])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" # GPU 1: layers 8–15 - match: name: "^model\\.layers\\.(8|9|1[0-5])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" # GPU 2: layers 16–23 - match: name: "^model\\.layers\\.(1[6-9]|2[0-3])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" # GPU 3: layers 24–31 - match: name: "^model\\.layers\\.(2[4-9]|3[0-1])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" # GPU 4: layers 32–39 - match: name: "^model\\.layers\\.([3][2-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:4" prefill_device: "cuda:4" # GPU 5: layers 40–47 - match: name: "^model\\.layers\\.(4[0-7])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:5" prefill_device: "cuda:5" # GPU 6: layers 48–55 - match: name: "^model\\.layers\\.(4[8-9]|5[0-5])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:6" prefill_device: "cuda:6" # GPU 7: layers 56–60 - match: name: "^model\\.layers\\.(5[6-9]|60)\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:7" prefill_device: "cuda:7" # === Linear Layers Replacement (excluding self_attn.kv_b_proj) === # GPU 0: layers 0–7 - match: name: "^model\\.layers\\.([0-7])\\.(?!self_attn\\.kv_b_proj).*$" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # GPU 1: layers 8–15 - match: name: "^model\\.layers\\.(8|9|1[0-5])\\.(?!self_attn\\.kv_b_proj).*$" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # GPU 2: layers 16–23 - match: name: "^model\\.layers\\.(1[6-9]|2[0-3])\\.(?!self_attn\\.kv_b_proj).*$" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # GPU 3: layers 24–31 - match: name: "^model\\.layers\\.(2[4-9]|3[0-1])\\.(?!self_attn\\.kv_b_proj).*$" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # GPU 4: layers 32–39 - match: name: "^model\\.layers\\.(3[2-9])\\.(?!self_attn\\.kv_b_proj).*$" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:4" prefill_device: "cuda:4" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # GPU 5: layers 40–47 - match: name: "^model\\.layers\\.(4[0-7])\\.(?!self_attn\\.kv_b_proj).*$" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:5" prefill_device: "cuda:5" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # GPU 6: layers 48–55 - match: name: "^model\\.layers\\.(4[8-9]|5[0-5])\\.(?!self_attn\\.kv_b_proj).*$" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:6" prefill_device: "cuda:6" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # GPU 7: layers 56–63 - match: name: "^model\\.layers\\.(5[6-9]|60)\\.(?!self_attn\\.kv_b_proj).*$" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:7" prefill_device: "cuda:7" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # === MLP (MoE) Replacement === # GPU 0: layers 0–7 - match: name: "^model\\.layers\\.([0-7])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" # GPU 1: layers 8–15 - match: name: "^model\\.layers\\.(8|9|1[0-5])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" # GPU 2: layers 16–23 - match: name: "^model\\.layers\\.(1[6-9]|2[0-3])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" # GPU 3: layers 24–31 - match: name: "^model\\.layers\\.(2[4-9]|3[0-1])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" # GPU 4: layers 32–39 - match: name: "^model\\.layers\\.(3[2-9])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE kwargs: generate_device: "cuda:4" prefill_device: "cuda:4" # GPU 5: layers 40–47 - match: name: "^model\\.layers\\.(4[0-7])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE kwargs: generate_device: "cuda:5" prefill_device: "cuda:5" # GPU 6: layers 48–55 - match: name: "^model\\.layers\\.(4[8-9]|5[0-5])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE kwargs: generate_device: "cuda:6" prefill_device: "cuda:6" # GPU 7: layers 56–60 - match: name: "^model\\.layers\\.(5[6-9]|60)\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE kwargs: generate_device: "cuda:7" prefill_device: "cuda:7" # === MLP Gate Replacement === # GPU 0: layers 0–7 - match: name: "^model\\.layers\\.([0-7])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" # GPU 1: layers 8–15 - match: name: "^model\\.layers\\.(8|9|1[0-5])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" # GPU 2: layers 16–23 - match: name: "^model\\.layers\\.(1[6-9]|2[0-3])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" # GPU 3: layers 24–31 - match: name: "^model\\.layers\\.(2[4-9]|3[0-1])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" # GPU 4: layers 32–39 - match: name: "^model\\.layers\\.(3[2-9])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:4" prefill_device: "cuda:4" # GPU 5: layers 40–47 - match: name: "^model\\.layers\\.(4[0-7])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:5" prefill_device: "cuda:5" # GPU 6: layers 48–55 - match: name: "^model\\.layers\\.(4[8-9]|5[0-5])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:6" prefill_device: "cuda:6" # GPU 7: layers 56–60 - match: name: "^model\\.layers\\.(5[6-9]|60)\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:7" prefill_device: "cuda:7" # === MLP Experts Replacement === # replace with marlin expert. Open and modify layer-num as needed. # Each layer of malin experts takes about 6GB of GPU memory. # !!!Do remember 'close' cuda graph if you are using marlin expert.!!! # !!!Loading marlin expert will take signifcant time.!!! # GPU 0: layers 0–7 # - match: # name: "^model\\.layers\\.([0-7])\\.mlp\\.experts$" # inject experts in layer 0~4 as marlin expert # replace: # class: ktransformers.operators.experts.KTransformersExperts # kwargs: # generate_device: "cuda:0" # generate_op: "KExpertsMarlin" # recursive: False # # GPU 1: layers 8–15 # - match: # name: "^model\\.layers\\.([8-9]|1[0-5)\\.mlp\\.experts$" # inject experts in layer 30~31 as marlin expert # replace: # class: ktransformers.operators.experts.KTransformersExperts # kwargs: # generate_device: "cuda:1" # generate_op: "KExpertsMarlin" # recursive: False # # GPU 2: layers 16–23 # - match: # name: "^model\\.layers\\.(1[6-9]|2[0-3])\\.mlp\\.experts$" # inject experts in layer 0~4 as marlin expert # replace: # class: ktransformers.operators.experts.KTransformersExperts # kwargs: # generate_device: "cuda:0" # generate_op: "KExpertsMarlin" # recursive: False # # GPU 3: layers 24–31 # - match: # name: "^model\\.layers\\.(2[4-9]|3[0-1])\\.mlp\\.experts$" # inject experts in layer 30~31 as marlin expert # replace: # class: ktransformers.operators.experts.KTransformersExperts # kwargs: # generate_device: "cuda:1" # generate_op: "KExpertsMarlin" # recursive: False # # GPU 4: layers 32–39 # - match: # name: "^model\\.layers\\.(3[2-9])\\.mlp\\.experts$" # inject experts in layer 0~4 as marlin expert # replace: # class: ktransformers.operators.experts.KTransformersExperts # kwargs: # generate_device: "cuda:0" # generate_op: "KExpertsMarlin" # recursive: False # # GPU 5: layers 40–47 # - match: # name: "^model\\.layers\\.(4[0-7])\\.mlp\\.experts$" # inject experts in layer 30~31 as marlin expert # replace: # class: ktransformers.operators.experts.KTransformersExperts # kwargs: # generate_device: "cuda:1" # generate_op: "KExpertsMarlin" # recursive: False # # GPU 6: layers 48–55 # - match: # name: "^model\\.layers\\.(4[8-9]|5[0-5])\\.mlp\\.experts$" # inject experts in layer 0~4 as marlin expert # replace: # class: ktransformers.operators.experts.KTransformersExperts # kwargs: # generate_device: "cuda:0" # generate_op: "KExpertsMarlin" # recursive: False # # GPU 7: layers 56–60 # - match: # name: "^model\\.layers\\.(5[6-9]|60)\\.mlp\\.experts$" # inject experts in layer 30~31 as marlin expert # replace: # class: ktransformers.operators.experts.KTransformersExperts # kwargs: # generate_device: "cuda:1" # generate_op: "KExpertsMarlin" # recursive: False # === MLP Experts Replacement === # GPU 0: layers 0–7 - match: name: "^model\\.layers\\.([0-7])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts kwargs: prefill_device: "cuda:0" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:0" recursive: False # GPU 1: layers 8–15 - match: name: "^model\\.layers\\.(8|9|1[0-5])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts kwargs: prefill_device: "cuda:1" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:1" recursive: False # GPU 2: layers 16–23 - match: name: "^model\\.layers\\.(1[6-9]|2[0-3])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts kwargs: prefill_device: "cuda:2" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:2" recursive: False # GPU 3: layers 24–31 - match: name: "^model\\.layers\\.(2[4-9]|3[0-1])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts kwargs: prefill_device: "cuda:3" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:3" recursive: False # GPU 4: layers 32–39 - match: name: "^model\\.layers\\.(3[2-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts kwargs: prefill_device: "cuda:4" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:4" recursive: False # GPU 5: layers 40–47 - match: name: "^model\\.layers\\.(4[0-7])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts kwargs: prefill_device: "cuda:5" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:5" recursive: False # GPU 6: layers 48–55 - match: name: "^model\\.layers\\.(4[8-9]|5[0-5])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts kwargs: prefill_device: "cuda:6" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:6" recursive: False # GPU 7: layers 56–60 - match: name: "^model\\.layers\\.(5[6-9]|60)\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts kwargs: prefill_device: "cuda:7" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:7" recursive: False # === Self-Attention Replacement === # GPU 0: layers 0–7 - match: name: "^model\\.layers\\.([0-7])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" # GPU 1: layers 8–15 - match: name: "^model\\.layers\\.(8|9|1[0-5])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" # GPU 2: layers 16–23 - match: name: "^model\\.layers\\.(1[6-9]|2[0-3])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" # GPU 3: layers 24–31 - match: name: "^model\\.layers\\.(2[4-9]|3[0-1])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" # GPU 4: layers 32–39 - match: name: "^model\\.layers\\.(3[2-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention kwargs: generate_device: "cuda:4" prefill_device: "cuda:4" # GPU 5: layers 40–47 - match: name: "^model\\.layers\\.(4[0-7])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention kwargs: generate_device: "cuda:5" prefill_device: "cuda:5" # GPU 6: layers 48–55 - match: name: "^model\\.layers\\.(4[8-9]|5[0-5])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention kwargs: generate_device: "cuda:6" prefill_device: "cuda:6" # GPU 7: layers 56–60 - match: name: "^model\\.layers\\.(5[6-9]|60)\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention kwargs: generate_device: "cuda:7" prefill_device: "cuda:7" # === Overall Model Replacement with Transfer Map === - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 means close layer‐wise prefill transfer_map: 8: "cuda:1" 16: "cuda:2" 24: "cuda:3" 32: "cuda:4" 40: "cuda:5" 48: "cuda:6" 56: "cuda:7" # === Default Catch-All for Other Modules === # GPU 0: layers 0–7 - match: name: "^model\\.layers\\.([0-7])\\." replace: class: "default" kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" # GPU 1: layers 8–15 - match: name: "^model\\.layers\\.(8|9|1[0-5])\\." replace: class: "default" kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" # GPU 2: layers 16–23 - match: name: "^model\\.layers\\.(1[6-9]|2[0-3])\\." replace: class: "default" kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" # GPU 3: layers 24–31 - match: name: "^model\\.layers\\.(2[4-9]|3[0-1])\\." replace: class: "default" kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" # GPU 4: layers 32–39 - match: name: "^model\\.layers\\.(3[2-9])\\." replace: class: "default" kwargs: generate_device: "cuda:4" prefill_device: "cuda:4" # GPU 5: layers 40–47 - match: name: "^model\\.layers\\.(4[0-7])\\." replace: class: "default" kwargs: generate_device: "cuda:5" prefill_device: "cuda:5" # GPU 6: layers 48–55 - match: name: "^model\\.layers\\.(4[8-9]|5[0-5])\\." replace: class: "default" kwargs: generate_device: "cuda:6" prefill_device: "cuda:6" # GPU 7: layers 56–63 - match: name: "^model\\.layers\\.(5[6-9]|60)\\." replace: class: "default" kwargs: generate_device: "cuda:7" prefill_device: "cuda:7" - match: name: "^lm_head" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:7" prefill_device: "cuda:7" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # For final modules (model.norm), ensure they are on GPU 7 (as in your original config) - match: name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)" replace: class: "default" kwargs: generate_device: "cuda:7" prefill_device: "cuda:7" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml ================================================ - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" generate_op: "KLinearFP8" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" generate_op: "KLinearFP8" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda:0" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:0" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda:1" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:1" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" absorb_for_prefill: False # change this to True to enable long context(prefill may slower). - match: name: "^model\\.layers\\.([3456][0-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" absorb_for_prefill: False # change this to True to enable long context(prefill may slower). - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill transfer_map: 30: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." replace: class: "default" kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^lm_head" class: torch.nn.Linear replace: class: "default" kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)" replace: class: "default" kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml ================================================ - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-4])\\.mlp\\.experts$" # inject experts in layer 0~4 as marlin expert replace: class: ktransformers.operators.experts.KTransformersExperts kwargs: generate_device: "cuda:0" # run in cuda:0 generate_op: "KExpertsMarlin" recursive: False - match: name: "^model\\.layers\\.([3][0])\\.mlp\\.experts$" # inject experts in layer 30~31 as marlin expert replace: class: ktransformers.operators.experts.KTransformersExperts kwargs: generate_device: "cuda:1" generate_op: "KExpertsMarlin" recursive: False - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda:0" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:0" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda:1" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:1" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill transfer_map: 30: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." replace: class: "default" kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^lm_head" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)" replace: class: "default" kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml ================================================ - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\." class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda:0" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:0" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda:1" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:1" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill transfer_map: 30: "cuda:1" - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." replace: class: "default" kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^lm_head" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)" replace: class: "default" kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-npu.yaml ================================================ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "npu" prefill_device: "npu" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "npu" prefill_device: "npu" generate_op: "KLinearTorch" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "npu" prefill_device: "npu" generate_op: "KLinearTorch" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: generate_device: "npu" prefill_device: "npu" - match: class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "npu:0" prefill_device: "npu:0" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "npu" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "npu" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "npu" prefill_device: "npu" absorb_for_prefill: False # change this to True to enable long context(prefill may slower). - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml ================================================ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "VLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "VLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoEV2 # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" absorb_for_prefill: False # change this to True to enable long context(prefill may slower). - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm replace: class: ktransformers.operators.layernorm.RMSNorm kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP replace: class: ktransformers.operators.mlp.kDeepseekV3MLP kwargs: generate_device: "cuda" prefill_device: "cuda" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml ================================================ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" absorb_for_prefill: False # change this to True to enable long context(prefill may slower). - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml ================================================ - match: class: ktransformers.models.modeling_glm4_moe.Glm4MoeRotaryEmbedding replace: class: ktransformers.operators.RoPE.KGlm4MoeRotaryEmbedding kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "VLinearMarlin" prefill_op: "KLinearTorch" # - match: # name: "^model\\.layers\\..*$" # regular expression # class: torch.nn.Linear # only match modules matching name and class simultaneously # replace: # class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types # kwargs: # generate_device: "cuda" # prefill_device: "cuda" # generate_op: "VLinearMarlin" # prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_glm4_moe.Glm4MoeMoE replace: class: ktransformers.operators.experts.KGlm4MoeMoE kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KGlm4Experts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: None generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.balance_serve_attention.KGlm4MoeAttention # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: class: ktransformers.models.modeling_glm4_moe.Glm4MoeRMSNorm replace: class: ktransformers.operators.layernorm.KGlm4MoeRMSNorm kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_glm4_moe.Glm4MoeMLP replace: class: ktransformers.operators.mlp.KGlm4MoeMLP kwargs: generate_device: "cuda" prefill_device: "cuda" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/Internlm2_5-7b-Chat-1m.yaml ================================================ - match: class: ktransformers.models.modeling_llama.LlamaRotaryEmbedding replace: class: ktransformers.operators.RoPE.RotaryEmbeddingV2 - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: class: ktransformers.models.modeling_llama.LlamaModel replace: class: ktransformers.operators.models.KLlamaModel kwargs: generate_device: "cuda" prefill_device: "cuda" per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.attention.KLlamaAttention kwargs: generate_device: "cuda" prefill_device: "cuda" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/Mixtral.yaml ================================================ - match: class: ktransformers.models.modeling_mixtral.MixtralRotaryEmbedding replace: class: ktransformers.operators.RoPE.RotaryEmbedding kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\..*$" class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^lm_head" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.block_sparse_moe$" class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock replace: class: ktransformers.operators.experts.KMistralSparseMoEBlock - match: name: "^model\\.layers\\..*\\.block_sparse_moe\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: name: "^model\\.layers\\..*\\." replace: class: "default" kwargs: generate_device: "cuda" prefill_device: "cuda" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml ================================================ - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "VLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "VLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoEV2 # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" absorb_for_prefill: False # change this to True to enable long context(prefill may slower). - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm replace: class: ktransformers.operators.layernorm.RMSNorm kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP replace: class: ktransformers.operators.mlp.kDeepseekV3MLP kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.RotaryEmbeddingV4 kwargs: generate_device: "cuda" prefill_device: "cuda" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml ================================================ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.RotaryEmbeddingV3 kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module # if want to use more VRAM, use experts Marlin and disable CUDA Graph(disable CUDA Graph may cause low performance) #- match: # name: "^model\\.layers\\..*\\.mlp\\.experts$" # replace: # class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism # kwargs: # prefill_device: "cuda" # prefill_op: "KExpertsTorch" # generate_device: "cuda" # generate_op: "KExpertsMarlin" # recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml ================================================ - match: name: "^model\\.layers\\.([012])\\." class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding replace: class: ktransformers.operators.RoPE.RotaryEmbedding kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([012])$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.([012])\\.mlp$" class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock replace: class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlock # mlp module with custom forward function - match: name: "^model\\.layers\\.([012])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism # device: "cpu" # which devices to load this module when initializing kwargs: prefill_device: "cuda:0" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:0" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\.([12][0-9]|[3-9])\\." class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding replace: class: ktransformers.operators.RoPE.RotaryEmbedding kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model\\.layers\\.([12][0-9]|[3-9])$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.([12][0-9]|[3-9])\\.mlp$" class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock replace: class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlock # mlp module with custom forward function - match: name: "^model\\.layers\\.([12][0-9]|[3-9])\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism # device: "cpu" # which devices to load this module when initializing kwargs: prefill_device: "cuda:1" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda:1" recursive: False # don't recursively inject submodules of this module - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: name: "^lm_head" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "(^model.norm)" replace: class: "default" kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" - match: name: "^model$" replace: class: "ktransformers.operators.models.KQwen2MoeModel" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill transfer_map: 3: "cuda:1" - match: name: "^model\\.layers\\.([012])\\." replace: class: "default" kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([12][0-9]|[3-9])\\." replace: class: "default" kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml ================================================ - match: class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding replace: class: ktransformers.operators.RoPE.RotaryEmbedding kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\..*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^lm_head" class: torch.nn.Linear replace: class: ktransformers.operators.linear.KTransformersLinear kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock replace: class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlock # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism # device: "cpu" # which devices to load this module when initializing kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: name: "^model$" replace: class: "ktransformers.operators.models.KQwen2MoeModel" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: name: "^model\\.layers\\..*\\." replace: class: "default" kwargs: generate_device: "cuda" prefill_device: "cuda" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/Qwen2-serve-amx.yaml ================================================ - match: class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding replace: class: ktransformers.operators.RoPE.RotaryEmbedding kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # - match: # name: "^model\\.layers\\..*$" # regular expression # class: torch.nn.Linear # only match modules matching name and class simultaneously # replace: # class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types # kwargs: # generate_device: "cuda" # prefill_device: "cuda" # generate_op: "VLinearMarlin" # prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "VLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock replace: class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlockV2 # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default) recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.balance_serve_attention.KQwen2MoeAttention # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model$" replace: class: "ktransformers.operators.models.KQwen2MoeModel" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRMSNorm replace: class: ktransformers.operators.layernorm.KQwen2MoeRMSNorm kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeMLP replace: class: ktransformers.operators.mlp.KQwen2MoeMLP kwargs: generate_device: "cuda" prefill_device: "cuda" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/Qwen2-serve.yaml ================================================ - match: class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding replace: class: ktransformers.operators.RoPE.RotaryEmbedding kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" # - match: # name: "^model\\.layers\\..*$" # regular expression # class: torch.nn.Linear # only match modules matching name and class simultaneously # replace: # class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types # kwargs: # generate_device: "cuda" # prefill_device: "cuda" # generate_op: "VLinearMarlin" # prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "VLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock replace: class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlockV2 # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.balance_serve_attention.KQwen2MoeAttention # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model$" replace: class: "ktransformers.operators.models.KQwen2MoeModel" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRMSNorm replace: class: ktransformers.operators.layernorm.KQwen2MoeRMSNorm kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeMLP replace: class: ktransformers.operators.mlp.KQwen2MoeMLP kwargs: generate_device: "cuda" prefill_device: "cuda" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/Qwen3Moe-serve-amx.yaml ================================================ - match: class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding replace: class: ktransformers.operators.RoPE.RotaryEmbedding kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "VLinearMarlin" prefill_op: "KLinearTorch" # - match: # name: "^model\\.layers\\..*$" # regular expression # class: torch.nn.Linear # only match modules matching name and class simultaneously # replace: # class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types # kwargs: # generate_device: "cuda" # prefill_device: "cuda" # generate_op: "VLinearMarlin" # prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock replace: class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2 # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" backend: "AMXBF16" # or "AMXBF16" or "llamafile" (default) recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.balance_serve_attention.KQwen3MoeAttention # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model$" replace: class: "ktransformers.operators.models.KQwen2MoeModel" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeRMSNorm replace: class: ktransformers.operators.layernorm.KQwen3MoeRMSNorm kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeMLP replace: class: ktransformers.operators.mlp.KQwen2MoeMLP kwargs: generate_device: "cuda" prefill_device: "cuda" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml ================================================ - match: class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding replace: class: ktransformers.operators.RoPE.RotaryEmbedding kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "VLinearMarlin" prefill_op: "KLinearTorch" # - match: # name: "^model\\.layers\\..*$" # regular expression # class: torch.nn.Linear # only match modules matching name and class simultaneously # replace: # class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types # kwargs: # generate_device: "cuda" # prefill_device: "cuda" # generate_op: "VLinearMarlin" # prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock replace: class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2 # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.balance_serve_attention.KQwen3MoeAttention # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model$" replace: class: "ktransformers.operators.models.KQwen2MoeModel" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeRMSNorm replace: class: ktransformers.operators.layernorm.KQwen3MoeRMSNorm kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeMLP replace: class: ktransformers.operators.mlp.KQwen2MoeMLP kwargs: generate_device: "cuda" prefill_device: "cuda" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/Qwen3Next-serve.yaml ================================================ - match: class: ktransformers.models.modeling_qwen3_next.Qwen3NextRotaryEmbedding replace: class: ktransformers.operators.RoPE.KQwen3MoeRotaryEmbedding kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "VLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_qwen3_next.Qwen3NextSparseMoeBlock replace: class: ktransformers.operators.experts.KQwen3NextSparseMoeBlockV2 # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: class: ktransformers.models.modeling_qwen3_next.Qwen3NextGatedDeltaNet replace: class: ktransformers.operators.balance_serve_attention.KQwen3NextGatedDeltaNet # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_qwen3_next.Qwen3NextAttention replace: class: ktransformers.operators.balance_serve_attention.KQwen3NextAttention # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: class: ktransformers.models.modeling_qwen3_next.Qwen3NextRMSNorm replace: class: ktransformers.operators.layernorm.KQwen3NextRMSNorm kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_qwen3_next.Qwen3NextMLP replace: class: ktransformers.operators.mlp.KQwen2MoeMLP kwargs: generate_device: "cuda" prefill_device: "cuda" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/Smallthinker-serve.yaml ================================================ - match: class: ktransformers.models.modeling_smallthinker.SmallthinkerRotaryEmbedding replace: class: ktransformers.operators.RoPE.KSmallthinkerRotaryEmbedding kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "VLinearMarlin" prefill_op: "KLinearTorch" # - match: # name: "^model\\.layers\\..*$" # regular expression # class: torch.nn.Linear # only match modules matching name and class simultaneously # replace: # class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types # kwargs: # generate_device: "cuda" # prefill_device: "cuda" # generate_op: "VLinearMarlin" # prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(?!.*feed_forward\\.shared_expert_gate).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearMarlin" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.block_sparse_moe$" class: ktransformers.models.modeling_smallthinker.SmallthinkerMoeBlock replace: class: ktransformers.operators.experts.KSmallthinkerMoeBlock kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model\\.layers\\..*\\.block_sparse_moe\\.experts$" replace: class: ktransformers.operators.experts.KSmallthinkerExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: None generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.balance_serve_attention.KSmallthinkerAttention # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: class: ktransformers.models.modeling_smallthinker.SmallthinkerRMSNorm replace: class: ktransformers.operators.layernorm.KSmallthinkerRMSNorm kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_smallthinker.SmallthinkerDenseMlpBlock replace: class: ktransformers.operators.mlp.KSmallthinkerDenseMlpBlock kwargs: generate_device: "cuda" prefill_device: "cuda" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/npu/DeepSeek-V3-Chat-300IA2-npu-serve.yaml ================================================ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "npu" prefill_device: "npu" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2 # optimized Kernel on quantized data types kwargs: generate_device: "npu" prefill_device: "npu" generate_op: "KLinearTorchW8A8A2" prefill_op: "KLinearTorchW8A8A2" - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2 # optimized Kernel on quantized data types kwargs: generate_device: "npu" prefill_device: "npu" generate_op: "KLinearTorchW8A8A2" prefill_op: "KLinearTorchW8A8A2" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.ascend.ascend_experts.KDeepseekV3MoEW8A8 # mlp module with custom forward function kwargs: generate_device: "npu" prefill_device: "npu" - match: name: "^model\\.layers\\.([0-2])\\.mlp$" class: "ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP" replace: class: "ktransformers.operators.ascend.ascend_mlp.KDeepseekV3MLPW8A8A2V1" kwargs: generate_device: "npu" prefill_device: "npu" - match: name: "^model\\.layers\\..*\\.mlp\\.shared_experts$" class: "ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP" replace: class: "ktransformers.operators.ascend.ascend_mlp.KDeepseekV3MLPW8A8A2V2" kwargs: generate_device: "npu" prefill_device: "npu" - match: class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.ascend.ascend_gate.KDeepseekV3GateA2 kwargs: generate_device: "npu:0" prefill_device: "npu:0" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.ascend.ascend_experts.KTransformersExpertsW8A8 kwargs: prefill_device: "npu" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPUW8A8" out_device: "npu" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" class: ktransformers.operators.experts.KExpertsCPU replace: class: ktransformers.operators.ascend.ascend_experts.KExpertsCPUW8A8 - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.ascend.ascend_attention.KDeepseekV2AttentionW8A8A2Serve # optimized MLA implementation kwargs: generate_device: "npu" prefill_device: "npu" absorb_for_prefill: False # change this to True to enable long context(prefill may slower). - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: name: "^model..*norm" replace: class: ktransformers.operators.ascend.ascend_layernorm.KDeepseekV3RMSNormW8A8 kwargs: generate_device: "npu" prefill_device: "npu" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/npu/DeepSeek-V3-Chat-300IA2-npu.yaml ================================================ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "npu" prefill_device: "npu" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2 # optimized Kernel on quantized data types kwargs: generate_device: "npu" prefill_device: "npu" generate_op: "KLinearTorchW8A8A2" prefill_op: "KLinearTorchW8A8A2" - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2 # optimized Kernel on quantized data types kwargs: generate_device: "npu" prefill_device: "npu" generate_op: "KLinearTorchW8A8A2" prefill_op: "KLinearTorchW8A8A2" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.ascend.ascend_experts.KDeepseekV3MoEW8A8 # mlp module with custom forward function kwargs: generate_device: "npu" prefill_device: "npu" - match: name: "^model\\.layers\\.([0-2])\\.mlp$" class: "ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP" replace: class: "ktransformers.operators.ascend.ascend_mlp.KDeepseekV3MLPW8A8A2V1" kwargs: generate_device: "npu" prefill_device: "npu" - match: name: "^model\\.layers\\..*\\.mlp\\.shared_experts$" class: "ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP" replace: class: "ktransformers.operators.ascend.ascend_mlp.KDeepseekV3MLPW8A8A2V2" kwargs: generate_device: "npu" prefill_device: "npu" - match: class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.ascend.ascend_gate.KDeepseekV3GateA2 kwargs: generate_device: "npu:0" prefill_device: "npu:0" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.ascend.ascend_experts.KTransformersExpertsW8A8 kwargs: prefill_device: "npu" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPUW8A8" out_device: "npu" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" class: ktransformers.operators.experts.KExpertsCPU replace: class: ktransformers.operators.ascend.ascend_experts.KExpertsCPUW8A8 - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.ascend.ascend_attention.KDeepseekV2AttentionW8A8A2 # optimized MLA implementation kwargs: generate_device: "npu" prefill_device: "npu" absorb_for_prefill: False # change this to True to enable long context(prefill may slower). - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: name: "^model..*norm" replace: class: ktransformers.operators.ascend.ascend_layernorm.KDeepseekV3RMSNormW8A8 kwargs: generate_device: "npu" prefill_device: "npu" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/npu/Qwen3-Chat-300IA2-npu-serve.yaml ================================================ - match: class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding replace: class: ktransformers.operators.RoPE.RotaryEmbedding kwargs: generate_device: "npu" prefill_device: "npu" - match: name: "^lm_head$" class: torch.nn.Linear replace: class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2 kwargs: generate_device: "npu" prefill_device: "npu" generate_op: "KLinearTorchW8A8A2" prefill_op: "KLinearTorchW8A8A2" - match: name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate)(?!.*mlp\\.gate)(?!.*mlp\\.experts).*$" class: torch.nn.Linear replace: class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2 kwargs: generate_device: "npu" prefill_device: "npu" generate_op: "KLinearTorchW8A8A2" prefill_op: "KLinearTorchW8A8A2" - match: name: "^model\\.layers\\.(?!.*mlp\\.gate)(?!.*self_attn\\.kv_b_proj)(?!.*mlp\\.experts).*$" class: torch.nn.Linear replace: class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2 kwargs: generate_device: "npu" prefill_device: "npu" generate_op: "KLinearTorchW8A8A2" prefill_op: "KLinearTorchW8A8A2" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock replace: class: ktransformers.operators.ascend.ascend_experts.KQwen3MoeSparseMoeBlockW8A8 kwargs: generate_device: "npu" prefill_device: "npu" dump_enable: False dump_dir: "/mnt/dump_from_mindie/dump_from_kt_moe" - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.ascend.ascend_attention.KQwen3MoeAttentionW8A8A2Serve kwargs: generate_device: "npu" prefill_device: "npu" absorb_for_prefill: False dump_enable: False dump_dir: "/mnt/dump_from_mindie/dump_from_kt_attn" - match: name: "^model$" replace: class: "ktransformers.operators.models.KQwen2MoeModel" kwargs: per_layer_prefill_intput_threshold: 0 - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeRMSNorm replace: class: ktransformers.operators.ascend.ascend_layernorm.KQwen3MoeRMSNormW8A8 kwargs: generate_device: "npu" prefill_device: "npu" dump_enable: False dump_dir: "/mnt/dump_from_mindie/dump_from_kt_rms" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml ================================================ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "cuda" prefill_device: "cuda" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cpu" prefill_device: "cuda" generate_op: "KLinearCPUInfer" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "cuda" prefill_device: "cuda" generate_op: "KLinearQ8" prefill_op: "KLinearTorch" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: generate_device: "cuda" prefill_device: "cuda" - match: class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGate kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "cuda" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" absorb_for_prefill: False # change this to True to enable long context(prefill may slower). - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml ================================================ - match: class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbedding kwargs: generate_device: "xpu" prefill_device: "xpu" - match: name: "^model\\.layers\\..*" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "xpu" prefill_device: "xpu" generate_op: "KLinearIPEXLLM" prefill_op: "KLinearIPEXLLM" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek.DeepseekV2MoE replace: class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function kwargs: generate_device: "xpu" prefill_device: "xpu" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "xpu" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "xpu" recursive: False # don't recursively inject submodules of this module - match: class: ktransformers.models.modeling_deepseek.DeepseekV2RMSNorm replace: class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM kwargs: generate_device: "xpu" prefill_device: "xpu" - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "xpu" prefill_device: "xpu" - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill device: "xpu" - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml ================================================ - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding replace: class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 kwargs: generate_device: "xpu" prefill_device: "xpu" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "xpu" prefill_device: "xpu" generate_op: "KLinearIPEXLLM" prefill_op: "KLinearIPEXLLM" - match: name: "^model\\.layers\\..*" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "xpu" prefill_device: "xpu" generate_op: "KLinearIPEXLLM" prefill_op: "KLinearIPEXLLM" - match: name: "^model\\.layers\\..*\\.mlp$" class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: generate_device: "xpu" prefill_device: "xpu" - match: class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm replace: class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM kwargs: generate_device: "xpu" prefill_device: "xpu" - match: class: ktransformers.models.modeling_deepseek_v3.MoEGate replace: class: ktransformers.operators.gate.KMoEGateIPEXLLM kwargs: generate_device: "xpu:0" prefill_device: "xpu:0" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism kwargs: prefill_device: "xpu" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "xpu" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation kwargs: generate_device: "xpu" prefill_device: "xpu" absorb_for_prefill: False # change this to True to enable long context(prefill may slower). - match: name: "^model$" replace: class: "ktransformers.operators.models.KDeepseekV2Model" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" ================================================ FILE: archive/ktransformers/optimize/optimize_rules/xpu/Qwen3Moe-Chat.yaml ================================================ - match: name: "rotary_emb$" replace: class: ktransformers.operators.RoPE.KQwen3MoeRotaryEmbedding kwargs: generate_device: "xpu" prefill_device: "xpu" - match: name: "^lm_head$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "xpu" prefill_device: "xpu" generate_op: "KLinearIPEXLLM" prefill_op: "KLinearIPEXLLM" - match: name: "^model\\.layers\\.(?!.*mlp\\.gate).*$" # regular expression class: torch.nn.Linear # only match modules matching name and class simultaneously replace: class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types kwargs: generate_device: "xpu" prefill_device: "xpu" generate_op: "KLinearIPEXLLM" prefill_op: "KLinearIPEXLLM" - match: name: "^model\\.layers\\..*\\.mlp$" class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock replace: class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2 # mlp module with custom forward function kwargs: generate_device: "xpu" prefill_device: "xpu" - match: name: "^model\\.layers\\..*\\.mlp\\.experts$" replace: class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism kwargs: prefill_device: "xpu" prefill_op: "KExpertsTorch" generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "xpu" recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" replace: class: ktransformers.operators.attention.KQwen3MoeAttentionIPEXLLM kwargs: generate_device: "xpu" prefill_device: "xpu" - match: name: "^model$" replace: class: "ktransformers.operators.models.KQwen2MoeModel" kwargs: per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill - match: name: "^model.embed_tokens" replace: class: "default" kwargs: generate_device: "cpu" prefill_device: "cpu" - match: class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm replace: class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM kwargs: generate_device: "xpu" prefill_device: "xpu" - match: class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP replace: class: ktransformers.operators.mlp.KQwen2MoeMLP kwargs: generate_device: "xpu" prefill_device: "xpu" ================================================ FILE: archive/ktransformers/server/__init__.py ================================================ ================================================ FILE: archive/ktransformers/server/api/__init__.py ================================================ from fastapi import APIRouter from .ollama import router as ollama_router from .openai import router as openai_router,post_db_creation_operations from .web import router as web_router router = APIRouter() router.include_router(ollama_router) router.include_router(openai_router) router.include_router(web_router) ================================================ FILE: archive/ktransformers/server/api/ollama/__init__.py ================================================ from fastapi import APIRouter from .completions import router as completions_router router = APIRouter() router.include_router(completions_router) ================================================ FILE: archive/ktransformers/server/api/ollama/completions.py ================================================ from datetime import datetime from http.client import NOT_IMPLEMENTED import json from time import time from uuid import uuid4 from typing import List, Optional from fastapi import APIRouter, Request from pydantic import BaseModel, Field from ktransformers.server.config.config import Config from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.schemas.assistants.streaming import check_link_response from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.schemas.endpoints.chat import RawUsage router = APIRouter(prefix='/api') # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion class OllamaGenerateCompletionRequest(BaseModel): model: str = Field(..., description="The model name, which is required.") prompt: Optional[str] = Field( None, description="The prompt to generate a response for.") images: Optional[List[str]] = Field( None, description="A list of base64-encoded images for multimodal models such as llava.") # Advanced parameters format: Optional[str] = Field( None, description="The format to return a response in, accepted value is json.") options: Optional[dict] = Field( None, description="Additional model parameters as listed in the documentation.") system: Optional[str] = Field( None, description="System message to override what is defined in the Modelfile.") template: Optional[str] = Field( None, description="The prompt template to use, overriding what is defined in the Modelfile.") context: Optional[str] = Field( None, description="The context parameter from a previous request to keep a short conversational memory.") stream: Optional[bool] = Field( None, description="If false, the response will be returned as a single response object.") raw: Optional[bool] = Field( None, description="If true, no formatting will be applied to the prompt.") keep_alive: Optional[str] = Field( "5m", description="Controls how long the model will stay loaded into memory following the request.") class OllamaGenerationStreamResponse(BaseModel): model: str created_at: str response: str done: bool = Field(...) class OllamaGenerationResponse(BaseModel): model: str created_at: str response: str done: bool @router.post("/generate", tags=['ollama']) async def generate(request: Request, input: OllamaGenerateCompletionRequest): id = str(uuid4()) interface: BackendInterfaceBase = get_interface() print(f'COMPLETION INPUT:----\n{input.prompt}\n----') config = Config() if input.stream: async def inner(): async for res in interface.inference(input.prompt, id): if isinstance(res, RawUsage): raw_usage = res else: token, finish_reason = res d = OllamaGenerationStreamResponse( model=config.model_name, created_at=str(datetime.now()), response=token, done=False ) yield d.model_dump_json() + '\n' d = OllamaGenerationStreamResponse( model=config.model_name, created_at=str(datetime.now()), response='', done=True ) yield d.model_dump_json() + '\n' return check_link_response(request, inner()) else: complete_response = "" async for res in interface.inference(input.prompt, id): if isinstance(res, RawUsage): raw_usage = res else: token, finish_reason = res complete_response += token response = OllamaGenerationResponse( model=config.model_name, created_at=str(datetime.now()), response=complete_response, done=True ) return response # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion class OllamaChatCompletionMessage(BaseModel): role: str content: str class OllamaChatCompletionRequest(BaseModel): model: str = Field(..., description="The model name, which is required.") messages: List[OllamaChatCompletionMessage] = Field( ..., description="A list of messages to generate a response for.") stream: bool = Field(True, description="If true, the response will be streamed.") class OllamaChatCompletionStreamResponse(BaseModel): model: str created_at: str message: dict done: bool = Field(...) done_reason: Optional[str] = Field("", description="done_reason") total_duration: Optional[int] = Field(None, description="Total time spent in nanoseconds") load_duration: Optional[int] = Field(None, description="Time spent loading model in nanoseconds") prompt_eval_count: Optional[int] = Field(None, description="Number of tokens in prompt") prompt_eval_duration: Optional[int] = Field(None, description="Time spent evaluating prompt in nanoseconds") eval_count: Optional[int] = Field(None, description="Number of tokens generated") eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds") class OllamaChatCompletionResponse(BaseModel): model: str created_at: str message: dict done: bool done_reason: Optional[str] = Field("", description="done_reason") total_duration: Optional[int] = Field(None, description="Total time spent in nanoseconds") load_duration: Optional[int] = Field(None, description="Time spent loading model in nanoseconds") prompt_eval_count: Optional[int] = Field(None, description="Number of tokens in prompt") prompt_eval_duration: Optional[int] = Field(None, description="Time spent evaluating prompt in nanoseconds") eval_count: Optional[int] = Field(None, description="Number of tokens generated") eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds") @router.post("/chat", tags=['ollama']) async def chat(request: Request, input: OllamaChatCompletionRequest): id = str(uuid4()) interface: BackendInterfaceBase = get_interface() config = Config() input_message = [json.loads(m.model_dump_json()) for m in input.messages] if input.stream: async def inner(): start_time = time() # 记录开始时间(秒) tokens = [] async for res in interface.inference(input_message, id): if isinstance(res, RawUsage): raw_usage = res else: token, finish_reason = res d = OllamaChatCompletionStreamResponse( model=config.model_name, created_at=str(datetime.now()), message={"role": "assistant", "content": token}, done=False ) yield d.model_dump_json() + '\n' # 计算性能数据 end_time = time() total_duration = int((end_time - start_time) * 1_000_000_000) # unit: ns prompt_eval_count = raw_usage.prefill_count eval_count = raw_usage.decode_count eval_duration = int(raw_usage.decode_time * 1_000_000_000) prompt_eval_duration = int(raw_usage.prefill_time * 1_000_000_000) load_duration = int(raw_usage.tokenize_time * 1_000_000_000) done_reason = finish_reason d = OllamaChatCompletionStreamResponse( model=config.model_name, created_at=str(datetime.now()), message={}, done=True, total_duration=total_duration, load_duration=load_duration, prompt_eval_count=prompt_eval_count, prompt_eval_duration=prompt_eval_duration, eval_count=eval_count, eval_duration=eval_duration, done_reason=done_reason ) yield d.model_dump_json() + '\n' return check_link_response(request, inner()) else: start_time = time() complete_response = "" eval_count = 0 async for res in interface.inference(input_message, id): if isinstance(res, RawUsage): raw_usage = res else: token, finish_reason = res complete_response += token end_time = time() total_duration = int((end_time - start_time) * 1_000_000_000) # unit: ns prompt_eval_count = raw_usage.prefill_count eval_count = raw_usage.decode_count eval_duration = int(raw_usage.decode_time * 1_000_000_000) prompt_eval_duration = int(raw_usage.prefill_time * 1_000_000_000) load_duration = int(raw_usage.tokenize_time * 1_000_000_000) done_reason = finish_reason response = OllamaChatCompletionResponse( model=config.model_name, created_at=str(datetime.now()), message={"role": "assistant", "content": complete_response}, done=True, total_duration=total_duration, load_duration=load_duration, prompt_eval_count=prompt_eval_count, prompt_eval_duration=prompt_eval_duration, eval_count=eval_count, eval_duration=eval_duration, done_reason=done_reason ) return response # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models class OllamaModel(BaseModel): name: str modified_at: str size: int # TODO: fill the rest correctly # mock ollama @router.get("/tags", tags=['ollama']) async def tags(): config = Config() # TODO: fill this correctly, although it does not effect Tabby return {"models": [OllamaModel(name=config.model_name, modified_at="123", size=123)]} class OllamaModelInfo(BaseModel): # TODO: fill this correctly pass class OllamaShowRequest(BaseModel): name: str = Field(..., description="Name of the model to show") verbose: Optional[bool] = Field( None, description="If set to true, returns full data for verbose response fields") class OllamaShowDetial(BaseModel): parent_model: str format: str family: str families: List[str] parameter_size: str quantization_level: str class OllamaShowResponse(BaseModel): modelfile: str parameters: str template: str details: OllamaShowDetial model_info: OllamaModelInfo class Config: protected_namespaces = () @router.post("/show", tags=['ollama']) async def show(request: Request, input: OllamaShowRequest): config = Config() # TODO: Add more info in config to return, although it does not effect Tabby return OllamaShowResponse( modelfile="# Modelfile generated by ...", parameters=" ", template=" ", details=OllamaShowDetial( parent_model=" ", format="gguf", family=" ", families=[" "], parameter_size=" ", quantization_level=" " ), model_info=OllamaModelInfo() ) ================================================ FILE: archive/ktransformers/server/api/openai/__init__.py ================================================ from fastapi import APIRouter from .assistants import router as assistants_router,create_default_assistant from .endpoints.chat import router as chat_router from .legacy import router as legacy_router router = APIRouter(prefix='/v1') router.include_router(assistants_router) router.include_router(chat_router) router.include_router(legacy_router) def post_db_creation_operations(): create_default_assistant() ================================================ FILE: archive/ktransformers/server/api/openai/assistants/__init__.py ================================================ from fastapi import APIRouter from .assistants import router as assistants_router, create_default_assistant from .messages import router as messages_router from .runs import router as runs_router from .threads import router as threads_router router = APIRouter() threads_router.include_router(runs_router) threads_router.include_router(messages_router) router.include_router(assistants_router) router.include_router(threads_router) ================================================ FILE: archive/ktransformers/server/api/openai/assistants/assistants.py ================================================ from typing import Optional from fastapi import APIRouter from fastapi.testclient import TestClient from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager from ktransformers.server.crud.assistants.runs import RunsDatabaseManager from ktransformers.server.schemas.assistants.assistants import AssistantCreate, AssistantModify, ObjectID, AssistantBuildStatus, AssistantObject from ktransformers.server.schemas.base import DeleteResponse, Order from ktransformers.server.config.log import logger router = APIRouter(prefix="/assistants") assistant_manager = AssistantDatabaseManager() runs_manager = RunsDatabaseManager() @router.post("/", tags=['openai']) async def create_assistant( assistant: AssistantCreate, ): return assistant_manager.db_create_assistant(assistant).as_api_response() @router.get("/", tags=['openai']) async def list_assistants( limit: Optional[int] = 20, order: Order = Order.DESC, after: Optional[str] = None, before: Optional[str] = None, ): return [assistant.as_api_response() for assistant in assistant_manager.db_list_assistants(limit, order)] # list assistant with status @router.get("/status", tags=['openai-ext']) async def list_assistants_with_status( limit: Optional[int] = 20, order: Order = Order.DESC, after: Optional[str] = None, before: Optional[str] = None, ): return assistant_manager.db_list_assistants(limit, order) @router.get("/{assistant_id}", tags=['openai']) async def retrieve_assistant( assistant_id: str, ): return assistant_manager.db_get_assistant_by_id(assistant_id).as_api_response() @router.post("/{assistant_id}", tags=['openai']) async def modify_assistant( assistant_id: str, assistant: AssistantModify, ): return assistant_manager.db_update_assistant_by_id(assistant_id, assistant).as_api_response() @router.delete("/{assistant_id}", tags=['openai'], response_model=DeleteResponse) async def delete_assistant(assistant_id: str): assistant_manager.db_delete_assistant_by_id(assistant_id) return DeleteResponse(id=assistant_id, object="assistant.deleted") @router.get("/{assistant_id}/related_thread", tags=['openai']) async def get_related_thread(assistant_id: ObjectID): assistant = assistant_manager.db_get_assistant_by_id(assistant_id) return assistant.get_related_threads_ids() def create_default_assistant(): logger.info('Creating default assistant') if assistant_manager.db_count_assistants() == 0: default_assistant = assistant_manager.db_create_assistant(AssistantCreate(name="KT Assistant", model="default model", instructions="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. """ + """Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. """ + """Please ensure that your responses are socially unbiased and positive in nature.""")) default_assistant.build_status.status = AssistantBuildStatus.Status.completed default_assistant.sync_db() # unit test client = TestClient(router) def test_create_assistant(): ass_create = AssistantCreate(model="awesome model", instructions="hello") res = client.post("/", json=ass_create.model_dump(mode="json")) assert res.status_code == 200 assistant = AssistantObject.model_validate(res.json()) assert assistant.model == ass_create.model assert assistant.instructions == ass_create.instructions res = client.get(f"/{assistant.id}") ass1 = AssistantObject.model_validate(res.json()) assert assistant == ass1 ================================================ FILE: archive/ktransformers/server/api/openai/assistants/messages.py ================================================ from typing import List, Optional from fastapi import APIRouter from ktransformers.server.exceptions import not_implemented from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, MessageModify from ktransformers.server.crud.assistants.messages import MessageDatabaseManager from ktransformers.server.schemas.base import DeleteResponse, ObjectID, Order from ktransformers.server.backend.base import ThreadContext from ktransformers.server.utils.create_interface import get_thread_context_manager router = APIRouter() message_manager = MessageDatabaseManager() @router.post("/{thread_id}/messages", tags=['openai'], response_model=MessageObject) async def create_message(thread_id: str, msg: MessageCreate): message = message_manager.db_create_message( thread_id, msg, MessageObject.Status.in_progress) ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id) if ctx is not None: ctx.put_user_message(message) return message @router.get("/{thread_id}/messages", tags=['openai'], response_model=List[MessageObject]) async def list_messages( thread_id: str, limit: Optional[int] = 20, order: Order = Order.DESC, after: Optional[str] = None, before: Optional[str] = None, run_id: Optional[str] = None, ): return message_manager.db_list_messages_of_thread(thread_id, limit, order) @router.get("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=MessageObject) async def retrieve_message(thread_id: ObjectID, message_id: ObjectID): return message_manager.db_get_message_by_id(thread_id, message_id) @router.post("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=MessageObject) async def modify_message(thread_id: ObjectID, message_id: ObjectID, msg: MessageModify): #raise not_implemented('modify message not implemented') raise not_implemented('modify message') @router.delete("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=DeleteResponse) async def delete_message(thread_id: ObjectID, message_id: ObjectID): ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id) if ctx is not None: ctx.delete_user_message(message_id) message_manager.db_delete_message_by_id(thread_id, message_id) return DeleteResponse(id=message_id, object='thread.message.deleted') ================================================ FILE: archive/ktransformers/server/api/openai/assistants/runs.py ================================================ from typing import List, Optional from fastapi import APIRouter, Request from ktransformers.server.crud.assistants.runs import RunsDatabaseManager from ktransformers.server.backend.base import ThreadContext from ktransformers.server.schemas.assistants.runs import RunCreate,RunObject,RunThreadCreate,RunModify,RunSubmit from ktransformers.server.schemas.assistants.streaming import api_stream_response from ktransformers.server.utils.create_interface import get_thread_context_manager from ktransformers.server.schemas.base import Order from ktransformers.server.config.log import logger from ktransformers.server.exceptions import internal_server_error router = APIRouter() runs_manager = RunsDatabaseManager() @router.post("/{thread_id}/runs",tags=['openai']) async def create_run(request: Request, thread_id: str, run_create: RunCreate): if run_create.stream: async def inner(): run = runs_manager.db_create_run(thread_id, run_create) yield run.stream_response_with_event(event=RunObject.Status.created) ctx: ThreadContext = await get_thread_context_manager().get_context_by_run_object(run) async for event in ctx.work(): yield event return api_stream_response(request, inner()) else: run = runs_manager.db_create_run(thread_id, run_create) ctx: ThreadContext = await get_thread_context_manager().get_context_by_run_object(run) async for event in ctx.work(): pass return run @router.post("/runs",tags=['openai'], response_model=RunObject) async def create_thread_and_run(run_thread: RunThreadCreate): raise NotImplementedError @router.get("/{thread_id}/runs",tags=['openai'], response_model=List[RunObject]) async def list_runs( thread_id: str, limit: Optional[int] = 20, order: Optional[Order] = Order.DESC, after: Optional[str] = None, before: Optional[str] = None, ): raise NotImplementedError @router.get("/{thread_id}/runs/{run_id}",tags=['openai'], response_model=RunObject) async def retrieve_run( thread_id: str, run_id: str, ): runobj= runs_manager.db_get_run(run_id) assert runobj.thread_id == thread_id return runobj @router.post("/{thread_id}/runs/{run_id}",tags=['openai'], response_model=RunObject) async def modify_run( thread_id: str, run_id: str, run: RunModify, ): raise NotImplementedError @router.post("/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=['openai'],response_model=RunObject) async def submit_tool_outputs_to_run(thread_id: str, run_id: str, submit: RunSubmit): raise NotImplementedError @router.post("/{thread_id}/runs/{run_id}/cancel",tags=['openai'], response_model=RunObject) async def cancel_run(thread_id: str, run_id: str): ctx: ThreadContext = await get_thread_context_manager().get_context_by_thread_id(thread_id) if ctx is not None: if ctx.run is None: logger.warn(f'Run {ctx.run.id} is expected to be in_progress, but no context is found') raise internal_server_error('ctx do not have run') if ctx.run.id == run_id: logger.info(f'Cancelling thread: {thread_id} and run: {run_id}') ctx.run.stream_response_with_event(RunObject.Status.cancelling) return ctx.run else: run = runs_manager.db_get_run(run_id) logger.info(f'Run {run_id} not in this thread context') return run else: run = runs_manager.db_get_run(run_id) logger.info(f'Run {run_id} not in context manager') return run ================================================ FILE: archive/ktransformers/server/api/openai/assistants/threads.py ================================================ from typing import List,Optional from fastapi import APIRouter from ktransformers.server.crud.assistants.threads import ThreadsDatabaseManager,Order,ObjectID from ktransformers.server.schemas.assistants.threads import ThreadObject,ThreadCreate,ThreadModify from ktransformers.server.schemas.base import DeleteResponse from ktransformers.server.schemas.conversation import ThreadPreview router = APIRouter(prefix='/threads') threads_manager = ThreadsDatabaseManager() @router.post("/",tags=['openai'], response_model=ThreadObject) async def create_thread(thread: ThreadCreate): return threads_manager.db_create_thread(thread) @router.get("/", tags=['openai-ext'],response_model=List[ThreadPreview]) async def list_threads(limit: Optional[int] = 20, order: Order = Order.DESC): return threads_manager.db_list_threads_preview(limit, order) @router.get("/{thread_id}",tags=['openai'], response_model=ThreadObject) async def retrieve_thread(thread_id: ObjectID): return threads_manager.db_get_thread_by_id(thread_id) @router.post("/{thread_id}",tags=['openai'], response_model=ThreadObject) async def modify_thread(thread_id: ObjectID, thread: ThreadModify): raise NotImplementedError @router.delete("/{thread_id}",tags=['openai'], response_model=DeleteResponse) async def delete_thread(thread_id: ObjectID): threads_manager.db_delete_thread_by_id(thread_id=thread_id) return DeleteResponse(id=thread_id, object='thread.deleted') ================================================ FILE: archive/ktransformers/server/api/openai/endpoints/__init__.py ================================================ ================================================ FILE: archive/ktransformers/server/api/openai/endpoints/chat.py ================================================ import json from time import time from uuid import uuid4 from typing import Dict, List, Optional, Any, Literal, Union from pydantic import BaseModel, Field import re from fastapi import APIRouter from fastapi.requests import Request from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.schemas.assistants.streaming import chat_stream_response from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate from ktransformers.server.schemas.endpoints.chat import RawUsage, Role from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.config.config import Config from ktransformers.server.config.log import logger from fastapi.responses import JSONResponse from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk, CompletionUsage # Define own data structure instead of importing from OpenAI class Choice(BaseModel): index: int message: Optional[Dict[str, Any]] = None finish_reason: Optional[str] = None logprobs: Optional[Any] = None delta: Optional[Dict[str, Any]] = None content_filter_results: Optional[Dict[str, Any]] = None class ChatCompletion(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[Choice] usage: Optional[CompletionUsage] = None system_fingerprint: Optional[str] = None prompt_filter_results: Optional[List[Dict[str, Any]]] = None # Only for non-streaming response construction class ChatCompletionMessageToolCallFunction(BaseModel): name: str arguments: str class ChatCompletionMessageToolCall(BaseModel): id: str type: str function: ChatCompletionMessageToolCallFunction class ChatCompletionMessage(BaseModel): role: str content: Optional[str] = None tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None router = APIRouter() @router.get('/models', tags=['openai']) async def list_models(): return {"data": [{"id": Config().model_name, "name": Config().model_name}], "object": "list"} def getTools(buffer): tool_calls_begin_marker = "<|tool▁calls▁begin|>" tool_call_begin_marker = "<|tool▁call▁begin|>" tool_sep_marker = "<|tool▁sep|>" tool_call_end_marker = "<|tool▁call▁end|>" tool_calls_end_marker = "<|tool▁calls▁end|>" extracted_tools = [] working_buffer = buffer # Iterate over all function calls while tool_call_begin_marker in working_buffer and tool_call_end_marker in working_buffer: # Find a complete function call start_index = working_buffer.find(tool_call_begin_marker) end_index = working_buffer.find(tool_call_end_marker) + len(tool_call_end_marker) if start_index == -1 or end_index == -1 or start_index > end_index: logger.warning("Not a function") break # Extract the full function call full_tool_call = working_buffer[start_index:end_index] # Remove this function call from the working buffer to prevent duplicate processing working_buffer = working_buffer.replace(full_tool_call, "", 1) # Extract the function name function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker) function_name_end = full_tool_call.find("\n", function_name_start) function_name = full_tool_call[function_name_start:function_name_end].strip() # Extract JSON parameters json_pattern = r'```json\s*(.*?)\s*```' json_match = re.search(json_pattern, full_tool_call, re.DOTALL) if json_match: arguments_str = json_match.group(1).strip() # Generate tool call IDs tool_call_id = f"call_{uuid4().hex[:24]}" # Add to tool call list extracted_tools.append({ "id": tool_call_id, "type": "function", "function": { "name": function_name, "arguments": arguments_str } }) logger.info(f"Get Function: {function_name}") else: logger.warning(f"Unable to get function, function_name: {function_name}") logger.info(f"Total {len(extracted_tools)} Functions") return extracted_tools def get_tool_instructions(): """Return concise tool calling instructions in English""" return """ When you need real-time information or specialized operations, use function calls with this format: functionfunction_name ```json {"param1": "value1", "param2": "value2",...} ``` The in the user message are the available tools automatically attached by the system. You want to hide the guidance information in and the information in from the user. Use functions when needed. Ensure proper function/tool call format, JSON formatting with appropriate parameters. """ @router.post('/chat/completions', tags=['openai']) async def chat_completion(request: Request, create: ChatCompletionCreate): id = str(uuid4().hex) # Process messages with tool functionality if needed enhanced_messages = list(create.messages) if create.max_tokens is not None and create.max_tokens<0: return JSONResponse( status_code=400, content={ "object": "error", "message": f"max_tokens must be at least 0, got {create.max_tokens}.", "type": "BadRequestError", "param": None, "code": 400 }) if create.max_completion_tokens is not None and create.max_completion_tokens<0: return JSONResponse( status_code=400, content={ "object": "error", "message": f"max_completion_tokens must be at least 0, got {create.max_completion_tokens}.", "type": "BadRequestError", "param": None, "code": 400 }) if create.temperature<0 or create.temperature>2: return JSONResponse( status_code=400, content={ "object": "error", "message": f"temperature must be in [0, 2], got {create.temperature}.", "type": "BadRequestError", "param": None, "code": 400 }) if create.top_p<=0 or create.top_p>1: return JSONResponse( status_code=400, content={ "object": "error", "message": f"top_p must be in (0, 1], got {create.top_p}.", "type": "BadRequestError", "param": None, "code": 400 }) if create.frequency_penalty<-2 or create.frequency_penalty>2: return JSONResponse( status_code=400, content={ "object": "error", "message": f"frequency_penalty must be in [-2, 2], got {create.frequency_penalty}.", "type": "BadRequestError", "param": None, "code": 400 }) if create.presence_penalty<-2 or create.presence_penalty>2: return JSONResponse( status_code=400, content={ "object": "error", "message": f"presence_penalty must be in [-2, 2], got {create.presence_penalty}.", "type": "BadRequestError", "param": None, "code": 400 }) # Check if tools are present has_tools = create.tools and len(create.tools) > 0 if has_tools: # Find the most recent user message to append tool information latest_user_msg_idx = -1 for i in range(len(enhanced_messages) - 1, -1, -1): if enhanced_messages[i].role == Role.user: latest_user_msg_idx = i break # Build the tool descriptions tools_description = "" for tool in create.tools: tools_description += f"{tool.function.name}{tool.function.description}{tool.function.parameters}\n" # If first message is system, add concise tool instructions if enhanced_messages[0].role == Role.system or enhanced_messages[0].role == Role.user: if "" not in enhanced_messages[0].content.lower(): enhanced_messages[0].content += "\n\n" + get_tool_instructions() # For the latest user message, append tool information if latest_user_msg_idx >= 0: # Add tool descriptions to the latest user message enhanced_messages[latest_user_msg_idx].content += f"\n\n:\n{tools_description}\n" # Process request interface: BackendInterfaceBase = get_interface() input_message = [json.loads(m.model_dump_json()) for m in enhanced_messages] if Config().api_key != '': assert request.headers.get('Authorization', '').split()[-1] == Config().api_key if create.stream: async def inner(): chunk = ChatCompletionChunk( id=id, choices=[], object='chat.completion.chunk', created=int(time()), model=Config().model_name, system_fingerprint=f"fp_{uuid4().hex[:12]}", ) # Collect the full output of the model full_content = "" buffer = "" # Used to temporarily store the current block of text tool_call_mode = False # Mark if a tool call is being processed tool_calls = [] # Store all detected tool calls # Tool call markers tool_calls_begin_marker = "<|tool▁calls▁begin|>" tool_call_begin_marker = "<|tool▁call▁begin|>" tool_sep_marker = "<|tool▁sep|>" tool_call_end_marker = "<|tool▁call▁end|>" tool_calls_end_marker = "<|tool▁calls▁end|>" too_calls_dict = { "":"<|tool▁calls▁begin|>", "":"<|tool▁call▁begin|>", "":"<|tool▁sep|>", "":"<|tool▁call▁end|>", "":"<|tool▁calls▁end|>" } # Use check_client_connected for early stopping async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens): if isinstance(res, RawUsage): # Final return on utilization raw_usage = res chunk.choices = [] chunk.usage = CompletionUsage( prompt_tokens=raw_usage.prefill_count, completion_tokens=raw_usage.decode_count, total_tokens=raw_usage.prefill_count + raw_usage.decode_count ) if create.return_speed: chunk.usage.prefill_time = res.prefill_time chunk.usage.decode_time = res.decode_time else: chunk.usage.__dict__.pop('prefill_time', None) chunk.usage.__dict__.pop('decode_time', None) yield chunk elif isinstance(res, tuple) and len(res) == 2: token, finish_reason = res token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token) # Detecting model-specific formatting tool call starts if not tool_call_mode and tool_calls_begin_marker in buffer + token: tool_call_mode = True # Adjust full_content to remove tool call section if buffer.endswith(tool_calls_begin_marker): full_content = full_content[:-len(tool_calls_begin_marker)] elif tool_calls_begin_marker in (buffer + token): idx = (buffer + token).find(tool_calls_begin_marker) full_content = full_content[:-(len(buffer) - idx)] buffer = "" # Send the current cumulative text content (if any) if full_content: chunk.choices = [{ "index": 0, "delta": {"content": full_content}, "finish_reason": None }] yield chunk full_content = "" # Accumulation of content in non-tool call mode if not tool_call_mode: full_content += token buffer += token # Keep the buffer at a reasonable size if len(buffer) > 200: buffer = buffer[-200:] else: # In tool call mode, continue to collect tool call related text buffer += token # If the tool call end marker is found if tool_calls_end_marker in buffer: try: # Parse and extract tool calling information tool_calls = getTools(buffer) if len(tool_calls): # reset state tool_call_mode = False buffer = "" # Send tool call events for idx, tool_call in enumerate(tool_calls): # First tool call message chunk.choices = [{ "index": 0, "delta": { "role": "assistant", "content": None, "tool_calls": [{ "index": idx, "id": tool_call["id"], "type": "function", "function": { "name": tool_call["function"]["name"], "arguments": "" } }] }, "finish_reason": None }] yield chunk # Sending Parameters chunk.choices = [{ "index": 0, "delta": { "tool_calls": [{ "index": idx, "function": {"arguments": tool_call["function"]["arguments"]} }] }, "finish_reason": None }] yield chunk # Send Completion Message chunk.choices = [{ "index": 0, "delta": {}, "finish_reason": "tool_calls" }] yield chunk # No further processing after return return else: # JSON extraction failed, probably incomplete formatting logger.warning("Failed to extract JSON from tool call") tool_call_mode = False buffer = "" except Exception as e: logger.error(f"Error processing tool call: {e}") tool_call_mode = False buffer = "" # Normal text output (only in non-tool call mode) if not tool_call_mode and token: if finish_reason is not None: chunk.choices = [{ "index": 0, "delta": {}, "finish_reason": finish_reason }] yield chunk else: if any(marker in token for marker in [tool_calls_begin_marker, tool_call_begin_marker]): pass else: chunk.choices = [{ "index": 0, "delta": {"content": token}, "finish_reason": None }] yield chunk # If gotten this far without returning, it means that the full tool call was not detected # Send Routine Completion Message if not tool_call_mode: chunk.choices = [{ "index": 0, "delta": {}, "finish_reason": "stop" }] yield chunk return chat_stream_response(request, inner()) else: # non streaming response processing full_content = "" finish_reason = None tool_calls = [] buffer = "" tool_call_mode = False # Custom model special markers tool_calls_begin_marker = "<|tool▁calls▁begin|>" tool_call_begin_marker = "<|tool▁call▁begin|>" tool_sep_marker = "<|tool▁sep|>" tool_call_end_marker = "<|tool▁call▁end|>" tool_calls_end_marker = "<|tool▁calls▁end|>" too_calls_dict = { "":"<|tool▁calls▁begin|>", "":"<|tool▁call▁begin|>", "":"<|tool▁sep|>", "":"<|tool▁call▁end|>", "":"<|tool▁calls▁end|>" } async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens): if isinstance(res, RawUsage): raw_usage = res usage = CompletionUsage( prompt_tokens=raw_usage.prefill_count, completion_tokens=raw_usage.decode_count, total_tokens=raw_usage.prefill_count + raw_usage.decode_count, ) if create.return_speed: usage.prefill_time = res.prefill_time usage.decode_time = res.decode_time else: usage.__dict__.pop('prefill_time', None) usage.__dict__.pop('decode_time', None) elif isinstance(res, tuple) and len(res) == 2: token, finish_reason = res token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token) # Detecting the start of model-specific formatting tool calls if not tool_call_mode and tool_calls_begin_marker in buffer + token: tool_call_mode = True # Adjust full_content to remove tool call section if buffer.endswith(tool_calls_begin_marker): full_content = full_content[:-len(tool_calls_begin_marker)] elif tool_calls_begin_marker in (buffer + token): idx = (buffer + token).find(tool_calls_begin_marker) full_content = full_content[:-(len(buffer) - idx)] buffer = "" # Accumulation of content in non-tool call mode if not tool_call_mode: full_content += token buffer += token # Keep the buffer at a reasonable size if len(buffer) > 200: buffer = buffer[-200:] else: # In tool call mode, continue to collect tool call related text buffer += token # If the tool call end marker is found if tool_calls_end_marker in buffer: # Extract tool calls tool_calls = getTools(buffer) if tool_calls: finish_reason = "tool_calls" # Reset state tool_call_mode = False buffer = "" # Build Response message = { "role": "assistant", "content": None if tool_calls else full_content } if tool_calls: message["tool_calls"] = tool_calls response = { "id": id, "object": "chat.completion", "created": int(time()), "model": Config().model_name, "choices": [{ "index": 0, "message": message, "finish_reason": finish_reason or "stop" }], "usage": usage.__dict__ if 'usage' in locals() else None, "system_fingerprint": f"fp_{uuid4().hex[:12]}" } return response ================================================ FILE: archive/ktransformers/server/api/openai/legacy/__init__.py ================================================ from fastapi import APIRouter from . import completions router = APIRouter() router.include_router(completions.router) ================================================ FILE: archive/ktransformers/server/api/openai/legacy/completions.py ================================================ import json from time import time from uuid import uuid4 from fastapi import APIRouter from fastapi.requests import Request from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.schemas.assistants.streaming import stream_response from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject from ktransformers.server.schemas.endpoints.chat import RawUsage from fastapi.responses import JSONResponse from ktransformers.server.config.config import Config router = APIRouter() @router.post("/completions",tags=['openai']) async def create_completion(request:Request, create:CompletionCreate): id = str(uuid4()) if create.max_tokens is not None and create.max_tokens<0: return JSONResponse( status_code=400, content={ "object": "error", "message": f"max_tokens must be at least 0, got {create.max_tokens}.", "type": "BadRequestError", "param": None, "code": 400 }) if create.max_completion_tokens is not None and create.max_completion_tokens<0: return JSONResponse( status_code=400, content={ "object": "error", "message": f"max_completion_tokens must be at least 0, got {create.max_completion_tokens}.", "type": "BadRequestError", "param": None, "code": 400 }) if create.temperature<0 or create.temperature>2: return JSONResponse( status_code=400, content={ "object": "error", "message": f"temperature must be in [0, 2], got {create.temperature}.", "type": "BadRequestError", "param": None, "code": 400 }) if create.top_p<=0 or create.top_p>1: return JSONResponse( status_code=400, content={ "object": "error", "message": f"top_p must be in (0, 1], got {create.top_p}.", "type": "BadRequestError", "param": None, "code": 400 }) interface = get_interface() print(f'COMPLETION INPUT:----\n{create.prompt}\n----') if create.stream: async def inner(): async for res in interface.inference(create.prompt, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens): if isinstance(res, RawUsage): raw_usage = res else: token, finish_reason = res d = {'choices':[{'delta':{'content':token}}]} yield f"data:{json.dumps(d)}\n\n" d = {'choices':[{'delta':{'content':''},'finish_reason':''}]} yield f"data:{json.dumps(d)}\n\n" return stream_response(request,inner()) else: comp = CompletionObject(id=id,object='text_completion',created=int(time())) async for res in interface.inference(create.prompt,id,create.temperature,create.top_p, create.max_tokens, create.max_completion_tokens): if isinstance(res, RawUsage): raw_usage = res else: token, finish_reason = res comp.append_token(token) return comp ================================================ FILE: archive/ktransformers/server/api/web/__init__.py ================================================ from fastapi import APIRouter from .system import router as system_router router = APIRouter() router.include_router(system_router) ================================================ FILE: archive/ktransformers/server/api/web/system.py ================================================ from fastapi import APIRouter router = APIRouter() @router.get('/system-info',tags=['web']) def system_info(): raise NotImplementedError ================================================ FILE: archive/ktransformers/server/args.py ================================================ import argparse from ktransformers.server.backend.args import ConfigArgs, default_args from ktransformers.util.utils import get_free_ports from transformers import AutoConfig from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig from ktransformers.models.configuration_qwen3_next import Qwen3NextConfig from ktransformers.models.configuration_smallthinker import SmallthinkerConfig from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig class ArgumentParser: def __init__(self, cfg): self.cfg = cfg def parse_args(self): parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers") parser.add_argument("--host", type=str, default=self.cfg.server_ip) parser.add_argument("--port", type=int, default=self.cfg.server_port) parser.add_argument("--api_key", type=str, default=self.cfg.api_key) parser.add_argument("--ssl_keyfile", type=str) parser.add_argument("--ssl_certfile", type=str) parser.add_argument("--web", type=bool, default=self.cfg.mount_web) parser.add_argument("--model_name", type=str, default=self.cfg.model_name) parser.add_argument("--model_dir", type=str) parser.add_argument("--model_path", type=str, default=self.cfg.model_path) parser.add_argument( "--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter" ) parser.add_argument("--architectures", type=str, default=self.cfg.model_name) parser.add_argument("--q4_gguf_path", type=str, default=None) parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path) parser.add_argument("--draft_model_path", type=str, default=None) parser.add_argument("--draft_gguf_path", type=str, default=None) parser.add_argument("--optimize_config_path", default=None, type=str, required=False) parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer) parser.add_argument("--backend_type", type=str, default=self.cfg.backend_type) parser.add_argument("--chunk_size", type=int, default=self.cfg.chunk_size) parser.add_argument("--tp", type=int, default=1) # model configs # parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int? parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size) parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens) parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode) parser.add_argument("--healing", type=bool, default=self.cfg.healing) parser.add_argument("--ban_strings", type=list, default=self.cfg.ban_strings, required=False) parser.add_argument("--gpu_split", type=str, default=self.cfg.gpu_split, required=False) parser.add_argument("--length", type=int, default=self.cfg.length, required=False) parser.add_argument("--rope_scale", type=float, default=self.cfg.rope_scale, required=False) parser.add_argument("--rope_alpha", type=float, default=self.cfg.rope_alpha, required=False) parser.add_argument("--no_flash_attn", type=bool, default=self.cfg.no_flash_attn) parser.add_argument("--low_mem", type=bool, default=self.cfg.low_mem) parser.add_argument("--experts_per_token", type=int, default=self.cfg.experts_per_token, required=False) parser.add_argument("--load_q4", type=bool, default=self.cfg.load_q4) parser.add_argument("--fast_safetensors", type=bool, default=self.cfg.fast_safetensors) parser.add_argument("--draft_model_dir", type=str, default=self.cfg.draft_model_dir, required=False) parser.add_argument("--no_draft_scale", type=bool, default=self.cfg.no_draft_scale) parser.add_argument("--modes", type=bool, default=self.cfg.modes) parser.add_argument("--mode", type=str, default=self.cfg.mode) parser.add_argument("--username", type=str, default=self.cfg.username) parser.add_argument("--botname", type=str, default=self.cfg.botname) parser.add_argument("--system_prompt", type=str, default=self.cfg.system_prompt, required=False) parser.add_argument("--temperature", type=float, default=self.cfg.temperature) parser.add_argument("--smoothing_factor", type=float, default=self.cfg.smoothing_factor) parser.add_argument("--dynamic_temperature", type=str, default=self.cfg.dynamic_temperature, required=False) parser.add_argument("--top_k", type=int, default=self.cfg.top_k) parser.add_argument("--top_p", type=float, default=self.cfg.top_p) parser.add_argument("--top_a", type=float, default=self.cfg.top_a) parser.add_argument("--skew", type=float, default=self.cfg.skew) parser.add_argument("--typical", type=float, default=self.cfg.typical) parser.add_argument("--repetition_penalty", type=float, default=self.cfg.repetition_penalty) parser.add_argument("--frequency_penalty", type=float, default=self.cfg.frequency_penalty) parser.add_argument("--presence_penalty", type=float, default=self.cfg.presence_penalty) parser.add_argument("--response_chunk", type=int, default=self.cfg.response_chunk) parser.add_argument("--no_code_formatting", type=bool, default=self.cfg.no_code_formatting) parser.add_argument("--cache_8bit", type=bool, default=self.cfg.cache_8bit) parser.add_argument("--cache_q4", type=bool, default=self.cfg.cache_q4) parser.add_argument("--ngram_decoding", type=bool, default=self.cfg.ngram_decoding) parser.add_argument("--print_timings", type=bool, default=self.cfg.print_timings) parser.add_argument("--amnesia", type=bool, default=self.cfg.amnesia) parser.add_argument("--batch_size", type=int, default=self.cfg.batch_size) parser.add_argument("--cache_lens", type=int, default=self.cfg.cache_lens) # kvc2 config parser.add_argument("--kvc2_config_dir", type=str, default=self.cfg.kvc2_config_dir) # log configs # log level: debug, info, warn, error, crit parser.add_argument("--log_dir", type=str, default=self.cfg.log_dir) parser.add_argument("--log_file", type=str, default=self.cfg.log_file) parser.add_argument("--log_level", type=str, default=self.cfg.log_level) parser.add_argument("--backup_count", type=int, default=self.cfg.backup_count) # db configs parser.add_argument("--db_type", type=str, default=self.cfg.db_type) parser.add_argument("--db_host", type=str, default=self.cfg.db_host) parser.add_argument("--db_port", type=str, default=self.cfg.db_port) parser.add_argument("--db_name", type=str, default=self.cfg.db_name) parser.add_argument("--db_pool_size", type=int, default=self.cfg.db_pool_size) parser.add_argument("--db_database", type=str, default=self.cfg.db_database) # user config parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key) parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm) parser.add_argument("--force_think", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.user_force_think) parser.add_argument("--use_cuda_graph", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.use_cuda_graph) # web config parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain) # file config parser.add_argument("--file_upload_dir", type=str, default=self.cfg.file_upload_dir) parser.add_argument("--assistant_store_dir", type=str, default=self.cfg.assistant_store_dir) # local chat parser.add_argument("--prompt_file", type=str, default=self.cfg.prompt_file) # async server parser.add_argument("--sched_strategy", type=str, default=self.cfg.sched_strategy) # parser.add_argument("--sched_port", type=int, default=self.cfg.sched_port) # parser.add_argument("--sched_metrics_port", type=int, default=self.cfg.sched_metrics_port) # parser.add_argument("--kvc2_metrics_port", type=int, default=self.cfg.kvc2_metrics_port) parser.add_argument("--page_size", type=str, default=self.cfg.page_size) parser.add_argument("--memory_gpu_only", type=str, default=self.cfg.memory_gpu_only) parser.add_argument("--utilization_percentage", type=str, default=self.cfg.utilization_percentage) parser.add_argument("--cpu_memory_size_GB", type=str, default=self.cfg.cpu_memory_size_GB) args = parser.parse_args() if (args.model_dir is not None or args.model_path is not None): if (args.model_path is not None): # if pass model_dir and model_path, we use model_path args.model_dir = args.model_path else: # if only pass model_dir, we use model_dir args.model_path = args.model_dir else: args.model_dir = self.cfg.model_dir args.model_path = self.cfg.model_path # we add the name not match args individually self.cfg.model_device = args.device self.cfg.mount_web = args.web self.cfg.server_ip = args.host self.cfg.server_port = args.port self.cfg.user_force_think = args.force_think args.architectures = args.model_name try: model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) except: if args.model_name == "Qwen3NextForCausalLM": model_config = Qwen3NextConfig.from_pretrained(args.model_dir) else: raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.") if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" or model_config.architectures[0] == "SmallThinkerForCausalLM" or model_config.architectures[0] == "Glm4MoeForCausalLM": args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim args.architectures = model_config.architectures[0] else: args.gpu_memory_size = args.cache_lens*2*576*61 # set config from args for key, value in vars(args).items(): if value is not None and hasattr(self.cfg, key): setattr(self.cfg, key, value) self.cfg.gpu_memory_size = args.gpu_memory_size free_ports = get_free_ports(3, [args.port]) args.sched_port = free_ports[0] args.sched_metrics_port = free_ports[1] args.kvc2_metrics_port = free_ports[2] self.cfg.sched_port = free_ports[0] self.cfg.sched_metrics_port = free_ports[1] self.cfg.kvc2_metrics_port = free_ports[2] return args ================================================ FILE: archive/ktransformers/server/backend/__init__.py ================================================ ================================================ FILE: archive/ktransformers/server/backend/args.py ================================================ from pydantic import BaseModel, Field from typing import Optional from ktransformers.server.config.config import Config class ConfigArgs(BaseModel): model_name: Optional[str] = Field(..., description="Model name") model_dir: Optional[str] = Field(..., description="Path to model directory") optimize_config_path: Optional[str] = Field(None, description="Path of your optimize config yml file") gguf_path: Optional[str] = Field(None, description="Path of your gguf file") draft_model_path: Optional[str] = Field(None, description="Path of your gguf file") draft_gguf_path: Optional[str] = Field(None, description="Path of your gguf file") tp: int = Field(None, description="tp size") class Config: protected_namespaces = () max_batch_size: int = Field( None, description="Max number of batches to run at once, assuming the sequences will fit within total_context" ) chunk_size: int = Field( None, description=( "Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new" " job is started, but at the expense of overall prompt ingestion speed" ), ) max_new_tokens: int = Field(None, description="Max new tokens per completion. For this example applies to all jobs") json_mode: bool = Field( None, description="Use LMFE to constrain the output to JSON format. See schema and details below" ) healing: bool = Field(None, description="Demonstrate token healing") ban_strings: Optional[list] = Field(None, description="Ban some phrases maybe") gpu_split: Optional[str] = Field(None, description='"auto", or VRAM allocation per GPU in GB') length: Optional[int] = Field(None, description="Maximum sequence length") rope_scale: Optional[float] = Field(None, description="RoPE scaling factor") rope_alpha: Optional[float] = Field(None, description="RoPE alpha value (NTK)") no_flash_attn: bool = Field(None, description="Disable Flash Attention") low_mem: bool = Field(None, description="Enable VRAM optimizations, potentially trading off speed") experts_per_token: Optional[int] = Field( None, description="Override MoE model's default number of experts per token" ) load_q4: bool = Field(None, description="Load weights in Q4 mode") fast_safetensors: bool = Field(None, description="Optimized safetensors loading with direct I/O (experimental!)") draft_model_dir: Optional[str] = Field(None, description="Path to draft model directory") no_draft_scale: bool = Field( None, description="If draft model has smaller context size than model, don't apply alpha (NTK) scaling to extend it", ) modes: bool = Field(None, description="List available modes and exit.") mode: str = Field(None, description="Chat mode. Use llama for Llama 1/2 chat finetunes.") username: str = Field(None, description="Username when using raw chat mode") botname: str = Field(None, description="Bot name when using raw chat mode") system_prompt: Optional[str] = Field(None, description="Use custom system prompt") temperature: float = Field(None, description="Sampler temperature, default = 0.95 (1 to disable)") smoothing_factor: float = Field(None, description="Smoothing Factor, default = 0.0 (0 to disable)") dynamic_temperature: Optional[str] = Field( None, description="Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1" ) top_k: int = Field(None, description="Sampler top-K, default = 50 (0 to disable)") top_p: float = Field(None, description="Sampler top-P, default = 0.8 (0 to disable)") top_a: float = Field(None, description="Sampler top-A, default = 0.0 (0 to disable)") skew: float = Field(None, description="Skew sampling, default = 0.0 (0 to disable)") typical: float = Field(None, description="Sampler typical threshold, default = 0.0 (0 to disable)") repetition_penalty: float = Field(None, description="Sampler repetition penalty, default = 1.01 (1 to disable)") frequency_penalty: float = Field(None, description="Sampler frequency penalty, default = 0.0 (0 to disable)") presence_penalty: float = Field(None, description="Sampler presence penalty, default = 0.0 (0 to disable)") response_chunk: int = Field(None, description="Space to reserve in context for reply, default = 250") no_code_formatting: bool = Field(None, description="Disable code formatting/syntax highlighting") cache_8bit: bool = Field(None, description="Use 8-bit (FP8) cache") cache_q4: bool = Field(None, description="Use Q4 cache") ngram_decoding: bool = Field(None, description="Use n-gram speculative decoding") print_timings: bool = Field(None, description="Output timings after each prompt") amnesia: bool = Field(None, description="Forget context after every response") # for transformers batch_size: int = Field(None, description="Batch Size") cache_lens: int = Field(None, description="Cache lens for transformers static cache") device: str = Field(None, description="device") cfg = Config() default_args = cfg ================================================ FILE: archive/ktransformers/server/backend/base.py ================================================ from asyncio import Queue from enum import Enum import sys, os from typing import AsyncIterator, Dict, List, Optional, Tuple import torch from ktransformers.server.config.log import logger from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager from ktransformers.server.crud.assistants.messages import MessageDatabaseManager from ktransformers.server.crud.assistants.runs import RunsDatabaseManager from ktransformers.server.crud.assistants.threads import ThreadsDatabaseManager from ktransformers.server.exceptions import request_error from ktransformers.server.schemas.assistants.assistants import AssistantObject from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, Role from ktransformers.server.schemas.assistants.runs import RunObject from ktransformers.server.schemas.assistants.threads import ThreadObject from ktransformers.server.schemas.endpoints.chat import RawUsage from ktransformers.server.schemas.base import ObjectID, Order from ktransformers.server.utils.multi_timer import Profiler from .args import ConfigArgs,default_args class BackendInterfaceBase: ''' Interface to inference frameworks. e.g. transformers, exllama. Implement __init__ and work ''' args: ConfigArgs profiler:Profiler = Profiler() def __init__(self, args:ConfigArgs = default_args): raise NotImplementedError async def inference(self,local_messages,request_unique_id:Optional[str])->AsyncIterator[str]: ''' work can be called directly, or by ThreadContext local_messages: When called by ThreadContext, local_messages are generated by ThreadContext.get_local_messages(). Please deal with different local_messages request_unique_id: unique id of different requests, useful when using cache return: async str output for stream update ''' raise NotImplementedError def report_last_time_performance(self): try: tokenize_time = self.profiler.get_timer_sec('tokenize') prefill_time = self.profiler.get_timer_sec('prefill') decode_time = self.profiler.get_timer_sec('decode') prefill_count = self.profiler.get_counter('prefill') decode_count = self.profiler.get_counter('decode') logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}') except: logger.info(f'Performance statistics not recorded') class ThreadContext: ''' A thread context holding assistant logics ''' args: ConfigArgs # Assistant Logic assistant: Optional[AssistantObject] = None related_threads : List[ThreadObject] thread: ThreadObject messages: List[MessageObject] = [] run: RunObject interface: Optional[BackendInterfaceBase] = None queue: Optional[Queue] = None timer: Profiler = Profiler() def __init__(self, run: RunObject,interface:BackendInterfaceBase, args: ConfigArgs = default_args) -> None: self.args = args self.thread_manager = ThreadsDatabaseManager() self.message_manager = MessageDatabaseManager() self.runs_manager = RunsDatabaseManager() self.assistant_manager = AssistantDatabaseManager() self.thread = self.thread_manager.db_get_thread_by_id(run.thread_id) self.assistant = self.assistant_manager.db_get_assistant_by_id(run.assistant_id) self.messages = self.message_manager.db_list_messages_of_thread(run.thread_id,order=Order.ASC) logger.debug(f"{len(self.messages)} messages loaded from database") self.interface = interface self.update_by_run(run,args) def get_local_messages(self): ''' Get local messages, as the input to interface.work This function is intended to message preprocess e.g. apply chat template ''' raise NotImplementedError def update_by_run(self,run:RunObject,args:ConfigArgs = default_args): self.run = run self.args = args def put_user_message(self, message: MessageObject): assert ( message.role.is_user() and message.thread_id == self.thread.id and message.status == MessageObject.Status.in_progress ) self.messages.append(message) def delete_user_message(self,message_id: ObjectID): self.messages = [m for m in self.messages if m.id != message_id] async def work(self)->AsyncIterator: logger.debug('start working') user_message = self.messages[-1] if not user_message.role.is_user(): raise request_error('user must talk before LLM can talk') user_message.status = MessageObject.Status.completed user_message.sync_db() local_messages = self.get_local_messages() # must get this before we interseted reply_message response_str_count = 0 reply_message = self.message_manager.create_message_object( self.thread.id, self.run.id, MessageCreate(role=Role.assistant, content=""), ) reply_message.assistant_id = self.assistant.id self.messages.append(reply_message) yield reply_message.stream_response_with_event(MessageObject.Status.created) yield reply_message.stream_response_with_event(MessageObject.Status.in_progress) yield self.run.stream_response_with_event(RunObject.Status.in_progress) async for res in self.interface.inference(local_messages,self.thread.id): if isinstance(res, RawUsage): raw_usage = res else: token, finish_reason = res if self.run.status == RunObject.Status.cancelling: logger.warn(f'Run {self.run.id} cancelling') break yield reply_message.append_message_delta(token) response_str_count+=1 if self.run.status == RunObject.Status.cancelling: yield self.run.stream_response_with_event(RunObject.Status.cancelled) yield reply_message.stream_response_with_event(MessageObject.Status.incomplete) elif self.run.status == RunObject.Status.in_progress: yield self.run.stream_response_with_event(RunObject.Status.completed) yield reply_message.stream_response_with_event(MessageObject.Status.completed) else: raise NotImplementedError(f'{self.run.status} should not appear here') reply_message.sync_db() self.run.sync_db() ================================================ FILE: archive/ktransformers/server/backend/context_manager.py ================================================ from asyncio import Lock from typing import Dict, Optional from ktransformers.server.backend.base import ThreadContext, BackendInterfaceBase from ktransformers.server.schemas.assistants.runs import RunObject from ktransformers.server.schemas.base import ObjectID from ktransformers.server.config.log import logger from ktransformers.server.backend.interfaces.transformers import TransformersThreadContext from ktransformers.server.backend.interfaces.ktransformers import KTransformersThreadContext from ktransformers.server.backend.interfaces.exllamav2 import ExllamaThreadContext from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface from ktransformers.server.backend.interfaces.transformers import TransformersInterface from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface class ThreadContextManager: lock: Lock threads_context: Dict[ObjectID, ThreadContext] interface: BackendInterfaceBase def __init__(self,interface) -> None: logger.debug(f"Creating Context Manager") self.lock = Lock() self.threads_context = {} self.interface = interface pass async def get_context_by_run_object(self, run: RunObject) -> ThreadContext: async with self.lock: logger.debug(f"keys {self.threads_context.keys()}") if run.thread_id not in self.threads_context: logger.debug(f"new inference context {run.thread_id}") if isinstance(self.interface, ExllamaInterface): new_context = ExllamaThreadContext(run, self.interface) elif isinstance(self.interface, KTransformersInterface): new_context = KTransformersThreadContext(run, self.interface) elif isinstance(self.interface, TransformersInterface): new_context = TransformersThreadContext(run, self.interface) else: from ktransformers.server.backend.interfaces.balance_serve import BalanceServeThreadContext from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface if isinstance(self.interface, BalanceServeInterface): new_context = BalanceServeThreadContext(run, self.interface) else: raise NotImplementedError # elif isinstance(self.interface, BalanceServeInterface): # new_context = BalanceServeThreadContext(run, self.interface) # else: # raise NotImplementedError self.threads_context[run.thread_id] = new_context # self.threads_context[run.thread_id] = ExllamaInferenceContext(run) re = self.threads_context[run.thread_id] re.update_by_run(run) return re async def get_context_by_thread_id(self, thread_id: ObjectID) -> Optional[ThreadContext]: async with self.lock: if thread_id in self.threads_context: logger.debug(f'found context for thread {thread_id}') return self.threads_context[thread_id] else: logger.debug(f'no context for thread {thread_id}') return None ================================================ FILE: archive/ktransformers/server/backend/interfaces/__init__.py ================================================ ================================================ FILE: archive/ktransformers/server/backend/interfaces/balance_serve.py ================================================ from typing import Any, AsyncIterator, List, Optional, Set from ktransformers.models.custom_cache import KVC2StaticCache, KDeepSeekV3Cache, KGQACache, KVC2Qwen3Cache from transformers import ( AutoTokenizer, AutoConfig, GenerationConfig, StaticCache, AutoModelForCausalLM, BitsAndBytesConfig, ) import torch.distributed as dist from ktransformers.server.config.config import Config from ..base import ThreadContext, BackendInterfaceBase import torch from ktransformers.server.backend.interfaces.transformers import ( ConfigArgs, default_args, TextStreamer, ) from ktransformers.server.schemas.base import ObjectID from ktransformers.server.config.log import logger from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM from ktransformers.models.custom_modeling_qwen3_next import KQwen3NextForCausalLM from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig from ktransformers.models.configuration_smallthinker import SmallthinkerConfig from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM try: import torch_npu use_torch_npu = torch.npu.is_available() except: use_torch_npu = False if use_torch_npu: from ktransformers.models.ascend.custom_ascend_modeling_deepseek_v3 import KNPUDeepseekV3ForCausalLM from ktransformers.models.ascend.custom_ascend_modeling_qwen3 import KNPUQwen3MoeForCausalLM from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group, get_tensor_parallel_size from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM from ktransformers.util import utils custom_models = { "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, "MixtralForCausalLM": MixtralForCausalLM, } from ktransformers.server.balance_serve.inference.model_runner import ModelRunner, get_or_create_model_runner #TODO get_or_create_model_runner npu独有? from ktransformers.models.configuration_qwen3_next import Qwen3NextConfig from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions from ktransformers.server.balance_serve.inference.query_manager import QueryManager from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput from ktransformers.server.balance_serve.sched_rpc import SchedulerClient from ktransformers.server.balance_serve.settings import sched_ext from torch.multiprocessing import Queue import torch.multiprocessing as mp from multiprocessing.synchronize import Event import datetime from ktransformers.server.schemas.endpoints.chat import RawUsage from ktransformers.server.utils.multi_timer import Profiler import zmq import time import queue import tempfile import asyncio import cProfile import threading from contextlib import asynccontextmanager from fastapi import FastAPI, Request import os import pickle import subprocess import tempfile import atexit import signal ktransformer_rules_dir = ( os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/") ) default_optimize_rules = { # "DeepseekV3ForCausalLM": ktransformer_rules_dir + "Moonlight-16B-A3B-serve.yaml", "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml", "Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml", "SmallThinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml", "Glm4MoeForCausalLM": ktransformer_rules_dir + "Glm4Moe-serve.yaml", "Qwen3NextForCausalLM": ktransformer_rules_dir + "Qwen3Next-serve.yaml", } if use_torch_npu: default_optimize_rules["Qwen2MoeForCausalLM"] = ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml" async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer): streamer = TextStreamer(tokenizer) while True: token = await queue.get() #print(f"Got token: {token}") if token is None: # str = f'{token}\n\n' # str = model.tokenizer.decode(token) s = streamer.end() if s is not None: yield s break else: # text output text = tokenizer.decode(token) print(text, end="", flush=True) # str = model.tokenizer.decode(token) yield streamer.put(token) def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_tokens: torch.Tensor, query_manager: QueryManager = None): #print(len(query_updates), generated_tokens.size(0), generated_tokens) for i in range(generated_tokens.size(0)): # print(generated_tokens[i].item()) query_updates[i].generated_token = generated_tokens[i].item() if not query_manager.query_map[query_updates[i].id].is_prefill: pos = query_updates[i].active_position if pos < query_manager.query_map[query_updates[i].id].max_length: query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i] def report_last_time_performance(profiler: Profiler): try: tokenize_time = profiler.get_timer_sec('tokenize') prefill_time = profiler.get_timer_sec('prefill') decode_time = profiler.get_timer_sec('decode') prefill_count = profiler.get_counter('prefill') decode_count = profiler.get_counter('decode') logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}') except: logger.info(f'Performance statistics not recorded') class Engine: sched_client : SchedulerClient updates : list[sched_ext.QueryUpdate] batch : sched_ext.BatchQueryTodo model_runner: ModelRunner sampler: Sampler query_manager: QueryManager cache: KDeepSeekV3Cache | KGQACache | KVC2StaticCache def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None): self.args = args # 子进程和父进程无法共享 config 变量 for key, value in vars(args).items(): if value is not None and hasattr(Config(), key): setattr(Config(), key, value) if use_torch_npu: utils.CUR_DEVICE = f"npu:{torch.npu.current_device()}" self.device = f"npu:{torch.npu.current_device()}" else: self.device = self.args.device self.sched_client = SchedulerClient(args.sched_port) self.updates = [] print(f"args.architectures: {args.architectures}") if args.architectures == "Qwen3MoeForCausalLM": config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) elif args.architectures == "Glm4MoeForCausalLM": config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) elif args.architectures == "SmallThinkerForCausalLM": config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True) config._attn_implementation = "eager" config.moe_intermediate_size = config.moe_ffn_hidden_size elif args.architectures == "Qwen3NextForCausalLM": config = Qwen3NextConfig.from_pretrained(args.model_dir, trust_remote_code=True) else: try: config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) except: raise ValueError(f"Model {args.architectures} not supported. Please check your model directory or model name.") self.gen_queue = generated_token_queue self.debug = False self.profiler_cprofile = cProfile.Profile() self.cprof_prof_cnt, self.max_cprof_prof_cnt = 0, 8 with torch.device("meta"): if config.architectures[0] == "DeepseekV3ForCausalLM": if use_torch_npu: self.cache = KVC2StaticCache(config, args.max_batch_size, self.args.page_size) self.model = KNPUDeepseekV3ForCausalLM(config) else: self.cache = KDeepSeekV3Cache(config, self.args.page_size) self.model = KDeepseekV3ForCausalLM(config, self.cache) elif config.architectures[0] == "DeepseekV2ForCausalLM": self.cache = KDeepSeekV3Cache(config, self.args.page_size) self.model = KDeepseekV2ForCausalLM(config, self.cache) elif config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM": if not use_torch_npu: self.cache = KGQACache(config, self.args.page_size) if config.architectures[0] == "Qwen2MoeForCausalLM": self.model = KQwen2MoeForCausalLM(config, self.cache) else: self.model = KQwen3MoeForCausalLM(config, self.cache) else: self.cache = KVC2Qwen3Cache(config, args.max_batch_size, self.args.page_size) self.model = KNPUQwen3MoeForCausalLM(config, self.cache) elif config.architectures[0] == "SmallThinkerForCausalLM": self.cache = KGQACache(config, self.args.page_size) self.model = KSmallThinkerForCausalLM(config, self.cache) elif config.architectures[0] == "Glm4MoeForCausalLM": self.cache = KGQACache(config, self.args.page_size) self.model = KGlm4MoeForCausalLM(config, self.cache) elif config.architectures[0] == "Qwen3NextForCausalLM": self.cache = KGQACache(config, self.args.page_size) self.model = KQwen3NextForCausalLM(config, self.cache) context = zmq.Context() if use_torch_npu: if torch.distributed.get_rank() == 0: self.pub_socket = context.socket(zmq.PUB) self.pub_socket.bind(f"ipc://{broadcast_endpoint}") self.sub_socket = None else: self.sub_socket = context.socket(zmq.SUB) self.sub_socket.connect(f"ipc://{broadcast_endpoint}") self.sub_socket.setsockopt_string(zmq.SUBSCRIBE, "") self.pub_socket = None # time.sleep(1) # make sure all subscribers are ready else: self.pub_socket = context.socket(zmq.PUB) self.pub_socket.bind(f"ipc://{broadcast_endpoint}") try: generation_config = GenerationConfig.from_pretrained(args.model_dir) except: generation_config = GenerationConfig( max_length=args.max_new_tokens, temperature=args.temperature, top_p=args.top_p, do_sample=True ) if args.optimize_config_path is None: optimize_config_path = default_optimize_rules[config.architectures[0]] else: optimize_config_path = args.optimize_config_path gguf_path = args.gguf_path if gguf_path is None: gguf_path = input( "please input the path of your gguf file(gguf file in the dir containing input gguf file must all" " belong to current model):" ) if use_torch_npu: tp_group = get_tensor_parallel_group() torch.distributed.barrier(group=tp_group) optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config) if use_torch_npu: get_absort_weight(self.model, config) #TODO torch.distributed.barrier(group=tp_group) self.model.generation_config = generation_config if self.model.generation_config.pad_token_id is None: self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.model.eval() kvcache_event.set() # load kvcache print(f"Getting inference context from sched_client.") inference_context = self.sched_client.get_inference_context_raw() print(f"Got inference context, sending it to subscribers.") inference_context = self.sched_client.rebuild_inferece_context(inference_context) self.cache.load(inference_context) print(f"kv_cache loaded successfully.") self.block_num = inference_context.k_cache[0].size(1) #TODO ModelRunner 区别 # self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num) #@TODO add config if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM" or config.architectures[0] == "Qwen3NextForCausalLM": if not use_torch_npu: self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num) else: # npu donnot support flash attn self.model.init_wrapper() else: self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num) # self.args.use_cuda_graph代表是否使用图下沉 self.model_runner = get_or_create_model_runner(self.model, self.cache, self.device, self.args.use_cuda_graph, page_size = args.page_size) self.sampler = Sampler() self.query_manager = QueryManager(device = self.device, page_size = args.page_size) def sampling(self, forward_output: ForwardBatchOutput): generated_tokens = [] probs = [] for i in range(forward_output.num_batchs): logit = forward_output.logits[i] if hasattr(forward_output, "temperatures"): temperatures = forward_output.temperatures[i] else: temperatures = None if hasattr(forward_output, "top_ps"): top_ps = forward_output.top_ps[i] else: top_ps = None sample_options = SamplingOptions(logit.size(0), self.device, pretrained_config=self.model.generation_config, temperatures=temperatures, top_ps=top_ps) generated_token, prob=self.sampler(logit, sample_options) generated_tokens.append(generated_token.clone()) probs.append(prob.clone()) generated_tokens, probs = torch.cat(generated_tokens), torch.cat(probs, dim=0) return generated_tokens, probs def loop(self): next_batch = None while True: self.batch = next_batch if self.batch is not None: if use_torch_npu: batch_size = 0 for i in range(len(self.batch.decode_mini_batches)): batch_size += len(self.batch.decode_mini_batches[i]) # logger.debug(f"prefill batch: {len(self.batch.prefill_mini_batches)} decode batch: {len(self.batch.decode_mini_batches)} {batch_size} \n") self.model_runner.run_split(self.batch, self.query_manager) else: self.model_runner.run(self.batch, self.query_manager) if len(self.updates) > 0: for q in self.updates: if q.is_prefill == True: continue # print(f"Putting token {q.generated_token} into queue for query id: {q.id}") try: if use_torch_npu: if torch.distributed.get_rank() == 0: self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5) else: self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5) except queue.Full: pass#print("Queue is full after timeout; unable to put more items.") if use_torch_npu: if torch.distributed.get_rank() == 0: next_batch = self.sched_client.update_last_batch(self.updates) if next_batch.query_ids == []: next_batch = None self.pub_socket.send_pyobj(next_batch) else: next_batch = self.sub_socket.recv_pyobj() else: next_batch = self.sched_client.update_last_batch(self.updates) if next_batch.query_ids == []: next_batch = None self.pub_socket.send_pyobj(next_batch) if next_batch is not None: self.query_manager.add_query(next_batch) if self.batch is not None: self.model_runner.sync() # print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms") # if self.rank == 0: generated_tokens, probs = self.sampling( self.model_runner.output) self.updates = self.query_manager.update(self.batch) fill_generated_tokens(self.updates, generated_tokens, self.query_manager) else: self.updates = [] class BalanceServeThreadContext(ThreadContext): def get_local_messages(self): local_messages = [] for m in self.messages: local_messages.append({"role": m.role.value, "content": m.get_text_content()}) return local_messages def init_distributed(rank: int, world_size: int, tp_size: int, master_addr: str = os.getenv("MASTER_ADDR", "127.0.0.1"), master_port: int = os.getenv("MASTER_PORT", "29500"), backend: str = "hccl"): #TODO csx: 是否distribute 都只与NPU有关 os.environ["RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = str(master_port) local_rank, world_size = setup_model_parallel(tp=tp_size) return local_rank, world_size def run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event, rank=None, world_size=None): if use_torch_npu: init_distributed(rank, world_size, args.tp, backend="hccl") #TODO 同上 import torch.distributed as dist engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event) if args.use_cuda_graph: if 'npu' in engine.device: print(f"[WARMUP-NPU] start", flush=True) engine.model_runner.warmup_npu() else: engine.model_runner.warmup() else: print(f"[WARMUP-NPU] skip warmup, eager mode!", flush=True) if use_torch_npu: args.port += torch.distributed.get_rank() event.set() engine.loop() class BalanceServeInterface(BackendInterfaceBase): use_static_cache: bool = True model: Any tokenizer: AutoTokenizer cache: StaticCache generated_ids: torch.Tensor seq_length: int streamer: TextStreamer # thread_related last_request_id: Optional[str] = None ever_generated_ids: Set[int] = set() def __init__(self, args: ConfigArgs = default_args, input_args=None): self.args = input_args self.queue_map:dict[int,asyncio.Queue] = {} self.thread_map: dict[int, int] = {} processes = [] self.broadcast_endpoint = tempfile.NamedTemporaryFile(delete=False).name # @TODO add to config ctx = mp.get_context("spawn") self.token_queue = ctx.Queue(maxsize=1000) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True) self.sched_client = SchedulerClient(args.sched_port) self.streamer = TextStreamer(self.tokenizer) if use_torch_npu: world_size = str(os.getenv("WORLD_SIZE", self.args.tp)) if not isinstance(world_size, str): raise ValueError(f"world_size ({world_size}) must be str") start_events = [] kvcache_events = [] for rank in range(self.args.tp): if int(self.args.device[-1]) > 0: break start_event = ctx.Event() kvcache_event = ctx.Event() p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event, kvcache_event, rank, world_size)) p.start() processes.append(p) start_events.append(start_event) kvcache_events.append(kvcache_event) for evt in kvcache_events: evt.wait() self._engines = processes else: start_event = ctx.Event() kvcache_event = ctx.Event() p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event, kvcache_event)) p.start() processes.append(p) kvcache_event.wait() with tempfile.NamedTemporaryFile(delete=False) as temp_file: args.tp = input_args.tp pickle.dump(args, temp_file) temp_file_path = temp_file.name current_file = __file__ target_file = os.path.join(os.path.dirname(current_file), "..", "..", "balance_serve", "sched_rpc.py") target_file = os.path.normpath(target_file) log_path = os.path.join(args.log_dir, "rpc.log") log = open(log_path, "a") sched_process = subprocess.Popen( ["python3", target_file, "--config", temp_file_path], stdout=log, stderr=log ) print("sched_rpc started with PID:", sched_process.pid) def signal_handler(signum, frame): print(f"Received signal {signum}, shutting down...") cleanup() os._exit(0) def cleanup(): print("Cleaning up...") for p in processes: if p.is_alive(): print(f"Terminating subprocess {p.pid}") p.terminate() p.join() if sched_process and sched_process.poll() is None: print(f"Terminating sched_process {sched_process.pid}") sched_process.terminate() sched_process.wait() signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) if use_torch_npu: for evt in start_events: evt.wait() else: start_event.wait() def get_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None) -> tuple[float, float]: """Get sampling parameters and handle default values and edge cases""" if max_tokens is not None: max_completion_tokens = max_tokens if max_completion_tokens is None: max_completion_tokens = self.args.max_new_tokens else: max_completion_tokens = min(self.args.max_new_tokens, max_completion_tokens) if temperature is None: temperature = self.args.temperature if top_p is None: top_p = self.args.top_p if temperature == 0: temperature = 0.0001 if top_p == 0: top_p = 0.0001 return temperature, top_p, max_completion_tokens def run_queue_proxy(self): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(self.queue_proxy()) @asynccontextmanager async def lifespan(self, app: FastAPI): asyncio.create_task(self.queue_proxy()) yield async def queue_proxy(self): print("Queue Proxy Started") while True: try: query_id, token = self.token_queue.get_nowait() try: # query id might not be allocated yet self.queue_map[query_id].put_nowait(token) #print(f"Proxy Put token: {token} to queue for query id: {query_id}") except asyncio.QueueFull: #print(f"Queue for query id: {query_id} is full, waiting to put: {token}") await self.queue_map[query_id].put(token) except queue.Empty: # print("no new token") # await asyncio.sleep(1) await asyncio.sleep(0) def tokenize_prompt(self, prompt: str): input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device) return input_ids def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List): for m in messages: if m["role"] == "system": logger.warning(f'change {m["role"]} to user') m["role"] = "user" new_messages = [messages[0]] for m in messages[1:]: if m["role"] == "user" and new_messages[-1]["role"] == "user": logger.warning("merge two adjacent user messages") new_messages[-1]["content"] += '\n' + m["content"] else: new_messages.append(m) # input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True) # # drop token in chat template # if input_str.endswith('\n'): # input_str = input_str[:-len('\n')] input_ids = self.tokenizer.apply_chat_template(new_messages, add_generation_prompt=True, return_tensors="pt").to(self.args.device) return input_ids async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = 0, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None): profiler = Profiler() profiler.create_and_start_timer("tokenize") if isinstance(local_messages, List): input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) elif isinstance(local_messages, str): #local_messages = local_messages[0]['content'] input_ids = self.tokenize_prompt(local_messages) else: raise ValueError("local_messages should be List or str") if Config().user_force_think: token_thinks = torch.tensor([self.tokenizer.encode("\n",add_special_tokens=False)],device=input_ids.device) if not torch.equal(input_ids[0, -token_thinks.shape[-1]:], token_thinks[-1]): #TODO 此行新加的,考虑是否影响GPU input_ids = torch.cat( [input_ids, token_thinks], dim=1 ) logger.debug(f"get input ids of shape {input_ids.shape}") profiler.pause_timer("tokenize") profiler.create_and_start_timer("prefill") query_add = sched_ext.QueryAdd() query_add.query_token = input_ids[0].tolist() query_length = input_ids[0].shape[0] query_add.query_length = query_length profiler.set_counter("prefill", query_length) #@TODO add server stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")] query_add.stop_criteria = stop_criteria temperature, top_p, max_new_tokens = self.get_params(temperature, top_p, max_tokens, max_completion_tokens) query_add.sample_options.temperature = temperature if top_p == 0 or top_p is None: top_p = 0.0001 query_add.sample_options.top_p = top_p query_add.estimated_length = min(self.args.cache_lens, query_length+max_new_tokens) query_id = self.sched_client.add_query(query_add) queue = asyncio.Queue(maxsize=max_new_tokens) self.queue_map[query_id] = queue self.thread_map[thread_id] = query_id is_first_token = True async for token in chat_stream(self.queue_map[query_id], self.tokenizer): if is_first_token: is_first_token=False profiler.pause_timer("prefill") profiler.create_and_start_timer("decode") profiler.set_counter("decode", 0) if Config().user_force_think: think = '\n' print(think, end="",flush=True) yield think, None else: profiler.inc("decode") # TODO: 传入rank避免打印重复 yield token, None profiler.pause_timer("decode") report_last_time_performance(profiler) yield self.streamer.end(), None if profiler.get_counter('decode') >= max_new_tokens - 1: yield "", "length" else: yield "", "stop" yield RawUsage( tokenize_time = profiler.get_timer_sec('tokenize'), prefill_time = profiler.get_timer_sec('prefill'), decode_time = profiler.get_timer_sec('decode'), prefill_count = profiler.get_counter('prefill'), decode_count = profiler.get_counter('decode'), ) ================================================ FILE: archive/ktransformers/server/backend/interfaces/exllamav2.py ================================================ import sys, os from typing import AsyncIterator, Dict, Tuple import torch from ..args import ConfigArgs, default_args from ..base import BackendInterfaceBase, ThreadContext from ktransformers.server.schemas.assistants.runs import RunObject from ..args import * class ExllamaThreadContext(ThreadContext): def __init__(self, run: RunObject, args: ConfigArgs = default_args) -> None: super().__init__(run,args) def get_interface(self): return def get_local_messages(self): raise NotImplementedError class ExllamaInterface(BackendInterfaceBase): def __init__(self, args: ConfigArgs = ...): raise NotImplementedError def tokenize_prompt(self, prompt: str) -> torch.Tensor: raise NotImplementedError async def inference(self,local_messages,request_unique_id:Optional[str])->AsyncIterator: raise NotImplementedError ================================================ FILE: archive/ktransformers/server/backend/interfaces/ktransformers.py ================================================ import torch import torch.distributed as dist from torch import nn from torch.nn.attention import SDPBackend import asyncio from transformers import AutoTokenizer, AutoConfig, GenerationConfig from ktransformers.server.backend.interfaces.transformers import ( TransformersInterface, ConfigArgs, TransformersThreadContext, default_args, TextStreamer, ) import os try: import torch_npu use_npu = torch.npu.is_available() from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel except: use_npu = False from torch import nn from ktransformers.server.config.log import logger from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.models.custom_cache import StaticCache from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.local_chat import custom_models, default_optimize_rules from ktransformers.util.utils import get_device, get_all_used_cuda_device from ktransformers.util import utils from typing import Optional from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton from ktransformers.server.schemas.endpoints.chat import RawUsage from typing import Any, List, Optional, Set from ktransformers.server.config.config import Config warm_uped = False speculative_decoding = True # True -> verify by random accept ; False-> verify by token id global_acc_counts = 0 global_verify_counts = 0 ktransformer_rules_dir = ( os.path.dirname(os.path.abspath(__file__)) + "/../../../optimize/optimize_rules/" ) default_optimize_rules = { "DeepseekV2ForCausalLM": ktransformer_rules_dir + "DeepSeek-V2-Chat.yaml", "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat.yaml", "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml", "LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml", "MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml" } if use_npu: default_optimize_rules["DeepseekV3ForCausalLM"] = ktransformer_rules_dir + "DeepSeek-V3-Chat-npu.yaml" class KTransformersThreadContext(TransformersThreadContext): pass class KTransformersInterface(TransformersInterface): def __init__(self, args: ConfigArgs = default_args, input_args=None): self.args = input_args self.local_rank, self.world_size = setup_model_parallel(tp=self.args.tp) if use_npu and (utils.CUR_DEVICE is None): utils.CUR_DEVICE = f"npu:{torch.npu.current_device()}" self.args.device = utils.CUR_DEVICE self.args.device = f"npu:{torch.npu.current_device()}" torch.set_grad_enabled(False) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) try: generation_config = GenerationConfig.from_pretrained(args.model_dir) except: generation_config = GenerationConfig( max_length=args.max_new_tokens, temperature=args.temperature, top_p=args.top_p, do_sample=True ) torch.set_default_dtype(config.torch_dtype) if config.architectures[0] == "Qwen2MoeForCausalLM": config._attn_implementation = "flash_attention_2" config.backend_type = "ktransformers" config.chunk_size = self.args.chunk_size with torch.device("meta"): self.model = custom_models[config.architectures[0]](config) if input_args.optimize_config_path is not None: optimize_config_path = input_args.optimize_config_path elif default_args.optimize_config_path is None: optimize_config_path = default_optimize_rules[config.architectures[0]] else: optimize_config_path = args.optimize_config_path # print(optimize_config) gguf_path = args.gguf_path if gguf_path is None: gguf_path = input( "please input the path of your gguf file(gguf file in the dir containing input gguf file must all" " belong to current model):" ) optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config, q4_gguf_path=input_args.q4_gguf_path) #提前absorbed get_absort_weight(self.model, config) # utils.get_absort_weight(self.model, config) self.model.eval() self.model.generation_config = generation_config self.device_map = self.model.gguf_loader.tensor_device_map self.top_p = torch.tensor([[self.model.generation_config.top_p]], dtype = torch.float16, device = self.args.device) self.top_k = torch.tensor([[self.model.generation_config.top_k]], dtype = torch.int32, device = self.args.device) self.temperature = torch.tensor([[self.model.generation_config.temperature]], dtype = torch.float16, device = self.args.device) self.next_token_fake = torch.tensor([[1]], dtype=torch.int32, device = self.args.device) self.next_token_probs = torch.tensor([[1.0]], dtype=torch.float16, device = self.args.device) self.draft_model = None # logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}") self.cache = StaticCache( config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=self.device_map, dtype=self.model.dtype, ) # logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}") if self.model.generation_config.pad_token_id is None: self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id self.streamer = TextStreamer(self.tokenizer) self._infer_lock = asyncio.Lock() @torch.no_grad def decode_one_tokens(self): global warm_uped device_map = self.model.gguf_loader.tensor_device_map torch_device = get_device("blk.0.self_attn", device_map) torch_device = "cuda:0" if torch_device == "cuda" else torch_device torch.cuda.set_device(torch_device) if warm_uped and self.args.use_cuda_graph: if use_npu: from ktransformers.util.npu_graph_runner import get_or_create_runner, check_runner if check_runner(utils.get_current_device()): npu_graph_runner = get_or_create_runner(utils.get_current_device()) npu_graph_runner.init(self.args.batch_size, self.seq_length) self.cuda_graph_runner = npu_graph_runner utils._USE_NPU_GRAPH = True self.cuda_graph_runner.capture( self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, main_device=torch_device, return_dict=False, use_cache=True, ) if hasattr(self, "cuda_graph_runner"): inputs_embeds = self.model.model.embed_tokens(self.current_ids.to("cpu")).to(utils.get_current_device()) logits = self.cuda_graph_runner( inputs_embeds, self.active_cache_position.unsqueeze(0), self.active_cache_position )[0] self.cache.change_seq_length(1) torch.cuda.synchronize() logits = logits[0, -1, :] return self.logits_to_token(logits) else: if not hasattr(self, "cuda_graph_runner"): self.cuda_graph_runner = CUDAGraphRunner() self.cuda_graph_runner.capture( self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, main_device=torch_device, return_dict=False, use_cache=True, ) if hasattr(self, "cuda_graph_runner"): logits = self.cuda_graph_runner( self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position ) self.cache.change_seq_length(1) torch.cuda.synchronize() logits = logits[0, -1, :] return self.logits_to_token(logits) if self.args.use_cuda_graph: warm_uped = True if self.use_static_cache: logits = self.model( self.current_ids.to(torch_device), cache_position=self.active_cache_position, past_key_values=self.cache, return_dict=False, use_cache=True, is_prefill=False, )[0] else: logits = self.model(self.current_ids, return_dict=False, is_prefill=False)[0] logits = logits[0, -1, :] return self.logits_to_token(logits) @torch.no_grad def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None): input_ids_length = input_ids.shape[-1] if max_tokens is not None: max_completion_tokens = max_tokens if max_completion_tokens is None: max_new_tokens = self.args.max_new_tokens else: max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens) if(input_ids_length >= self.args.cache_lens): logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}") self.seq_length = input_ids_length return logger.debug(f"input_ids: {input_ids.shape}") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") device = "cuda:0" if device == "cuda" else device if is_new: self.ever_generated_ids.clear() same_prefix = 0 # flat_input_ids = input_ids.flatten() if getattr(self, 'generated_ids', None) is None: self.generated_ids = torch.zeros( self.args.batch_size, input_ids.shape[-1] + max_new_tokens + 1, dtype=torch.int, device=self.args.device, ) self.seq_length = 1 logger.debug(f"same prefix len: {same_prefix}") self.cache.remove_suffix(same_prefix) self.seq_length = same_prefix self.cache.position[0] = same_prefix self.generated_ids = self.generated_ids[..., :same_prefix] input_ids = input_ids[..., same_prefix:] input_ids_length = input_ids.shape[-1] self.ever_generated_ids.clear() self.profiler.set_counter("prefill", input_ids_length) logger.debug(f"input_ids: {input_ids.shape}") logger.debug(f"generate_ids: {self.generated_ids.shape}") former_seq_length = self.seq_length self.seq_length += input_ids_length expected_length = min(self.seq_length + max_new_tokens + 1, self.args.cache_lens) delta_length = expected_length - self.generated_ids.shape[-1] if delta_length > 0: new_generate_ids = torch.zeros( self.args.batch_size, delta_length, dtype=torch.int, device=utils.get_current_device() ) self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) else: logger.warning(f"seq_length bigger than cache_lens, killed") exit(0) logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") cache_position = torch.arange(former_seq_length, self.seq_length, device=device) self.generated_ids[:, cache_position] = input_ids.to(utils.get_current_device()).to(torch.int) if not (type(self) is TransformersInterface): input_ids = input_ids.to("cpu") def chunk_prefill(input_ids, cache_position): inputs_embeds = self.model.model.embed_tokens(input_ids).to(device) torch.cuda.set_device(device) if flashinfer_enabled: MLAWrapperSingleton.need_plan_all() if self.use_static_cache: logits = self.model( inputs_embeds=inputs_embeds, cache_position=cache_position, past_key_values=self.cache, return_dict=False, use_cache=True, is_prefill=True, )[0] else: logits = self.model(inputs_embeds=inputs_embeds, return_dict=False, is_prefill=True)[0] return logits if not use_npu: chunk_start = 0 while chunk_start < input_ids_length: chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length) if self.cache != None: self.cache.cur_idx=cache_position[chunk_start:chunk_end] logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end]) chunk_start += self.args.chunk_size if flashinfer_enabled: MLAWrapperSingleton.reset_buffer() self.prepare_logits_wrapper(input_ids, device, temperature, top_p) next_token = self.logits_to_token(logits[0, -1, :]) self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1 yield self.append_new_tokens(next_token) return def prefill_wrapper(prof=None): chunk_start = 0 while chunk_start < input_ids_length: chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length) if self.cache != None: self.cache.cur_idx = cache_position[chunk_start:chunk_end] logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end]) chunk_start += self.args.chunk_size if prof is not None: prof.step() if prof is not None: prof.stop() if logits is None: raise ValueError('logits cannot be None') return logits global WARM_UP_SKIP_CNT prof_prefill = os.environ["PROF_PREFILL"] if "PROF_PREFILL" in os.environ else "0" if prof_prefill == "1": experimental_config = torch_npu.profiler._ExperimentalConfig( aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False ) with torch_npu.profiler.profile( activities=[ torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU ], schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0), on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prefill_prof_lm_head"), record_shapes=True, profile_memory=True, with_stack=False, with_flops=False, with_modules=False, experimental_config=experimental_config) as prof: logits = prefill_wrapper(prof) else: logits = prefill_wrapper() if flashinfer_enabled: MLAWrapperSingleton.reset_buffer() self.prepare_logits_wrapper(input_ids, device, temperature, top_p) next_token = self.logits_to_token(logits[0, -1, :]) self.cache.position[0] = self.seq_length yield self.append_new_tokens(next_token) @property def active_cache_position(self): device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") return torch.tensor([self.seq_length - 1], device=device) def sampling(self, logits, do_sample): if do_sample: cur_len = logits.shape[1] logits = logits / self.temperature torch.manual_seed(0) probs = logits.view(-1, cur_len, self.model.config.vocab_size) probs = torch.softmax(probs, dim=-1).half() next_token = self.next_token_fake if self.draft_model is None or not speculative_decoding: torch_npu._npu_topk_topp_sampling(probs[:, 0, :], self.top_k, self.top_p, next_token, self.next_token_probs) for i in range(1,cur_len): ith_token = torch.empty_like(self.next_token_fake) torch_npu._npu_topk_topp_sampling(probs[:, i, :], self.top_k, self.top_p, ith_token, self.next_token_probs) next_token = torch.cat((next_token, ith_token), dim=-1) else: next_token = torch.argmax(logits, dim=-1) probs = torch.softmax(logits, dim=-1) return next_token, probs def verify_by_tokenid(self, main_token: int, draft_token: int): return main_token, main_token == draft_token def verify_speculative_decoding(self, main_prob: torch.Tensor, draft_prob: torch.Tensor, draft_token: int, p: float): #assert draft_prob[draft_token] == p q = main_prob[draft_token] #p = draft_prob[draft_token] accept_prob = min(1.0, (q / p).item()) if torch.rand(()) <= accept_prob: return draft_token, True else: # Compute the adjusted distribution for resampling new_prob = main_prob - draft_prob new_prob = torch.clamp(new_prob, min=0.0) new_prob /= new_prob.sum() # Sample a new token from the adjusted distribution token = torch.multinomial(new_prob, 1).item() return token, False def logits_to_token(self, logits: torch.Tensor): if self.model.generation_config.do_sample: logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1)) probs = torch.nn.functional.softmax(logits, dim=-1) last = torch.multinomial(probs, num_samples=1) else: logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1)) probs = torch.nn.functional.softmax(logits, dim=-1) _, last = torch.topk(probs, k=1, dim=-1) last = last.item() self.ever_generated_ids.add(last) return last async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None): async with self._infer_lock: async for v in super().inference(local_messages, thread_id, temperature, top_p, max_tokens, max_completion_tokens): yield v # return this inference raw usage yield RawUsage( tokenize_time = self.profiler.get_timer_sec('tokenize'), prefill_time = self.profiler.get_timer_sec('prefill'), decode_time = self.profiler.get_timer_sec('decode'), prefill_count = self.profiler.get_counter('prefill'), decode_count = self.profiler.get_counter('decode'), ) def sync_inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None) -> str: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: async def run_async(): result = [] async for chunk in self.inference(local_messages, thread_id, temperature, top_p): pass return "" return loop.run_until_complete(run_async()) finally: loop.close() ================================================ FILE: archive/ktransformers/server/backend/interfaces/transformers.py ================================================ from typing import Any, List, Optional, Set import re import json import uuid try: import torch_npu use_npu = torch.npu.is_available() except: use_npu = False from transformers import ( LlamaTokenizer, AutoTokenizer, AutoConfig, LlamaForCausalLM, GenerationConfig, StaticCache, AutoModelForCausalLM, BitsAndBytesConfig, LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, MinPLogitsWarper, TypicalLogitsWarper, EpsilonLogitsWarper, EtaLogitsWarper, ) from ktransformers.server.config.config import Config from ktransformers.server.schemas.base import ObjectID from ktransformers.server.utils.multi_timer import Profiler from torch.nn.attention import SDPBackend import torch import torch.distributed as dist from ktransformers.util import utils import sys, os from ..base import ThreadContext, BackendInterfaceBase from ktransformers.server.config.log import logger from ..args import ConfigArgs, default_args from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton from ktransformers.util import utils # This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py class TextStreamer: def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): self.tokenizer = tokenizer self.skip_prompt = skip_prompt self.decode_kwargs = decode_kwargs # variables used in the streaming process self.token_cache = [] self.print_len = 0 self.next_tokens_are_prompt = True def reset(self): self.token_cache = [] self.print_len = 0 def put(self, value) -> Optional[str]: """ Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. """ if not isinstance(value, int): raise ValueError("TextStreamer only supports batch size 1, and int type input") if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False return None # Add the new token to the cache and decodes the entire thing. self.token_cache.append(value) text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs) # After the symbol for a new line, we flush the cache. if text.endswith("\n"): printable_text = text[self.print_len :] self.reset() # If the last token is a CJK character, we print the characters. elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): printable_text = text[self.print_len :] self.print_len += len(printable_text) # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, # which may change with the subsequent token -- there are probably smarter ways to do this!) else: printable_text = text[self.print_len : text.rfind(" ") + 1] self.print_len += len(printable_text) return printable_text def end(self) -> Optional[str]: """Flushes any remaining cache and prints a newline to stdout.""" # Flush the cache, if it exists if len(self.token_cache) > 0: text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs) printable_text = text[self.print_len :] self.reset() else: printable_text = "" self.next_tokens_are_prompt = True return printable_text def _is_chinese_char(self, cp): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) # # Note that the CJK Unicode block is NOT all Japanese and Korean characters, # despite its name. The modern Korean Hangul alphabet is a different block, # as is Japanese Hiragana and Katakana. Those alphabets are used to write # space-separated words, so they are not treated specially and handled # like the all of the other languages. if ( (cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) # or (cp >= 0x20000 and cp <= 0x2A6DF) # or (cp >= 0x2A700 and cp <= 0x2B73F) # or (cp >= 0x2B740 and cp <= 0x2B81F) # or (cp >= 0x2B820 and cp <= 0x2CEAF) # or (cp >= 0xF900 and cp <= 0xFAFF) or (cp >= 0x2F800 and cp <= 0x2FA1F) # ): # return True return False class TransformersThreadContext(ThreadContext): def get_local_messages(self): local_messages = [] for m in self.messages: local_messages.append({"role": m.role.value, "content": m.get_text_content()}) return local_messages class TransformersInterface(BackendInterfaceBase): use_static_cache: bool = True model: Any tokenizer: AutoTokenizer cache: StaticCache generated_ids: torch.Tensor seq_length: int streamer: TextStreamer # thread_related last_request_id: Optional[str] = None ever_generated_ids: Set[int] = set() def __init__(self, args: ConfigArgs = default_args): self.args = args self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir) self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True) # logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}") self.cache = StaticCache( config=self.model.config, max_batch_size=args.batch_size, max_cache_len=args.cache_lens, device=args.device, dtype=self.model.dtype, ) # logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}") self.streamer = TextStreamer(self.tokenizer) @property def current_ids(self): return self.generated_ids[:, self.seq_length - 1].unsqueeze(1) @property def active_cache_position(self): return torch.tensor([self.seq_length - 1], device=self.args.device) def tokenize_prompt(self, prompt: str): input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device) return input_ids def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List): for m in messages: if m["role"] == "system": logger.warning(f'change {m["role"]} to user') m["role"] = "user" new_messages = [messages[0]] for m in messages[1:]: if m["role"] == "user" and new_messages[-1]["role"] == "user": logger.warning("merge two adjacent user messages") new_messages[-1]["content"] += '\n' + m["content"] else: new_messages.append(m) # if (self.last_request_id is not None) and self.last_request_id == thread_id: # input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors="pt",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors="pt").to(self.args.device) # else: # input_ids = self.tokenizer.apply_chat_template( # new_messages, return_tensors="pt", add_generation_prompt=True # ).to(self.args.device) # input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True) # drop token in chat template # if input_str.endswith('\n'): # input_str = input_str[:-len('\n')] # input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device) input_ids = self.tokenizer.apply_chat_template(new_messages, add_generation_prompt=True, return_tensors="pt").to(self.args.device) if (self.last_request_id is not None) and self.last_request_id == thread_id: x = self.generated_ids[:,:self.seq_length] y = input_ids[:,:self.seq_length] # We can only hope that the input_ids are the same unequal_mask = torch.ne(x,y) unequal_positions = torch.nonzero(unequal_mask) num_unequal_elements = unequal_mask.sum().item() logger.warning(f'num_unequal_elements: {num_unequal_elements}') input_ids = input_ids[:,self.seq_length:] logger.debug(f"get input ids of shape {input_ids.shape}") return input_ids def append_new_tokens(self, new_tokens: int) -> Optional[str]: self.generated_ids[0, self.seq_length] = new_tokens self.seq_length += 1 self.cache.position[0] += 1 return self.streamer.put(new_tokens) @staticmethod def tf_logits_warper(generation_config): """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances used for multinomial sampling. """ # instantiate warpers list warpers = LogitsProcessorList() # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) if generation_config.num_beams > 1: if isinstance(generation_config._eos_token_tensor, list): min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 elif isinstance(generation_config._eos_token_tensor, torch.Tensor): min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1 else: min_tokens_to_keep = 2 else: min_tokens_to_keep = 1 # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` if generation_config.temperature is not None and generation_config.temperature != 1.0: warpers.append(TemperatureLogitsWarper(generation_config.temperature)) if generation_config.top_k is not None and generation_config.top_k != 0: warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) if generation_config.top_p is not None and generation_config.top_p < 1.0: warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)) if generation_config.min_p is not None: # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084) warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)) if generation_config.typical_p is not None and generation_config.typical_p < 1.0: warpers.append( TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) ) if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: warpers.append( EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep) ) if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: warpers.append( EtaLogitsWarper( epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device ) ) # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: warpers.append(LogitNormalization()) return warpers def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None): if temperature is None or temperature == 0: temperature = self.model.generation_config.temperature if top_p is None: top_p = self.model.generation_config.top_p if top_p == 0: top_p = 0.0001 # keep sampler the same as local_chat generation_config, model_kwargs = self.model._prepare_generation_config( None, max_length=self.args.max_new_tokens, do_sample=True, top_k=self.args.top_k, top_p=top_p, temperature=temperature, repetition_penalty=self.args.repetition_penalty # change this to modify generate config ) self.inputs = inputs self.logits_warper = self.tf_logits_warper(generation_config) def logits_to_token(self, logits: torch.Tensor): logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1)) probs = torch.nn.functional.softmax(logits, dim=-1) sample = True if sample: last = torch.multinomial(probs, num_samples=1) else: _, last = torch.topk(probs, k=1, dim=-1) last = last.item() self.ever_generated_ids.add(last) return last def decode_one_tokens(self): if self.use_static_cache: logits = self.model( self.current_ids, cache_position=self.active_cache_position, past_key_values=self.cache, return_dict=False, use_cache=True, )[0] else: logits = self.model(self.current_ids, return_dict=False)[0] logits = logits[0, -1, :] return self.logits_to_token(logits) @torch.no_grad def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None): input_ids_length = input_ids.shape[-1] logger.debug(f"input_ids: {input_ids.shape}") if max_tokens is not None: max_completion_tokens = max_tokens if max_completion_tokens is None: max_new_tokens = self.args.max_new_tokens else: max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens) if is_new: self.ever_generated_ids.clear() same_prefix = 0 flat_input_ids = input_ids.flatten() if getattr(self, 'generated_ids', None) is None: self.generated_ids = torch.zeros( self.args.batch_size, input_ids.shape[-1] + max_new_tokens + 1, dtype=torch.int, device=self.args.device, ) self.seq_length = 1 flat_prev_ids = self.generated_ids.flatten() for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1): if flat_input_ids[i] == flat_prev_ids[i]: same_prefix += 1 else: break logger.debug(f"same prefix len: {same_prefix}") self.cache.remove_suffix(same_prefix) self.seq_length = same_prefix self.generated_ids = self.generated_ids[..., :same_prefix] input_ids = input_ids[..., same_prefix:] input_ids_length = input_ids.shape[-1] self.ever_generated_ids.clear() self.profiler.set_counter("prefill", input_ids_length) logger.debug(f"input_ids: {input_ids.shape}") logger.debug(f"generate_ids: {self.generated_ids.shape}") former_seq_length = self.seq_length self.seq_length += input_ids_length expected_length = self.seq_length + max_new_tokens + 1 delta_length = expected_length - self.generated_ids.shape[-1] if delta_length > 0: new_generate_ids = torch.zeros( self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device ) self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device) self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) device = input_ids.device if not (type(self) is TransformersInterface): input_ids = input_ids.to("cpu") inputs_embeds = self.model.model.embed_tokens(input_ids).to(device) if self.use_static_cache: logits = self.model( inputs_embeds=inputs_embeds, cache_position=cache_position, past_key_values=self.cache, return_dict=False, use_cache=True, )[0] else: logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] self.prepare_logits_wrapper(input_ids, device, temperature, top_p) next_token = self.logits_to_token(logits[0, -1, :]) yield self.append_new_tokens(next_token) @torch.no_grad def generate(self): self.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length) - 1 logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}") if(self.max_new_tokens <= 0): logger.warning("max_new_tokens is less than 0") yield self.streamer.end(), "length" return logger.info(f"max_new_tokens: {self.max_new_tokens}") self.profiler.set_counter("decode", 0) for i in range(1, self.max_new_tokens): with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): if flashinfer_enabled: MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, None, num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size, sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) next_token = self.decode_one_tokens() self.profiler.inc("decode") if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token): yield self.streamer.end(), None yield "", "stop" assert self.args.batch_size == 1 break yield self.append_new_tokens(next_token), None else: # for's else, if output get max new tokens yield self.streamer.end(), None yield "", "length" if self.args.use_cuda_graph: utils._USE_NPU_GRAPH = False from ktransformers.util.npu_graph_runner import get_or_create_runner npu_graph_runner = get_or_create_runner(utils.get_current_device()) npu_graph_runner.destroy() def check_is_new(self, thread_id: str): if not self.use_static_cache: return True if self.last_request_id is None: self.last_request_id = thread_id return True else: if self.last_request_id == thread_id: return False else: self.last_request_id = thread_id return True async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None): self.streamer.reset() self.profiler.create_and_start_timer("tokenize") torch.distributed.barrier() rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() tp_size = utils.get_tensor_parallel_size() if isinstance(local_messages, List): input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) elif isinstance(local_messages, str): #local_messages = local_messages[0]['content'] input_ids = self.tokenize_prompt(local_messages) #input_ids = torch.tensor([[6366]], device=input_ids.device) else: raise ValueError("local_messages should be List or str") if tp_size == world_size and tp_size > 1: torch.distributed.barrier() input_size = torch.tensor([input_ids.size(1)], dtype=torch.int64, device=utils.CUR_DEVICE) all_input_sizes = [torch.zeros_like(input_size) for _ in range(world_size)] dist.all_gather(all_input_sizes, input_size) max_input_size = max([size.item() for size in all_input_sizes]) padded_input_ids = torch.zeros(1, max_input_size, dtype=input_ids.dtype, device=utils.CUR_DEVICE) padded_input_ids[0, :input_ids.size(1)] = input_ids[0] all_padded_inputs = [torch.zeros_like(padded_input_ids) for _ in range(world_size)] dist.all_gather(all_padded_inputs, padded_input_ids) original_size = all_input_sizes[0].item() input_ids = all_padded_inputs[0][:, :original_size] if Config().user_force_think: token_thinks = torch.tensor([self.tokenizer.encode("\n",add_special_tokens=False)],device=input_ids.device) if not torch.equal(input_ids[0, -token_thinks.shape[-1]:], token_thinks[-1]): input_ids = torch.cat( [input_ids, token_thinks], dim=1 ) self.profiler.pause_timer("tokenize") self.profiler.create_and_start_timer("prefill") if Config().user_force_think: think = '\n' print(think, end="",flush=True) yield think, None for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p, max_tokens, max_completion_tokens): # output think token after prefill done if t is not None: print(t, end="",flush=True) yield t, None self.profiler.pause_timer("prefill") self.profiler.create_and_start_timer("decode") for t, finish_reason in self.generate(): if t is not None: if tp_size == world_size: if rank == 0: print(t, end="", flush=True) else: print(t, end="",flush=True) yield t, finish_reason if tp_size == world_size: if rank == 0: print("") self.profiler.pause_timer("decode") self.report_last_time_performance() else: print("") self.profiler.pause_timer("decode") self.report_last_time_performance() ================================================ FILE: archive/ktransformers/server/balance_serve/inference/__init__.py ================================================ ================================================ FILE: archive/ktransformers/server/balance_serve/inference/config.py ================================================ ''' Date: 2024-11-07 07:30:16 LastEditors: djw LastEditTime: 2024-11-15 14:23:26 ''' import math from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn from torch import Tensor from torch.nn import functional as F import yaml import json from typing import Optional model_runner_dict = dict() class ModelConfig: vocab_size: int = 32000 n_layer: int = 1 n_head: int = 32 dim: int = 4096 intermediate_size: int = 18944 n_local_heads: int = 8 head_dim: int = 128 rope_base: float = 1000000.0 norm_eps: float = 1e-06 rope_scaling: Optional[dict] = None rms_norm_eps: float = 1e-6 hidden_act: str = "silu" model_path: str gguf_path: str optimize_rule_path: str speculative_rule_path: str # quantize config quant_algorithm: Optional[str] = None quant_group_size: Optional[int] = None quant_num_bits: Optional[int] = None json_key_map = { "vocab_size": "vocab_size", "n_layer": "num_hidden_layers", "n_head": "num_attention_heads", "dim": "hidden_size", "intermediate_size": "intermediate_size", "n_local_heads": "num_key_value_heads", "rope_base": "rope_theta", "norm_eps": "norm_eps", "rms_norm_eps": "rms_norm_eps", "hidden_act": "hidden_act", } def __init__(self, config): self.model_path = config["model"]["model_path"] self.gguf_path = config["model"]["gguf_path"] self.optimize_rule_path = config["model"]["optimize_rule_path"] if "speculative_rule_path" in config["model"]: self.speculative_rule_path = config["model"]["speculative_rule_path"] self.speculative_gguf_path = config["model"]["speculative_gguf_path"] self.speculative_model_path = config["model"]["speculative_model_path"] self.quant_algorithm = config["model"]["quant"]["algorithm"] self.quant_group_size = config["model"]["quant"]["group_size"] self.quant_num_bits = config["model"]["quant"]["num_bits"] self.load_config() self.n_layer = config["model"]["n_layers"] def load_config(self): config_file = f"{self.model_path}/config.json" try: with open(config_file, "r") as f: config_data = json.load(f) except FileNotFoundError: raise FileNotFoundError(f"Configuration file not found at {config_file}") for attr, json_key in self.json_key_map.items(): if json_key in config_data: setattr(self, attr, config_data[json_key]) else: setattr(self, attr, getattr(self, attr)) class ParallelConfig: def __init__( self, config, ) -> None: self.pipeline_parallel_size = config["parallel"]["pp"] self.tensor_parallel_size = config["parallel"]["tp"] self.disable_custom_all_reduce = config["parallel"]["disable_custom_all_reduce"] self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size class AttnConfig: page_size: int = 256 block_num: int = 32 max_batch_token : int = 256 max_batch_size: int = 32 def __init__(self, config): self.page_size = config["attn"]["page_size"] self.block_num = config["attn"]["block_num"] self.max_batch_token = config["attn"]["max_batch_token"] self.max_batch_size = config["attn"]["max_batch_size"] class SamplerConfig(): # Batched sampling params temperatures: float is_all_greedy: bool def __init__(self, config): self.temperatures = config["sample"]["temperature"] self.is_all_greedy = True def load_yaml_config(file_path): with open(file_path, "r") as f: return yaml.safe_load(f) class LLMConfig: model_config: ModelConfig parallel_config: ParallelConfig attn_config: AttnConfig sample_config: SamplerConfig config_file: str def __init__(self, config_file): self.config_file = config_file config = load_yaml_config(config_file) self.model_config = ModelConfig(config) self.parallel_config = ParallelConfig(config) self.attn_config = AttnConfig(config) self.sample_config = SamplerConfig(config) ================================================ FILE: archive/ktransformers/server/balance_serve/inference/distributed/__init__.py ================================================ from .communication_op import * from .parallel_state import * from .utils import * ================================================ FILE: archive/ktransformers/server/balance_serve/inference/distributed/communication_op.py ================================================ """ Date: 2024-12-11 06:02:42 LastEditors: djw LastEditTime: 2024-12-12 09:52:06 """ from typing import Any, Dict, Optional, Union import torch import torch.distributed from .parallel_state import get_tp_group def tensor_model_parallel_all_reduce(input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor: """All-reduce the input tensor across model parallel group.""" return get_tp_group().all_reduce(input_, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap) def tensor_model_parallel_all_gather( input_: torch.Tensor, dim: int = -1 ) -> torch.Tensor: """All-gather the input tensor across model parallel group.""" return get_tp_group().all_gather(input_, dim) def tensor_model_parallel_gather( input_: torch.Tensor, dst: int = 0, dim: int = -1 ) -> Optional[torch.Tensor]: """Gather the input tensor across model parallel group.""" return get_tp_group().gather(input_, dst, dim) def broadcast_tensor_dict( tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0 ): if not torch.distributed.is_initialized(): return tensor_dict return get_tp_group().broadcast_tensor_dict(tensor_dict, src) ================================================ FILE: archive/ktransformers/server/balance_serve/inference/distributed/cuda_wrapper.py ================================================ """This file is a pure Python wrapper for the cudart library. It avoids the need to compile a separate shared library, and is convenient for use when we just need to call a few functions. """ import ctypes from dataclasses import dataclass from typing import Any, Dict, List, Optional # this line makes it possible to directly load `libcudart.so` using `ctypes` import torch # noqa # === export types and functions from cudart to Python === # for the original cudart definition, please check # https://docs.nvidia.com/cuda/cuda-runtime-api/index.html cudaError_t = ctypes.c_int cudaMemcpyKind = ctypes.c_int class cudaIpcMemHandle_t(ctypes.Structure): _fields_ = [("internal", ctypes.c_byte * 128)] @dataclass class Function: name: str restype: Any argtypes: List[Any] def find_loaded_library(lib_name) -> Optional[str]: """ According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, the file `/proc/self/maps` contains the memory maps of the process, which includes the shared libraries loaded by the process. We can use this file to find the path of the a loaded library. """ # noqa found = False with open("/proc/self/maps") as f: for line in f: if lib_name in line: found = True break if not found: # the library is not loaded in the current process return None # if lib_name is libcudart, we need to match a line with: # address /path/to/libcudart-hash.so.11.0 start = line.index("/") path = line[start:].strip() filename = path.split("/")[-1] assert filename.rpartition(".so")[0].startswith(lib_name), \ f"Unexpected filename: {filename} for library {lib_name}" return path class CudaRTLibrary: exported_functions = [ # ​cudaError_t cudaSetDevice ( int device ) Function("cudaSetDevice", cudaError_t, [ctypes.c_int]), # cudaError_t cudaDeviceSynchronize ( void ) Function("cudaDeviceSynchronize", cudaError_t, []), # ​cudaError_t cudaDeviceReset ( void ) Function("cudaDeviceReset", cudaError_t, []), # const char* cudaGetErrorString ( cudaError_t error ) Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) Function("cudaMalloc", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), # ​cudaError_t cudaFree ( void* devPtr ) Function("cudaFree", cudaError_t, [ctypes.c_void_p]), # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) Function("cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]), # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa Function("cudaMemcpy", cudaError_t, [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind ]), # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa Function("cudaIpcGetMemHandle", cudaError_t, [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]), # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa Function("cudaIpcOpenMemHandle", cudaError_t, [ ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint ]), ] # class attribute to store the mapping from the path to the library # to avoid loading the same library multiple times path_to_library_cache: Dict[str, Any] = {} # class attribute to store the mapping from library path # to the corresponding dictionary path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} def __init__(self, so_file: Optional[str] = None): if so_file is None: so_file = find_loaded_library("libcudart") assert so_file is not None, \ "libcudart is not loaded in the current process" if so_file not in CudaRTLibrary.path_to_library_cache: lib = ctypes.CDLL(so_file) CudaRTLibrary.path_to_library_cache[so_file] = lib self.lib = CudaRTLibrary.path_to_library_cache[so_file] if so_file not in CudaRTLibrary.path_to_dict_mapping: _funcs = {} for func in CudaRTLibrary.exported_functions: f = getattr(self.lib, func.name) f.restype = func.restype f.argtypes = func.argtypes _funcs[func.name] = f CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file] def CUDART_CHECK(self, result: cudaError_t) -> None: if result != 0: error_str = self.cudaGetErrorString(result) raise RuntimeError(f"CUDART error: {error_str}") def cudaGetErrorString(self, error: cudaError_t) -> str: return self.funcs["cudaGetErrorString"](error).decode("utf-8") def cudaSetDevice(self, device: int) -> None: self.CUDART_CHECK(self.funcs["cudaSetDevice"](device)) def cudaDeviceSynchronize(self) -> None: self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]()) def cudaDeviceReset(self) -> None: self.CUDART_CHECK(self.funcs["cudaDeviceReset"]()) def cudaMalloc(self, size: int) -> ctypes.c_void_p: devPtr = ctypes.c_void_p() self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size)) return devPtr def cudaFree(self, devPtr: ctypes.c_void_p) -> None: self.CUDART_CHECK(self.funcs["cudaFree"](devPtr)) def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None: self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count)) def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int) -> None: cudaMemcpyDefault = 4 kind = cudaMemcpyDefault self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind)) def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: handle = cudaIpcMemHandle_t() self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"]( ctypes.byref(handle), devPtr)) return handle def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: cudaIpcMemLazyEnablePeerAccess = 1 devPtr = ctypes.c_void_p() self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"]( ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)) return devPtr ================================================ FILE: archive/ktransformers/server/balance_serve/inference/distributed/custom_all_reduce.py ================================================ import ctypes from contextlib import contextmanager from typing import List, Optional, Union import torch import torch.distributed as dist from torch.distributed import ProcessGroup import server.envs as envs from server.inference.distributed.cuda_wrapper import CudaRTLibrary from server.inference.distributed.custom_all_reduce_utils import gpu_p2p_access_check from server.inference.distributed.parallel_state import in_the_same_node_as from server.inference.platforms import current_platform from server.utils import cuda_device_count_stateless import vLLMCustomAllreduce try: vLLMCustomAllreduce.meta_size() custom_ar = True except Exception: # For AMD GPUs and CPUs custom_ar = False def _can_p2p(rank: int, world_size: int) -> bool: for i in range(world_size): if i == rank: continue if envs.VLLM_SKIP_P2P_CHECK: print("Skipping P2P check and trusting the driver's P2P report.") return torch.cuda.can_device_access_peer(rank, i) if not gpu_p2p_access_check(rank, i): return False return True def is_weak_contiguous(inp: torch.Tensor): return inp.is_contiguous() or ( inp.storage().nbytes() - inp.storage_offset() * inp.element_size() == inp.numel() * inp.element_size() ) class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] # max_size: max supported allreduce size def __init__( self, group: ProcessGroup, device: Union[int, str, torch.device], max_size=8192 * 1024, ) -> None: """ Args: group: the process group to work on. If None, it will use the default process group. device: the device to bind the CustomAllreduce to. If None, it will be bind to f"cuda:{local_rank}". It is the caller's responsibility to make sure each communicator is bind to a unique device, and all communicators in this group are in the same node. """ self._IS_CAPTURING = False self.disabled = True if not custom_ar: # disable because of missing custom allreduce library # e.g. in a non-cuda environment return self.group = group assert ( dist.get_backend(group) != dist.Backend.NCCL ), "CustomAllreduce should be attached to a non-NCCL group." if not all(in_the_same_node_as(group, source_rank=0)): # No need to initialize custom allreduce for multi-node case. print( "Custom allreduce is disabled because this process group" " spans across nodes." ) return rank = dist.get_rank(group=self.group) world_size = dist.get_world_size(group=self.group) if world_size == 1: # No need to initialize custom allreduce for single GPU case. return if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: print( "Custom allreduce is disabled due to an unsupported world" " size: %d. Supported world sizes: %s. To silence this " "warning, specify disable_custom_all_reduce=True explicitly.", world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES), ) return if isinstance(device, int): device = torch.device(f"cuda:{device}") elif isinstance(device, str): device = torch.device(device) # now `device` is a `torch.device` object assert isinstance(device, torch.device) self.device = device cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices: device_ids = list(map(int, cuda_visible_devices.split(","))) else: device_ids = list(range(cuda_device_count_stateless())) physical_device_id = device_ids[device.index] tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") gather_list = [ torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size) ] dist.all_gather(gather_list, tensor, group=self.group) physical_device_ids = [t.item() for t in gather_list] # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink assert current_platform.is_cuda() from server.inference.platforms.cuda import CudaPlatform cuda_platform: CudaPlatform = current_platform full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids) if world_size > 2 and not full_nvlink: print( "Custom allreduce is disabled because it's not supported on" " more than two PCIe-only GPUs. To silence this warning, " "specify disable_custom_all_reduce=True explicitly." ) return # test P2P capability, this checks software/cudaruntime support # this is expensive to compute at the first time # then we cache the result if not _can_p2p(rank, world_size): print( "Custom allreduce is disabled because your platform lacks " "GPU P2P capability or P2P test failed. To silence this " "warning, specify disable_custom_all_reduce=True explicitly." ) return self.disabled = False # Buffers memory are owned by this Python class and passed to C++. # Meta data composes of two parts: meta data for synchronization and a # temporary buffer for storing intermediate allreduce results. self.meta_ptrs = self.create_shared_buffer( vLLMCustomAllreduce.meta_size() + max_size, group=group ) # This is a pre-registered IPC buffer. In eager mode, input tensors # are first copied into this buffer before allreduce is performed self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) # This is a buffer for storing the tuples of pointers pointing to # IPC buffers from all ranks. Each registered tuple has size of # 8*world_size bytes where world_size is at most 8. Allocating 8MB # is enough for 131072 such tuples. The largest model I've seen only # needs less than 10000 of registered tuples. self.rank_data = torch.empty( 8 * 1024 * 1024, dtype=torch.uint8, device=self.device ) self.max_size = max_size self.rank = rank self.world_size = world_size self.full_nvlink = full_nvlink self._ptr = vLLMCustomAllreduce.init_custom_ar( self.meta_ptrs, self.rank_data, rank, self.full_nvlink ) vLLMCustomAllreduce.register_buffer(self._ptr, self.buffer_ptrs) @staticmethod def create_shared_buffer( size_in_bytes: int, group: Optional[ProcessGroup] = None ) -> List[int]: """ Creates a shared buffer and returns a list of pointers representing the buffer on all processes in the group. """ lib = CudaRTLibrary() pointer = lib.cudaMalloc(size_in_bytes) handle = lib.cudaIpcGetMemHandle(pointer) world_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) handles = [None] * world_size dist.all_gather_object(handles, handle, group=group) pointers: List[int] = [] for i, h in enumerate(handles): if i == rank: pointers.append(pointer.value) # type: ignore else: pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore return pointers @staticmethod def free_shared_buffer( pointers: List[int], group: Optional[ProcessGroup] = None ) -> None: rank = dist.get_rank(group=group) lib = CudaRTLibrary() lib.cudaFree(ctypes.c_void_p(pointers[rank])) @contextmanager def capture(self): """ The main responsibility of this context manager is the `register_graph_buffers` call at the end of the context. It records all the buffer addresses used in the CUDA graph. """ try: self._IS_CAPTURING = True yield finally: self._IS_CAPTURING = False if not self.disabled: self.register_graph_buffers() def register_graph_buffers(self): handle, offset = vLLMCustomAllreduce.get_graph_buffer_ipc_meta(self._ptr) print("Registering %d cuda graph addresses", len(offset)) # We cannot directly use `dist.all_gather_object` here # because it is incompatible with `gloo` backend under inference mode. # see https://github.com/pytorch/pytorch/issues/126032 for details. all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))] all_data[self.rank] = [handle, offset] ranks = sorted(dist.get_process_group_ranks(group=self.group)) for i, rank in enumerate(ranks): dist.broadcast_object_list( all_data[i], src=rank, group=self.group, device="cpu" ) # Unpack list of tuples to tuple of lists. handles = [d[0] for d in all_data] # type: ignore offsets = [d[1] for d in all_data] # type: ignore vLLMCustomAllreduce.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): if self.disabled: return False inp_size = inp.numel() * inp.element_size() # custom allreduce requires input byte size to be multiples of 16 if inp_size % 16 != 0: return False if not is_weak_contiguous(inp): return False # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. if self.world_size == 2 or self.full_nvlink: return inp_size < self.max_size return False def all_reduce( self, inp: torch.Tensor, *, out: torch.Tensor = None, bsz_tensor: torch.Tensor = None, registered: bool = False, is_compute_bound=False, overlap=False ): """Performs an out-of-place all reduce. If registered is True, this assumes inp's pointer is already IPC-registered. Otherwise, inp is first copied into a pre-registered buffer. """ if is_compute_bound: sms = 2 if overlap else 36 else: sms = 20 if overlap else 36 #print("all reduce sms", sms) if out is None: out = torch.empty_like(inp) if registered: vLLMCustomAllreduce.all_reduce(self._ptr, inp, out, 0, 0, bsz_tensor, block_limit=sms) else: vLLMCustomAllreduce.all_reduce( self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size, bsz_tensor, block_limit=sms ) return out def custom_all_reduce(self, input: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> Optional[torch.Tensor]: """The main allreduce API that provides support for cuda graph.""" # When custom allreduce is disabled, this will be None. if self.disabled or not self.should_custom_ar(input): return None if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): return self.all_reduce(input, bsz_tensor=bsz_tensor, registered=True, is_compute_bound=is_compute_bound, overlap=overlap) else: # If warm up, mimic the allocation pattern since custom # allreduce is out-of-place. return torch.empty_like(input) else: # Note: outside of cuda graph context, custom allreduce incurs a # cost of cudaMemcpy, which should be small (<=1% of overall # latency) compared to the performance gain of using custom kernels return self.all_reduce(input, bsz_tensor=bsz_tensor, registered=False, is_compute_bound=is_compute_bound, overlap=overlap) def close(self): if not self.disabled and self._ptr: vLLMCustomAllreduce.dispose(self._ptr) self._ptr = 0 self.free_shared_buffer(self.meta_ptrs) self.free_shared_buffer(self.buffer_ptrs) def __del__(self): self.close() ================================================ FILE: archive/ktransformers/server/balance_serve/inference/distributed/custom_all_reduce_utils.py ================================================ import ctypes import json import os import pickle import subprocess import sys import tempfile from itertools import product from typing import Dict, List, Optional, Sequence import torch.distributed as dist import torch.multiprocessing as mp import server.envs as envs from server.inference.distributed.cuda_wrapper import CudaRTLibrary from server.utils import cuda_device_count_stateless, update_environment_variables def producer( batch_src: Sequence[int], producer_queue, consumer_queue, result_queue, cuda_visible_devices: Optional[str] = None, ): if cuda_visible_devices is not None: update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) lib = CudaRTLibrary() for i in batch_src: lib.cudaSetDevice(i) pointer = lib.cudaMalloc(1024) lib.cudaMemset(pointer, 1, 1024) lib.cudaDeviceSynchronize() handle = lib.cudaIpcGetMemHandle(pointer) producer_queue.put(handle) open_success = consumer_queue.get() if open_success: # use two queues to simulate barrier producer_queue.put(0) consumer_queue.get() # check if the memory is modified host_data = (ctypes.c_char * 1024)() lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore for i in range(1024): if ord(host_data[i]) != 2: open_success = False break result_queue.put(open_success) lib.cudaDeviceReset() def consumer( batch_tgt: Sequence[int], producer_queue, consumer_queue, result_queue, cuda_visible_devices: Optional[str] = None, ): if cuda_visible_devices is not None: update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) lib = CudaRTLibrary() for j in batch_tgt: lib.cudaSetDevice(j) handle = producer_queue.get() open_success = False try: pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore open_success = True except RuntimeError: # cannot error out here, because the producer process # is still waiting for the response. pass consumer_queue.put(open_success) if open_success: # modify the memory lib.cudaMemset(pointer, 2, 1024) lib.cudaDeviceSynchronize() # use two queues to simulate barrier producer_queue.get() consumer_queue.put(0) # check if the memory is modified host_data = (ctypes.c_char * 1024)() lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore for i in range(1024): if ord(host_data[i]) != 2: open_success = False break result_queue.put(open_success) lib.cudaDeviceReset() def can_actually_p2p( batch_src: Sequence[int], batch_tgt: Sequence[int], ) -> Sequence[bool]: """ Usually, checking if P2P access is enabled can be done by `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)` returns `True` even if P2P access is not actually possible. See https://github.com/vllm-project/vllm/issues/2728 and https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10 Therefore, we have to perform a real P2P access to check if it is actually possible. Note on p2p and cuda IPC: Usually, one process uses one GPU: GPU src --> cuda context src --> tensor src --> process src We need to combine p2p and cuda IPC, so that: GPU src --> cuda context src --> tensor src --> process src |shared| GPU tgt --> cuda context tgt --> tensor tgt --> process tgt That is to say, process src creates a tensor in GPU src, passes IPC handle to process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the tensor in process tgt will be reflected in the tensor in process src, because they are the same memory segment. It is important to note that process tgt accesses the tensor in GPU tgt, not GPU src. That's why we need p2p access. The most time-consuming part is the process creation. To avoid creating processes for every pair of GPUs, we use batched testing. We create two processes for testing all pairs of GPUs in batch. The trick is to reset the device after each test (which is not available in PyTorch). """ # noqa cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES # pass the CUDA_VISIBLE_DEVICES to the child process # to make sure they see the same set of GPUs # make sure the processes are spawned smp = mp.get_context("spawn") producer_queue = smp.Queue() consumer_queue = smp.Queue() result_queue = smp.Queue() p_src = smp.Process( target=producer, args=( batch_src, producer_queue, consumer_queue, result_queue, cuda_visible_devices, ), ) p_tgt = smp.Process( target=consumer, args=( batch_tgt, producer_queue, consumer_queue, result_queue, cuda_visible_devices, ), ) p_src.start() p_tgt.start() p_src.join() p_tgt.join() assert p_src.exitcode == 0 and p_tgt.exitcode == 0 result: List[bool] = [] for src, tgt in zip(batch_src, batch_tgt): a = result_queue.get() b = result_queue.get() if a != b: print( "Two processes do not agree on the P2P access" " status on %d -> %d, treat as disabled.", src, tgt, ) result.append(False) else: result.append(a) return result # why do we need this cache? # we are testing peer-to-peer (p2p) access between GPUs,across processes. # if we test it every time, it will be very slow, because we need to create # N * N * 2 processes, where N is the world size. This is very slow. # to reduce the time, we use a cache file to store the p2p access status. # the cache file is generated by the master process if it does not exist. # then all the processes can read the cache file to check the p2p access status. # Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we # can have different cache files for different CUDA_VISIBLE_DEVICES settings, # e.g. used by different vllm engines. The device id in the cache file is a # **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number # of visible devices in the vllm engine. _gpu_p2p_access_cache: Optional[Dict[str, bool]] = None def gpu_p2p_access_check(src: int, tgt: int) -> bool: """Check if GPU src can access GPU tgt.""" # if the cache variable is already calculated, # read from the cache instead of checking it again global _gpu_p2p_access_cache if _gpu_p2p_access_cache is not None: return _gpu_p2p_access_cache[f"{src}->{tgt}"] is_distributed = dist.is_initialized() num_dev = cuda_device_count_stateless() cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices is None: cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) path = os.path.join( envs.VLLM_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json" ) os.makedirs(os.path.dirname(path), exist_ok=True) from server.inference.distributed.parallel_state import get_world_group if (not is_distributed or get_world_group().local_rank == 0) and ( not os.path.exists(path) ): # only the local master process (with local_rank == 0) can # enter this block to calculate the cache print("generating GPU P2P access cache in %s", path) cache: Dict[str, bool] = {} ids = list(range(num_dev)) # batch of all pairs of GPUs batch_src, batch_tgt = zip(*list(product(ids, ids))) # NOTE: we use `subprocess` rather than `multiprocessing` here # because the caller might not have `if __name__ == "__main__":`, # in that case we cannot use spawn method in multiprocessing. # However, `can_actually_p2p` requires spawn method. # The fix is, we use `subprocess` to call the function, # where we have `if __name__ == "__main__":` in this file. # use a temporary file to store the result # we don't use the output of the subprocess directly, # because the subprocess might produce logging output with tempfile.NamedTemporaryFile() as output_file: input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name)) returned = subprocess.run( [sys.executable, __file__], input=input_bytes, capture_output=True ) # check if the subprocess is successful try: returned.check_returncode() except Exception as e: # wrap raised exception to provide more information raise RuntimeError( f"Error happened when batch testing " f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" f"{returned.stderr.decode()}" ) from e with open(output_file.name, "rb") as f: result = pickle.load(f) for _i, _j, r in zip(batch_src, batch_tgt, result): cache[f"{_i}->{_j}"] = r with open(path, "w") as f: json.dump(cache, f, indent=4) if is_distributed: get_world_group().barrier() print("reading GPU P2P access cache from %s", path) with open(path) as f: cache = json.load(f) _gpu_p2p_access_cache = cache return _gpu_p2p_access_cache[f"{src}->{tgt}"] __all__ = ["gpu_p2p_access_check"] if __name__ == "__main__": batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read()) result = can_actually_p2p(batch_src, batch_tgt) with open(output_file, "wb") as f: f.write(pickle.dumps(result)) ================================================ FILE: archive/ktransformers/server/balance_serve/inference/distributed/parallel_state.py ================================================ # Copyright 2023 The vLLM team. # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """vLLM distributed state. It takes over the control of the distributed environment from PyTorch. The typical workflow is: - call `init_distributed_environment` to initialize the distributed environment. - call `initialize_model_parallel` or `ensure_model_parallel_initialized` to initialize the model parallel groups. - any code dealing with the distributed stuff - call `destroy_model_parallel` to destroy the model parallel groups. - call `destroy_distributed_environment` to destroy the distributed environment. If you only need to use the distributed environment without model/pipeline parallelism, you can skip the model parallel initialization and destruction steps. """ import contextlib import gc import pickle import weakref from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch import torch import torch.distributed from torch.distributed import Backend, ProcessGroup import server.envs as envs from server.inference.platforms import current_platform from server.utils import direct_register_custom_op, supports_custom_op @dataclass class GraphCaptureContext: stream: torch.cuda.Stream TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) def _split_tensor_dict( tensor_dict: Dict[str, Union[torch.Tensor, Any]] ) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: """Split the tensor dictionary into two parts: 1. A list of (key, value) pairs. If the value is a tensor, it is replaced by its metadata. 2. A list of tensors. """ metadata_list: List[Tuple[str, Any]] = [] tensor_list: List[torch.Tensor] = [] for key, value in tensor_dict.items(): if isinstance(value, torch.Tensor): # Note: we cannot use `value.device` here, # because it contains not only the device type but also the device # index (e.g. "cuda:0"). We only need the device type. # receiving side will set the device index. device = value.device.type metadata_list.append( (key, TensorMetadata(device, value.dtype, value.size())) ) tensor_list.append(value) else: metadata_list.append((key, value)) return metadata_list, tensor_list _group_name_counter: Dict[str, int] = {} def _get_unique_name(name: str) -> str: """Get a unique name for the group. Example: _get_unique_name("tp") -> "tp:0" _get_unique_name("tp") -> "tp:1" """ if name not in _group_name_counter: _group_name_counter[name] = 0 newname = f"{name}:{_group_name_counter[name]}" _group_name_counter[name] += 1 return newname _groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} def _register_group(group: "GroupCoordinator") -> None: _groups[group.unique_name] = weakref.ref(group) if supports_custom_op(): def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") group._all_reduce_in_place(tensor) def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None: return direct_register_custom_op( op_name="inplace_all_reduce", op_func=inplace_all_reduce, mutates_args=["tensor"], fake_impl=inplace_all_reduce_fake, ) def outplace_all_reduce(tensor: torch.Tensor, group_name: str, bsz_tensor: torch.Tensor, is_compute_bound: bool = False, overlap: bool = False) -> torch.Tensor: assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") return group._all_reduce_out_place(tensor, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap) def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str, bsz_tensor: torch.Tensor, is_compute_bound: bool = False, overlap: bool = False) -> torch.Tensor: return torch.empty_like(tensor) direct_register_custom_op( op_name="outplace_all_reduce", op_func=outplace_all_reduce, mutates_args=[], fake_impl=outplace_all_reduce_fake, ) class GroupCoordinator: """ PyTorch ProcessGroup wrapper for a group of processes. PyTorch ProcessGroup is bound to one specific communication backend, e.g. NCCL, Gloo, MPI, etc. GroupCoordinator takes charge of all the communication operations among the processes in the group. It can route the communication to a specific implementation (e.g. switch allreduce implementation based on the tensor size and cuda graph mode). """ # available attributes: rank: int # global rank ranks: List[int] # global ranks in the group world_size: int # size of the group # difference between `local_rank` and `rank_in_group`: # if we have a group of size 4 across two nodes: # Process | Node | Rank | Local Rank | Rank in Group # 0 | 0 | 0 | 0 | 0 # 1 | 0 | 1 | 1 | 1 # 2 | 1 | 2 | 0 | 2 # 3 | 1 | 3 | 1 | 3 local_rank: int # local rank used to assign devices rank_in_group: int # rank inside the group cpu_group: ProcessGroup # group for CPU communication device_group: ProcessGroup # group for device communication use_pynccl: bool # a hint of whether to use PyNccl use_custom_allreduce: bool # a hint of whether to use CustomAllreduce # communicators are only created for world size > 1 pynccl_comm: Optional[Any] # PyNccl communicator ca_comm: Optional[Any] # Custom allreduce communicator mq_broadcaster: Optional[Any] # shared memory broadcaster def __init__( self, group_ranks: List[List[int]], local_rank: int, torch_distributed_backend: Union[str, Backend], use_pynccl: bool, use_custom_allreduce: bool, use_tpu_communicator: bool, use_hpu_communicator: bool, use_xpu_communicator: bool, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, ): group_name = group_name or "anonymous" self.unique_name = _get_unique_name(group_name) _register_group(self) self.rank = torch.distributed.get_rank() self.local_rank = local_rank self.device_group = None self.cpu_group = None for ranks in group_ranks: device_group = torch.distributed.new_group( ranks, backend=torch_distributed_backend ) # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. cpu_group = torch.distributed.new_group(ranks, backend="gloo") if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) self.rank_in_group = ranks.index(self.rank) self.device_group = device_group self.cpu_group = cpu_group assert self.cpu_group is not None assert self.device_group is not None assert current_platform.is_cuda_alike() if current_platform.is_cuda_alike(): self.device = torch.device(f"cuda:{local_rank}") else: self.device = torch.device("cpu") self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce self.use_tpu_communicator = use_tpu_communicator self.use_hpu_communicator = use_hpu_communicator self.use_xpu_communicator = use_xpu_communicator # lazy import to avoid documentation build error from server.inference.distributed.custom_all_reduce import CustomAllreduce from server.inference.distributed.pynccl import PyNcclCommunicator self.pynccl_comm: Optional[PyNcclCommunicator] = None # if use_pynccl and self.world_size > 1: # self.pynccl_comm = PyNcclCommunicator( # group=self.cpu_group, # device=self.device, # ) self.ca_comm: Optional[CustomAllreduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( group=self.cpu_group, device=self.device, ) #### we assume we won't use tpu or hpu or xpu or messagequeue broadcast # from vllm.distributed.device_communicators.tpu_communicator import ( # TpuCommunicator) # self.tpu_communicator: Optional[TpuCommunicator] = None # if use_tpu_communicator and self.world_size > 1: # self.tpu_communicator = TpuCommunicator(group=self.cpu_group) self.tpu_communicator = None # from vllm.distributed.device_communicators.hpu_communicator import ( # HpuCommunicator) # self.hpu_communicator: Optional[HpuCommunicator] # if use_hpu_communicator and self.world_size > 1: # self.hpu_communicator = HpuCommunicator(group=self.device_group) self.hpu_communicator = None # from vllm.distributed.device_communicators.xpu_communicator import ( # XpuCommunicator) # self.xpu_communicator: Optional[XpuCommunicator] # if use_xpu_communicator and self.world_size > 1: # self.xpu_communicator = XpuCommunicator(group=self.device_group) self.xpu_communicator = None # from vllm.distributed.device_communicators.shm_broadcast import ( # MessageQueue) # self.mq_broadcaster: Optional[MessageQueue] = None # if use_message_queue_broadcaster and self.world_size > 1: # self.mq_broadcaster = MessageQueue.create_from_process_group( # self.cpu_group, 1 << 22, 6) self.mq_broadcaster = None @property def first_rank(self): """Return the global rank of the first process in the group""" return self.ranks[0] @property def last_rank(self): """Return the global rank of the last process in the group""" return self.ranks[-1] @property def is_first_rank(self): """Return whether the caller is the first process in the group""" return self.rank == self.first_rank @property def is_last_rank(self): """Return whether the caller is the last process in the group""" return self.rank == self.last_rank @property def next_rank(self): """Return the global rank of the process that follows the caller""" rank_in_group = self.rank_in_group world_size = self.world_size return self.ranks[(rank_in_group + 1) % world_size] @property def prev_rank(self): """Return the global rank of the process that precedes the caller""" rank_in_group = self.rank_in_group world_size = self.world_size return self.ranks[(rank_in_group - 1) % world_size] @contextmanager def graph_capture( self, graph_capture_context: Optional[GraphCaptureContext] = None ): if graph_capture_context is None: stream = torch.cuda.Stream() graph_capture_context = GraphCaptureContext(stream) else: stream = graph_capture_context.stream ca_comm = self.ca_comm maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture() # ensure all initialization operations complete before attempting to # capture the graph on another stream curr_stream = torch.cuda.current_stream() if curr_stream != stream: stream.wait_stream(curr_stream) with torch.cuda.stream(stream), maybe_ca_context: # In graph mode, we have to be very careful about the collective # operations. The current status is: # allreduce \ Mode | Eager | Graph | # -------------------------------------------- # custom allreduce | enabled | enabled | # PyNccl | disabled| enabled | # torch.distributed | enabled | disabled| # # Note that custom allreduce will have a runtime check, if the # tensor size is too large, it will fallback to the next # available option. # In summary: When using CUDA graph, we use # either custom all-reduce kernel or pynccl. When not using # CUDA graph, we use either custom all-reduce kernel or # PyTorch NCCL. We always prioritize using custom all-reduce # kernel but fall back to PyTorch or pynccl if it is # disabled or not supported. pynccl_comm = self.pynccl_comm maybe_pynccl_context: Any if not pynccl_comm: maybe_pynccl_context = nullcontext() else: maybe_pynccl_context = pynccl_comm.change_state( enable=True, stream=torch.cuda.current_stream() ) with maybe_pynccl_context: yield graph_capture_context def all_reduce(self, input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor: """ User-facing all-reduce function before we actually call the all-reduce operation. We need this because Dynamo does not support passing an arbitrary object (`self` in this case) to a custom op. We need to pass the group name as a string, and then look up the group coordinator from the group name, dispatch the all-reduce operation to the group coordinator. In addition, PyTorch custom ops do not support mutation or returning a new tensor in the same op. So we need to figure out if the op is in-place or out-of-place ahead of time. """ # Bypass the function if we are using only 1 GPU. if self.world_size == 1: return input_ if input_.is_cpu: import intel_extension_for_pytorch as ipex ipex.distributed.all_reduce(input_, group=self.device_group) return input_ if not supports_custom_op(): self._all_reduce_in_place(input_) return input_ if self.tpu_communicator is not None and not self.tpu_communicator.disabled: # TPU handles Dynamo with its own logic. return self.tpu_communicator.all_reduce(input_) if self.hpu_communicator is not None and not self.hpu_communicator.disabled: return self.hpu_communicator.all_reduce(input_) if self.xpu_communicator is not None and not self.xpu_communicator.disabled: return self.xpu_communicator.all_reduce(input_) if ( self.ca_comm is not None and not self.ca_comm.disabled and self.ca_comm.should_custom_ar(input_) ): return torch.ops.vllm.outplace_all_reduce( input_, group_name=self.unique_name, bsz_tensor=bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap ) else: #assert self.ca_comm is not None #assert not self.ca_comm.disabled #assert self.ca_comm.should_custom_ar(input_) torch.ops.vllm.inplace_all_reduce(input_, group_name=self.unique_name) return input_ def _all_reduce_out_place(self, input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor: ca_comm = self.ca_comm assert ca_comm is not None assert not ca_comm.disabled out = ca_comm.custom_all_reduce(input_, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap) assert out is not None return out def _all_reduce_in_place(self, input_: torch.Tensor) -> None: pynccl_comm = self.pynccl_comm if pynccl_comm is not None and not pynccl_comm.disabled: pynccl_comm.all_reduce(input_) else: torch.distributed.all_reduce(input_, group=self.device_group) def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ assert ( -input_.dim() <= dim < input_.dim() ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" # For TPUs, use TPU communicator. tpu_comm = self.tpu_communicator if tpu_comm is not None and not tpu_comm.disabled: return tpu_comm.all_gather(input_, dim) # For HPUs, use HPU communicator. hpu_comm = self.hpu_communicator if hpu_comm is not None and not hpu_comm.disabled: return hpu_comm.all_gather(input_, dim) if dim < 0: # Convert negative dim to positive. dim += input_.dim() input_size = input_.size() # NOTE: we have to use concat-style all-gather here, # stack-style all-gather has compatibility issues with # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 output_size = (input_size[0] * world_size,) + input_size[1:] # Allocate output tensor. output_tensor = torch.empty( output_size, dtype=input_.dtype, device=input_.device ) # All-gather. torch.distributed.all_gather_into_tensor( output_tensor, input_, group=self.device_group ) # Reshape output_tensor = output_tensor.reshape((world_size,) + input_size) output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape( input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] ) return output_tensor def gather( self, input_: torch.Tensor, dst: int = 0, dim: int = -1 ) -> Optional[torch.Tensor]: """ NOTE: We assume that the input tensor is on the same device across all the ranks. NOTE: `dst` is the local rank of the destination rank. """ world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ assert ( -input_.dim() <= dim < input_.dim() ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" if dim < 0: # Convert negative dim to positive. dim += input_.dim() if self.xpu_communicator is not None and not self.xpu_communicator.disabled: return self.xpu_communicator.gather(input_, self.rank_in_group, dst, dim) # Allocate output tensor. if self.rank_in_group == dst: gather_list = [torch.empty_like(input_) for _ in range(world_size)] else: gather_list = None # Gather. torch.distributed.gather( input_, gather_list, dst=self.ranks[dst], group=self.device_group ) if self.rank_in_group == dst: output_tensor = torch.cat(gather_list, dim=dim) else: output_tensor = None return output_tensor def broadcast(self, input_: torch.Tensor, src: int = 0): """Broadcast the input tensor. NOTE: `src` is the local rank of the source rank. """ assert src < self.world_size, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. if self.world_size == 1: return input_ # Broadcast. torch.distributed.broadcast( input_, src=self.ranks[src], group=self.device_group ) return input_ def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): """Broadcast the input object. NOTE: `src` is the local rank of the source rank. """ assert src < self.world_size, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. if self.world_size == 1: return obj if self.mq_broadcaster is not None: assert src == 0, "Message queue broadcaster only supports src=0" return self.mq_broadcaster.broadcast_object(obj) if self.rank_in_group == src: torch.distributed.broadcast_object_list( [obj], src=self.ranks[src], group=self.cpu_group ) return obj else: recv = [None] torch.distributed.broadcast_object_list( recv, src=self.ranks[src], group=self.cpu_group ) return recv[0] def broadcast_object_list( self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None ): """Broadcast the input object list. NOTE: `src` is the local rank of the source rank. """ assert src < self.world_size, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. if self.world_size == 1: return obj_list # Broadcast. torch.distributed.broadcast_object_list( obj_list, src=self.ranks[src], group=self.device_group ) return obj_list def send_object(self, obj: Any, dst: int) -> None: """Send the input object list to the destination rank.""" """NOTE: `dst` is the local rank of the destination rank.""" assert dst < self.world_size, f"Invalid dst rank ({dst})" assert dst != self.rank_in_group, ( "Invalid destination rank. Destination rank is the same " "as the current rank." ) # Serialize object to tensor and get the size as well object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) size_tensor = torch.tensor( [object_tensor.numel()], dtype=torch.long, device="cpu" ) # Send object size torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) # Send object torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) return None def recv_object(self, src: int) -> Any: """Receive the input object list from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" assert src < self.world_size, f"Invalid src rank ({src})" assert ( src != self.rank_in_group ), "Invalid source rank. Source rank is the same as the current rank." size_tensor = torch.empty(1, dtype=torch.long, device="cpu") # Receive object size rank_size = torch.distributed.recv( size_tensor, src=self.ranks[src], group=self.cpu_group ) # Tensor to receive serialized objects into. object_tensor = torch.empty( # type: ignore[call-overload] size_tensor.item(), # type: ignore[arg-type] dtype=torch.uint8, device="cpu", ) rank_object = torch.distributed.recv( object_tensor, src=self.ranks[src], group=self.cpu_group ) assert ( rank_object == rank_size ), "Received object sender rank does not match the size sender rank." obj = pickle.loads(object_tensor.numpy().tobytes()) return obj def broadcast_tensor_dict( self, tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, metadata_group: Optional[ProcessGroup] = None, ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: """Broadcast the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return tensor_dict group = self.device_group metadata_group = self.cpu_group assert src < self.world_size, f"Invalid src rank ({src})" rank_in_group = self.rank_in_group if rank_in_group == src: metadata_list: List[Tuple[Any, Any]] = [] assert isinstance( tensor_dict, dict ), f"Expecting a dictionary, got {type(tensor_dict)}" metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. # `broadcast_object_list` has serialization & deserialization, # all happening on CPU. Therefore, we can use the CPU group. self.broadcast_object(metadata_list, src=src) async_handles = [] for tensor in tensor_list: if tensor.numel() == 0: # Skip broadcasting empty tensors. continue if tensor.is_cpu: # use metadata_group for CPU tensors handle = torch.distributed.broadcast( tensor, src=self.ranks[src], group=metadata_group, async_op=True ) else: # use group for GPU tensors handle = torch.distributed.broadcast( tensor, src=self.ranks[src], group=group, async_op=True ) async_handles.append(handle) for async_handle in async_handles: async_handle.wait() else: metadata_list = self.broadcast_object(None, src=src) tensor_dict = {} async_handles = [] for key, value in metadata_list: if isinstance(value, TensorMetadata): tensor = torch.empty( value.size, dtype=value.dtype, device=value.device ) if tensor.numel() == 0: # Skip broadcasting empty tensors. tensor_dict[key] = tensor continue if tensor.is_cpu: # use metadata_group for CPU tensors handle = torch.distributed.broadcast( tensor, src=self.ranks[src], group=metadata_group, async_op=True, ) else: # use group for GPU tensors handle = torch.distributed.broadcast( tensor, src=self.ranks[src], group=group, async_op=True ) async_handles.append(handle) tensor_dict[key] = tensor else: tensor_dict[key] = value for async_handle in async_handles: async_handle.wait() return tensor_dict def send_tensor_dict( self, tensor_dict: Dict[str, Union[torch.Tensor, Any]], dst: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. """ # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return tensor_dict all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size all_gather_rank = ( 0 if all_gather_group is None else all_gather_group.rank_in_group ) group = self.device_group metadata_group = self.cpu_group if dst is None: dst = (self.rank_in_group + 1) % self.world_size assert dst < self.world_size, f"Invalid dst rank ({dst})" metadata_list: List[Tuple[Any, Any]] = [] assert isinstance( tensor_dict, dict ), f"Expecting a dictionary, got {type(tensor_dict)}" metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. # `send_object_list` has serialization & deserialization, # all happening on CPU. Therefore, we can use the CPU group. self.send_object(metadata_list, dst=dst) for tensor in tensor_list: if tensor.numel() == 0: # Skip sending empty tensors. continue # send-allgather: send only a slice, then do allgather. if all_gather_group is not None and tensor.numel() % all_gather_size == 0: tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] if tensor.is_cpu: # use metadata_group for CPU tensors torch.distributed.send( tensor, dst=self.ranks[dst], group=metadata_group ) else: # use group for GPU tensors torch.distributed.send(tensor, dst=self.ranks[dst], group=group) return None def recv_tensor_dict( self, src: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return None all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size all_gather_rank = ( 0 if all_gather_group is None else all_gather_group.rank_in_group ) group = self.device_group metadata_group = self.cpu_group if src is None: src = (self.rank_in_group - 1) % self.world_size assert src < self.world_size, f"Invalid src rank ({src})" recv_metadata_list = self.recv_object(src=src) tensor_dict: Dict[str, Any] = {} for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) if tensor.numel() == 0: # Skip broadcasting empty tensors. tensor_dict[key] = tensor continue # send-allgather: send only a slice, then do allgather. use_all_gather = ( all_gather_group is not None and tensor.numel() % all_gather_size == 0 ) if use_all_gather: orig_shape = tensor.shape tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] if tensor.is_cpu: # use metadata_group for CPU tensors torch.distributed.recv( tensor, src=self.ranks[src], group=metadata_group ) else: # use group for GPU tensors torch.distributed.recv(tensor, src=self.ranks[src], group=group) if use_all_gather: # do the allgather tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore tensor = tensor.reshape(orig_shape) tensor_dict[key] = tensor else: tensor_dict[key] = value return tensor_dict def barrier(self): """Barrier synchronization among the group. NOTE: don't use `device_group` here! `barrier` in NCCL is terrible because it is internally a broadcast operation with secretly created GPU tensors. It is easy to mess up the current device. Use the CPU group instead. """ torch.distributed.barrier(group=self.cpu_group) def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" if dst is None: dst = (self.rank_in_group + 1) % self.world_size pynccl_comm = self.pynccl_comm if pynccl_comm is not None and not pynccl_comm.disabled: pynccl_comm.send(tensor, dst) else: torch.distributed.send(tensor, self.ranks[dst], self.device_group) def recv( self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None ) -> torch.Tensor: """Receives a tensor from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" if src is None: src = (self.rank_in_group - 1) % self.world_size tensor = torch.empty(size, dtype=dtype, device=self.device) pynccl_comm = self.pynccl_comm if pynccl_comm is not None and not pynccl_comm.disabled: pynccl_comm.recv(tensor, src) else: torch.distributed.recv(tensor, self.ranks[src], self.device_group) return tensor def destroy(self): if self.device_group is not None: torch.distributed.destroy_process_group(self.device_group) self.device_group = None if self.cpu_group is not None: torch.distributed.destroy_process_group(self.cpu_group) self.cpu_group = None if self.pynccl_comm is not None: self.pynccl_comm = None if self.ca_comm is not None: self.ca_comm = None if self.mq_broadcaster is not None: self.mq_broadcaster = None _WORLD: Optional[GroupCoordinator] = None def get_world_group() -> GroupCoordinator: assert _WORLD is not None, "world group is not initialized" return _WORLD def init_world_group( ranks: List[int], local_rank: int, backend: str ) -> GroupCoordinator: return GroupCoordinator( group_ranks=[ranks], local_rank=local_rank, torch_distributed_backend=backend, use_pynccl=False, use_custom_allreduce=False, use_tpu_communicator=False, use_hpu_communicator=False, use_xpu_communicator=False, group_name="world", ) def init_model_parallel_group( group_ranks: List[List[int]], local_rank: int, backend: str, use_custom_allreduce: Optional[bool] = None, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, ) -> GroupCoordinator: if use_custom_allreduce is None: use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE return GroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, use_pynccl=True, use_custom_allreduce=use_custom_allreduce, use_tpu_communicator=True, use_hpu_communicator=True, use_xpu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, ) _TP: Optional[GroupCoordinator] = None def get_tp_group() -> GroupCoordinator: assert _TP is not None, "tensor model parallel group is not initialized" return _TP # kept for backward compatibility get_tensor_model_parallel_group = get_tp_group _PP: Optional[GroupCoordinator] = None def get_pp_group() -> GroupCoordinator: assert _PP is not None, "pipeline model parallel group is not initialized" return _PP # kept for backward compatibility get_pipeline_model_parallel_group = get_pp_group @contextmanager def graph_capture(): """ `graph_capture` is a context manager which should surround the code that is capturing the CUDA graph. Its main purpose is to ensure that the some operations will be run after the graph is captured, before the graph is replayed. It returns a `GraphCaptureContext` object which contains the necessary data for the graph capture. Currently, it only contains the stream that the graph capture is running on. This stream is set to the current CUDA stream when the context manager is entered and reset to the default stream when the context manager is exited. This is to ensure that the graph capture is running on a separate stream from the default stream, in order to explicitly distinguish the kernels to capture from other kernels possibly launched on background in the default stream. """ with get_tp_group().graph_capture() as context, get_pp_group().graph_capture( context ): yield context _ENABLE_CUSTOM_ALL_REDUCE = True def set_custom_all_reduce(enable: bool): global _ENABLE_CUSTOM_ALL_REDUCE _ENABLE_CUSTOM_ALL_REDUCE = enable def init_distributed_environment( world_size: int = -1, rank: int = -1, distributed_init_method: str = "env://", local_rank: int = -1, backend: str = "nccl", ): print( "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", world_size, rank, local_rank, distributed_init_method, backend, ) if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " "distributed environment" ) # this backend is used for WORLD torch.distributed.init_process_group( backend=backend, init_method=distributed_init_method, world_size=world_size, rank=rank, ) # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 if local_rank == -1: # local rank not set, this usually happens in single-node # setting, where we can use rank as local rank if distributed_init_method == "env://": local_rank = envs.LOCAL_RANK else: local_rank = rank global _WORLD if _WORLD is None: ranks = list(range(torch.distributed.get_world_size())) _WORLD = init_world_group(ranks, local_rank, backend) else: assert ( _WORLD.world_size == torch.distributed.get_world_size() ), "world group already initialized with a different world size" def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, backend: Optional[str] = None, ) -> None: """ Initialize model parallel groups. Arguments: tensor_model_parallel_size: number of GPUs used for tensor model parallelism. pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize the model pipeline. The present function will create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: 4 tensor model-parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7] 2 pipeline model-parallel groups: [g0, g2, g4, g6], [g1, g3, g5, g7] Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. """ # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend(get_world_group().device_group) if world_size != tensor_model_parallel_size * pipeline_model_parallel_size: raise RuntimeError( f"world_size ({world_size}) is not equal to " f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " f"pipeline_model_parallel_size ({pipeline_model_parallel_size})" ) # Build the tensor model-parallel groups. num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size global _TP assert _TP is None, "tensor model parallel group is already initialized" group_ranks = [] for i in range(num_tensor_model_parallel_groups): ranks = list( range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) ) group_ranks.append(ranks) # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, use_message_queue_broadcaster=True, group_name="tp", ) # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size global _PP assert _PP is None, "pipeline model parallel group is already initialized" group_ranks = [] for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group_ranks.append(ranks) # pipeline parallel does not need custom allreduce _PP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False, group_name="pp", ) def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, backend: Optional[str] = None, ) -> None: """Helper to initialize model parallel groups if they are not initialized, or ensure tensor-parallel and pipeline-parallel sizes are equal to expected values if the model parallel groups are initialized. """ backend = backend or torch.distributed.get_backend(get_world_group().device_group) if not model_parallel_is_initialized(): initialize_model_parallel( tensor_model_parallel_size, pipeline_model_parallel_size, backend ) return assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( "tensor parallel group already initialized, but of unexpected size: " f"{get_tensor_model_parallel_world_size()=} vs. " f"{tensor_model_parallel_size=}" ) pp_world_size = get_pp_group().world_size assert pp_world_size == pipeline_model_parallel_size, ( "pipeline parallel group already initialized, but of unexpected size: " f"{pp_world_size=} vs. " f"{pipeline_model_parallel_size=}" ) def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" return _TP is not None and _PP is not None _TP_STATE_PATCHED = False @contextmanager def patch_tensor_parallel_group(tp_group: GroupCoordinator): """Patch the tp group temporarily until this function ends. This method is for draft workers of speculative decoding to run draft model with different tp degree from that of target model workers. Args: tp_group (GroupCoordinator): the tp group coordinator """ global _TP_STATE_PATCHED assert not _TP_STATE_PATCHED, "Should not call when it's already patched" _TP_STATE_PATCHED = True old_tp_group = get_tp_group() global _TP _TP = tp_group try: yield finally: # restore the original state _TP_STATE_PATCHED = False _TP = old_tp_group def get_tensor_model_parallel_world_size(): """Return world size for the tensor model parallel group.""" return get_tp_group().world_size def get_tensor_model_parallel_rank(): """Return my rank for the tensor model parallel group.""" return get_tp_group().rank_in_group def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP if _TP: _TP.destroy() _TP = None global _PP if _PP: _PP.destroy() _PP = None def destroy_distributed_environment(): global _WORLD if _WORLD: _WORLD.destroy() _WORLD = None if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() def cleanup_dist_env_and_memory(shutdown_ray: bool = False): destroy_model_parallel() destroy_distributed_environment() with contextlib.suppress(AssertionError): torch.distributed.destroy_process_group() if shutdown_ray: import ray # Lazy import Ray ray.shutdown() gc.collect() if not current_platform.is_cpu(): torch.cuda.empty_cache() def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: """ This is a collective operation that returns if each rank is in the same node as the source rank. It tests if processes are attached to the same memory system (shared access to shared memory). """ assert ( torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL ), "in_the_same_node_as should be tested with a non-NCCL group." # local rank inside the group rank = torch.distributed.get_rank(group=pg) world_size = torch.distributed.get_world_size(group=pg) # local tensor in each process to store the result is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) # global ranks of the processes in the group ranks = torch.distributed.get_process_group_ranks(pg) magic_message = b"magic_message" shm = None try: with contextlib.suppress(OSError): if rank == source_rank: # create a shared memory segment shm = shared_memory.SharedMemory(create=True, size=128) shm.buf[: len(magic_message)] = magic_message torch.distributed.broadcast_object_list( [shm.name], src=ranks[source_rank], group=pg ) is_in_the_same_node[rank] = 1 else: # try to open the shared memory segment recv = [None] torch.distributed.broadcast_object_list( recv, src=ranks[source_rank], group=pg ) name = recv[0] # fix to https://stackoverflow.com/q/62748654/9191338 # Python incorrectly tracks shared memory even if it is not # created by the process. The following patch is a workaround. with patch( "multiprocessing.resource_tracker.register", lambda *args, **kwargs: None, ): shm = shared_memory.SharedMemory(name=name) if shm.buf[: len(magic_message)] == magic_message: is_in_the_same_node[rank] = 1 except Exception as e: print("Error ignored in is_in_the_same_node: %s", e) finally: if shm: shm.close() torch.distributed.barrier(group=pg) # clean up the shared memory segment with contextlib.suppress(OSError): if rank == source_rank and shm: shm.unlink() torch.distributed.all_reduce(is_in_the_same_node, group=pg) return [x == 1 for x in is_in_the_same_node.tolist()] ================================================ FILE: archive/ktransformers/server/balance_serve/inference/distributed/pynccl.py ================================================ from contextlib import contextmanager from typing import Optional, Union # ===================== import region ===================== import torch import torch.distributed as dist from torch.distributed import ProcessGroup, ReduceOp from server.inference.distributed.pynccl_wrapper import ( NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, ncclRedOpTypeEnum, ncclUniqueId, ) from server.inference.distributed.utils import StatelessProcessGroup class PyNcclCommunicator: def __init__( self, group: Union[ProcessGroup, StatelessProcessGroup], device: Union[int, str, torch.device], library_path: Optional[str] = None, ): """ Args: group: the process group to work on. If None, it will use the default process group. device: the device to bind the PyNcclCommunicator to. If None, it will be bind to f"cuda:{local_rank}". library_path: the path to the NCCL library. If None, it will use the default library path. It is the caller's responsibility to make sure each communicator is bind to a unique device. """ if not isinstance(group, StatelessProcessGroup): assert dist.is_initialized() assert ( dist.get_backend(group) != dist.Backend.NCCL ), "PyNcclCommunicator should be attached to a non-NCCL group." # note: this rank is the rank in the group self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) else: self.rank = group.rank self.world_size = group.world_size self.group = group # if world_size == 1, no need to create communicator if self.world_size == 1: self.available = False self.disabled = True self.stream = None return try: self.nccl = NCCLLibrary(library_path) except Exception: # disable because of missing NCCL library # e.g. in a non-GPU environment self.available = False self.disabled = True self.stream = None return self.available = True self.disabled = False print("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) if self.rank == 0: # get the unique id from NCCL self.unique_id = self.nccl.ncclGetUniqueId() else: # construct an empty unique id self.unique_id = ncclUniqueId() if not isinstance(group, StatelessProcessGroup): tensor = torch.ByteTensor(list(self.unique_id.internal)) ranks = dist.get_process_group_ranks(group) # arg `src` in `broadcast` is the global rank dist.broadcast(tensor, src=ranks[0], group=group) byte_list = tensor.tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte else: self.unique_id = group.broadcast_obj(self.unique_id, src=0) if isinstance(device, int): device = torch.device(f"cuda:{device}") elif isinstance(device, str): device = torch.device(device) # now `device` is a `torch.device` object assert isinstance(device, torch.device) self.device = device # nccl communicator and stream will use this device # `torch.cuda.device` is a context manager that changes the # current cuda device to the specified one with torch.cuda.device(device): self.comm: ncclComm_t = self.nccl.ncclCommInitRank( self.world_size, self.unique_id, self.rank ) self.stream = torch.cuda.Stream() # A small all_reduce for warmup. data = torch.zeros(1, device=device) self.all_reduce(data) self.stream.synchronize() del data # by default it is disabled, e.g. in profiling models and prefill phase. # to use it, use under `with obj.change_state(enable=True)`, usually # when we are using CUDA graph. self.disabled = True def all_reduce( self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None ): if self.disabled: return # nccl communicator created on a specific device # will only work on tensors on the same device # otherwise it will cause "illegal memory access" assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" ) if stream is None: stream = self.stream self.nccl.ncclAllReduce( buffer_type(tensor.data_ptr()), buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), ncclRedOpTypeEnum.from_torch(op), self.comm, cudaStream_t(stream.cuda_stream), ) def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" ) if stream is None: stream = self.stream self.nccl.ncclSend( buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), dst, self.comm, cudaStream_t(stream.cuda_stream), ) def recv(self, tensor: torch.Tensor, src: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}" ) if stream is None: stream = self.stream self.nccl.ncclRecv( buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream), ) @contextmanager def change_state( self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None ): """ A context manager to change the state of the communicator. """ if enable is None: # guess a default value when not specified enable = self.available if stream is None: stream = self.stream old_disable = self.disabled old_stream = self.stream self.stream = stream self.disabled = not enable yield self.disabled = old_disable self.stream = old_stream ================================================ FILE: archive/ktransformers/server/balance_serve/inference/distributed/pynccl_wrapper.py ================================================ # This file is a pure Python wrapper for the NCCL library. # The main purpose is to use NCCL combined with CUDA graph. # Before writing this script, we tried the following approach: # 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself # often gets stuck when initializing the NCCL communicator. # 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` # contains many other potential cuda APIs, that are not allowed during # capturing the CUDA graph. For further details, please check # https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . # # Another rejected idea is to write a C/C++ binding for NCCL. It is usually # doable, but we often encounter issues related with nccl versions, and need # to switch between different versions of NCCL. See # https://github.com/NVIDIA/nccl/issues/1234 for more details. # A C/C++ binding is not flexible enough to handle this. It requires # recompilation of the code every time we want to switch between different # versions. This current implementation, with a **pure** Python wrapper, is # more flexible. We can easily switch between different versions of NCCL by # changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` # variable in the code. import ctypes import platform from dataclasses import dataclass from typing import Any, Dict, List, Optional import torch from torch.distributed import ReduceOp from server.utils import find_nccl_library # === export types and functions from nccl to Python === # for the original nccl definition, please check # https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in ncclResult_t = ctypes.c_int ncclComm_t = ctypes.c_void_p class ncclUniqueId(ctypes.Structure): _fields_ = [("internal", ctypes.c_byte * 128)] cudaStream_t = ctypes.c_void_p buffer_type = ctypes.c_void_p ncclDataType_t = ctypes.c_int class ncclDataTypeEnum: ncclInt8 = 0 ncclChar = 0 ncclUint8 = 1 ncclInt32 = 2 ncclInt = 2 ncclUint32 = 3 ncclInt64 = 4 ncclUint64 = 5 ncclFloat16 = 6 ncclHalf = 6 ncclFloat32 = 7 ncclFloat = 7 ncclFloat64 = 8 ncclDouble = 8 ncclBfloat16 = 9 ncclNumTypes = 10 @classmethod def from_torch(cls, dtype: torch.dtype) -> int: if dtype == torch.int8: return cls.ncclInt8 if dtype == torch.uint8: return cls.ncclUint8 if dtype == torch.int32: return cls.ncclInt32 if dtype == torch.int64: return cls.ncclInt64 if dtype == torch.float16: return cls.ncclFloat16 if dtype == torch.float32: return cls.ncclFloat32 if dtype == torch.float64: return cls.ncclFloat64 if dtype == torch.bfloat16: return cls.ncclBfloat16 raise ValueError(f"Unsupported dtype: {dtype}") ncclRedOp_t = ctypes.c_int class ncclRedOpTypeEnum: ncclSum = 0 ncclProd = 1 ncclMax = 2 ncclMin = 3 ncclAvg = 4 ncclNumOps = 5 @classmethod def from_torch(cls, op: ReduceOp) -> int: if op == ReduceOp.SUM: return cls.ncclSum if op == ReduceOp.PRODUCT: return cls.ncclProd if op == ReduceOp.MAX: return cls.ncclMax if op == ReduceOp.MIN: return cls.ncclMin if op == ReduceOp.AVG: return cls.ncclAvg raise ValueError(f"Unsupported op: {op}") @dataclass class Function: name: str restype: Any argtypes: List[Any] class NCCLLibrary: exported_functions = [ # const char* ncclGetErrorString(ncclResult_t result) Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), # ncclResult_t ncclGetVersion(int *version); Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]), # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]), # ncclResult_t ncclCommInitRank( # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); # note that ncclComm_t is a pointer type, so the first argument # is a pointer to a pointer Function("ncclCommInitRank", ncclResult_t, [ ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int ]), # ncclResult_t ncclAllReduce( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer Function("ncclAllReduce", ncclResult_t, [ buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t ]), # ncclResult_t ncclSend( # const void* sendbuff, size_t count, ncclDataType_t datatype, # int dest, ncclComm_t comm, cudaStream_t stream); Function("ncclSend", ncclResult_t, [ buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, ncclComm_t, cudaStream_t ]), # ncclResult_t ncclRecv( # void* recvbuff, size_t count, ncclDataType_t datatype, # int src, ncclComm_t comm, cudaStream_t stream); Function("ncclRecv", ncclResult_t, [ buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, ncclComm_t, cudaStream_t ]), # be cautious! this is a collective call, it will block until all # processes in the communicator have called this function. # because Python object destruction can happen in random order, # it is better not to call it at all. # ncclResult_t ncclCommDestroy(ncclComm_t comm); Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), ] # class attribute to store the mapping from the path to the library # to avoid loading the same library multiple times path_to_library_cache: Dict[str, Any] = {} # class attribute to store the mapping from library path # to the corresponding dictionary path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} def __init__(self, so_file: Optional[str] = None): so_file = so_file or find_nccl_library() try: if so_file not in NCCLLibrary.path_to_dict_mapping: lib = ctypes.CDLL(so_file) NCCLLibrary.path_to_library_cache[so_file] = lib self.lib = NCCLLibrary.path_to_library_cache[so_file] except Exception as e: print( "Failed to load NCCL library from %s ." "It is expected if you are not running on NVIDIA/AMD GPUs." "Otherwise, the nccl library might not exist, be corrupted " "or it does not support the current platform %s." "If you already have the library, please set the " "environment variable VLLM_NCCL_SO_PATH" " to point to the correct nccl library path.", so_file, platform.platform()) raise e if so_file not in NCCLLibrary.path_to_dict_mapping: _funcs: Dict[str, Any] = {} for func in NCCLLibrary.exported_functions: f = getattr(self.lib, func.name) f.restype = func.restype f.argtypes = func.argtypes _funcs[func.name] = f NCCLLibrary.path_to_dict_mapping[so_file] = _funcs self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] def ncclGetErrorString(self, result: ncclResult_t) -> str: return self._funcs["ncclGetErrorString"](result).decode("utf-8") def NCCL_CHECK(self, result: ncclResult_t) -> None: if result != 0: error_str = self.ncclGetErrorString(result) raise RuntimeError(f"NCCL error: {error_str}") def ncclGetVersion(self) -> str: version = ctypes.c_int() self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) version_str = str(version.value) # something like 21903 --> "2.19.3" major = version_str[0].lstrip("0") minor = version_str[1:3].lstrip("0") patch = version_str[3:].lstrip("0") return f"{major}.{minor}.{patch}" def ncclGetUniqueId(self) -> ncclUniqueId: unique_id = ncclUniqueId() self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( ctypes.byref(unique_id))) return unique_id def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, rank: int) -> ncclComm_t: comm = ncclComm_t() self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), world_size, unique_id, rank)) return comm def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, count: int, datatype: int, op: int, comm: ncclComm_t, stream: cudaStream_t) -> None: # `datatype` actually should be `ncclDataType_t` # and `op` should be `ncclRedOp_t` # both are aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, datatype, op, comm, stream)) def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream)) def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)) def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) __all__ = [ "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", "ncclComm_t", "cudaStream_t", "buffer_type" ] ================================================ FILE: archive/ktransformers/server/balance_serve/inference/distributed/utils.py ================================================ # Copyright 2023 The vLLM team. # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import dataclasses import pickle import time from collections import deque from typing import Any, Deque, Dict, Optional, Sequence, Tuple import torch from torch.distributed import TCPStore import server.envs as envs def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" assert numerator % denominator == 0, "{} is not divisible by {}".format( numerator, denominator ) def divide(numerator, denominator): """Ensure that numerator is divisible by the denominator and return the division value.""" ensure_divisibility(numerator, denominator) return numerator // denominator def split_tensor_along_last_dim( tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False, ) -> Sequence[torch.Tensor]: """Split a tensor along its last dimension. Arguments: tensor: input tensor. num_partitions: number of partitions to split the tensor contiguous_split_chunks: If True, make each chunk contiguous in memory. Returns: A list of Tensors """ # Get the size and dimension. last_dim = tensor.dim() - 1 last_dim_size = divide(tensor.size()[last_dim], num_partitions) # Split. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) # NOTE: torch.split does not create contiguous tensors by default. if contiguous_split_chunks: return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list def get_pp_indices( num_hidden_layers: int, pp_rank: int, pp_size: int ) -> Tuple[int, int]: """Try to evenly distribute layers across partitions. If the number of layers is not divisible by the number of partitions, the last partition will have the remaining layers. """ partition_list_str = envs.VLLM_PP_LAYER_PARTITION if partition_list_str is not None: try: partitions = [int(layer) for layer in partition_list_str.split(",")] except ValueError as err: raise ValueError( "Invalid partition string: {}".format(partition_list_str) ) from err if len(partitions) != pp_size: raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") if sum(partitions) != num_hidden_layers: raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.") start_layer = sum(partitions[:pp_rank]) end_layer = start_layer + partitions[pp_rank] else: layers_per_partition = num_hidden_layers // pp_size start_layer = pp_rank * layers_per_partition end_layer = start_layer + layers_per_partition if pp_rank == pp_size - 1: end_layer = num_hidden_layers return (start_layer, end_layer) @dataclasses.dataclass class StatelessProcessGroup: """A dataclass to hold a metadata store, and the rank, world_size of the group. Only use it to communicate metadata between processes. For data-plane communication, create NCCL-related objects. """ rank: int world_size: int store: torch._C._distributed_c10d.Store data_expiration_seconds: int = 3600 # 1 hour # dst rank -> counter send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict) # src rank -> counter recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict) broadcast_send_counter: int = 0 broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict) # A deque to store the data entries, with key and timestamp. entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque) def __post_init__(self): assert self.rank < self.world_size self.send_dst_counter = {i: 0 for i in range(self.world_size)} self.recv_src_counter = {i: 0 for i in range(self.world_size)} self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)} def send_obj(self, obj: Any, dst: int): """Send an object to a destination rank.""" self.expire_data() key = f"send_to/{dst}/{self.send_dst_counter[dst]}" self.store.set(key, pickle.dumps(obj)) self.send_dst_counter[dst] += 1 self.entries.append((key, time.time())) def expire_data(self): """Expire data that is older than `data_expiration_seconds` seconds.""" while self.entries: # check the oldest entry key, timestamp = self.entries[0] if time.time() - timestamp > self.data_expiration_seconds: self.store.delete_key(key) self.entries.popleft() else: break def recv_obj(self, src: int) -> Any: """Receive an object from a source rank.""" obj = pickle.loads( self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}") ) self.recv_src_counter[src] += 1 return obj def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: """Broadcast an object from a source rank to all other ranks. It does not clean up after all ranks have received the object. Use it for limited times, e.g., for initialization. """ if self.rank == src: self.expire_data() key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}" self.store.set(key, pickle.dumps(obj)) self.broadcast_send_counter += 1 self.entries.append((key, time.time())) return obj else: key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}" recv_obj = pickle.loads(self.store.get(key)) self.broadcast_recv_src_counter[src] += 1 return recv_obj def all_gather_obj(self, obj: Any) -> list[Any]: """All gather an object from all ranks.""" gathered_objs = [] for i in range(self.world_size): if i == self.rank: gathered_objs.append(obj) self.broadcast_obj(obj, src=self.rank) else: recv_obj = self.broadcast_obj(None, src=i) gathered_objs.append(recv_obj) return gathered_objs def barrier(self): """A barrier to synchronize all ranks.""" for i in range(self.world_size): if i == self.rank: self.broadcast_obj(None, src=self.rank) else: self.broadcast_obj(None, src=i) @staticmethod def create( host: str, port: int, rank: int, world_size: int, data_expiration_seconds: int = 3600, ) -> "StatelessProcessGroup": """A replacement for `torch.distributed.init_process_group` that does not pollute the global state. If we have process A and process B called `torch.distributed.init_process_group` to form a group, and then we want to form another group with process A, B, C, D, it is not possible in PyTorch, because process A and process B have already formed a group, and process C and process D cannot join that group. This function is a workaround for this issue. `torch.distributed.init_process_group` is a global call, while this function is a stateless call. It will return a `StatelessProcessGroup` object that can be used for exchanging metadata. With this function, process A and process B can call `StatelessProcessGroup.create` to form a group, and then process A, B, C, and D can call `StatelessProcessGroup.create` to form another group. """ # noqa store = TCPStore( host_name=host, port=port, world_size=world_size, is_master=(rank == 0), ) return StatelessProcessGroup( rank=rank, world_size=world_size, store=store, data_expiration_seconds=data_expiration_seconds, ) ================================================ FILE: archive/ktransformers/server/balance_serve/inference/forward_batch.py ================================================ ''' Date: 2024-11-12 14:15:16 LastEditors: Xie Weiyu ervinxie@qq.com LastEditTime: 2024-11-26 08:12:49 ''' import torch try: import torch_npu use_torch_npu = torch_npu.npu.is_available() except: use_torch_npu = False from ktransformers.server.balance_serve.settings import sched_ext from ktransformers.server.balance_serve.inference.query_manager import QueryManager, QueryInfo from typing import Union import time from ktransformers.server.config.config import Config class ForwardMiniBatchCombine: q_indptr: torch.Tensor kv_indptr: torch.Tensor kv_indices: torch.Tensor kv_last_page_len: torch.Tensor kv_len: torch.Tensor position_ids: torch.Tensor tokens: torch.Tensor batch_indices: torch.Tensor positions: torch.Tensor chunk_size: int decode_batch: int is_last_prefill_chunk: bool logits_start: list temperatures: torch.Tensor top_ps: torch.Tensor def __init__(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, device = torch.device('cuda'), page_size = 256): batch_decode = len(decode_querys_info) batch_prefill = len(prefill_querys_info) self.q_indptr = torch.tensor([0], device=device, dtype=torch.int32) self.kv_indptr = torch.tensor([0], device=device, dtype=torch.int32) self.kv_indices = torch.tensor([], device=device, dtype=torch.int32) self.kv_len = torch.tensor([], device=device, dtype=torch.int32) self.kv_last_page_len = torch.tensor([], device=device, dtype=torch.int32) self.position_ids = torch.tensor([], device=device, dtype=torch.int32) self.tokens = torch.tensor([], device=device, dtype=torch.int32) self.temperatures = torch.tensor([], device=device, dtype=torch.float32) self.top_ps = torch.tensor([], device=device, dtype=torch.float32) self.logits_start = [] self.decode_batch = batch_decode self.num_tokens = batch_decode + sum(prefill_l) self.batch_size = batch_decode + batch_prefill for i, prefill_query_info in enumerate(prefill_querys_info): if prefill_query_info != None: prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0 # print(f"block_len: {prefill_kv_block_len}, page_size: {page_size}") self.q_indptr = torch.concat((self.q_indptr, torch.tensor([prefill_l[i] + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0) self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([prefill_kv_block_len + self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0) self.kv_indices = torch.concat((self.kv_indices, prefill_query_info.block_index[:prefill_kv_block_len]), dim=0) self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0) self.kv_len = torch.concat((self.kv_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i])], device=device, dtype=torch.int32)), dim=0) self.position_ids = torch.concat((self.position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0) self.tokens = torch.concat((self.tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0) self.logits_start.append(prefill_l[i] - 1 if len(self.logits_start) == 0 else sum(prefill_l[:i+1])-1) self.temperatures = torch.concat((self.temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0) self.top_ps = torch.concat((self.top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0) for decode_query_info in decode_querys_info: decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size self.q_indptr = torch.concat((self.q_indptr, torch.tensor([1 + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0) self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([decode_kv_block_len+self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0) self.kv_indices = torch.concat((self.kv_indices, decode_query_info.block_index[:decode_kv_block_len]), dim=0) self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0) self.kv_len = torch.concat((self.kv_len, torch.tensor([(decode_query_info.active_position + 1)], device=device, dtype=torch.int32)), dim=0) self.position_ids = torch.concat((self.position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0) if decode_query_info.active_position > 0: self.tokens = torch.concat((self.tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0) else: self.tokens = torch.concat((self.tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0) self.logits_start.append(0 if len(self.logits_start) == 0 else self.logits_start[-1]+1) self.temperatures = torch.concat((self.temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0) self.top_ps = torch.concat((self.top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0) self.q_indptr = self.q_indptr.contiguous() self.kv_indptr = self.kv_indptr.contiguous() self.kv_indices = self.kv_indices.contiguous() self.kv_len = self.kv_len.contiguous() self.kv_last_page_len = self.kv_last_page_len.contiguous() self.position_ids = self.position_ids.contiguous() self.tokens = self.tokens.contiguous() self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32) def fill(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, device = torch.device('cuda'), page_size = 256): batch_decode = len(decode_querys_info) batch_prefill = len(prefill_querys_info) self.q_indptr = torch.tensor([0], device=device, dtype=torch.int32) self.kv_indptr = torch.tensor([0], device=device, dtype=torch.int32) self.kv_indices = torch.tensor([], device=device, dtype=torch.int32) self.kv_len = torch.tensor([], device=device, dtype=torch.int32) self.kv_last_page_len = torch.tensor([], device=device, dtype=torch.int32) new_position_ids = torch.tensor([], device=device, dtype=torch.int32) new_tokens = torch.tensor([], device=device, dtype=torch.int32) self.temperatures = torch.tensor([], device=device, dtype=torch.float32) self.top_ps = torch.tensor([], device=device, dtype=torch.float32) self.logits_start = [] self.decode_batch = batch_decode self.num_tokens = batch_decode + sum(prefill_l) self.batch_size = batch_decode + batch_prefill for i, prefill_query_info in enumerate(prefill_querys_info): prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0 # print(f"block_len: {prefill_kv_block_len}, page_size: {page_size}") self.q_indptr = torch.concat((self.q_indptr, torch.tensor([prefill_l[i] + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0) self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([prefill_kv_block_len + self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0) self.kv_indices = torch.concat((self.kv_indices, prefill_query_info.block_index[:prefill_kv_block_len]), dim=0) self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0) self.kv_len = torch.concat((self.kv_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i])], device=device, dtype=torch.int32)), dim=0) new_position_ids = torch.concat((new_position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0) new_tokens = torch.concat((new_tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0) self.logits_start.append(prefill_l[i] - 1 if len(self.logits_start) == 0 else sum(prefill_l[:i+1])-1) self.temperatures = torch.concat((self.temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0) self.top_ps = torch.concat((self.top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0) for decode_query_info in decode_querys_info: decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size self.q_indptr = torch.concat((self.q_indptr, torch.tensor([1 + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0) self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([decode_kv_block_len+self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0) self.kv_indices = torch.concat((self.kv_indices, decode_query_info.block_index[:decode_kv_block_len]), dim=0) self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0) self.kv_len = torch.concat((self.kv_len, torch.tensor([(decode_query_info.active_position + 1)], device=device, dtype=torch.int32)), dim=0) new_position_ids = torch.concat((new_position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0) if decode_query_info.active_position > 0: new_tokens = torch.concat((new_tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0) else: new_tokens = torch.concat((new_tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0) self.logits_start.append(0 if len(self.logits_start) == 0 else self.logits_start[-1]+1) self.temperatures = torch.concat((self.temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0) self.top_ps = torch.concat((self.top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0) self.q_indptr = self.q_indptr.contiguous() self.kv_indptr = self.kv_indptr.contiguous() self.kv_indices = self.kv_indices.contiguous() self.kv_len = self.kv_len.contiguous() self.kv_last_page_len = self.kv_last_page_len.contiguous() self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32) # copy new_position_ids and new_tokens to self.position_ids and self.tokens # print("new_position_ids: ", new_position_ids) # self.print() self.position_ids[:new_position_ids.size(0)].copy_(new_position_ids) self.position_ids[new_position_ids.size(0):].zero_() self.tokens[:new_tokens.size(0)].copy_(new_tokens) def __str__(self): ret = '' ret += f'=====flash infer forward info:\n' ret += f'q_indptr: {self.q_indptr}, kv_indptr: {self.kv_indptr}, kv_indices: {self.kv_indices}\n' ret += f'kv_len: {self.kv_len}, kv_last_page_len: {self.kv_last_page_len}, bsz_tensor: {self.bsz_tensor}\n' ret += f'position_ids: {self.position_ids}, tokens: {self.tokens}\n' return ret class ForwardMiniBatchSplit: # NPU 流程 prefill 和 decode 分开打包 prefill_batch: int p_q_len: torch.Tensor # (bsz) p_kv_len: torch.Tensor # (bsz) p_position_ids: torch.Tensor # (sum(q_len)) p_tokens: torch.Tensor # (sum(q_len)) p_temperatures: torch.Tensor # (bsz) p_top_ps: torch.Tensor # (bsz) p_block_tables: torch.Tensor # (bsz, max_page_num) p_logits_start: list decode_batch: int d_q_len: torch.Tensor d_kv_len: torch.Tensor d_position_ids: torch.Tensor d_tokens: torch.Tensor d_temperatures: torch.Tensor d_top_ps: torch.Tensor d_block_tables: torch.Tensor # (bsz, max_page_num) d_logits_start: list chunk_size: int is_last_prefill_chunk: bool def __init__( self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, device=None, page_size: int = 256, max_page_num: int = 64, decode_padding_len: int = 1, ): # 统一 NPU 设备 device = torch.device('npu') if prefill_s is None or prefill_l is None: raise ValueError( "[ForwardMiniBatchSplit.__init__] prefill_s / prefill_l 不能为空,chunk prefill 需要这两个参数" ) # 过滤掉 None new_prefill_querys_info: list[QueryInfo] = [ info for info in prefill_querys_info if info is not None ] batch_prefill = len(new_prefill_querys_info) batch_decode = len(decode_querys_info) self.prefill_batch = batch_prefill self.decode_batch = batch_decode self.batch_size = batch_prefill + batch_decode self.num_tokens = batch_decode * decode_padding_len + sum(prefill_l) self.chunk_size = prefill_l[0] if prefill_l else 0 self.is_last_prefill_chunk = True for i, q in enumerate(new_prefill_querys_info): end_pos = prefill_s[i] + prefill_l[i] if end_pos < q.query_length: self.is_last_prefill_chunk = False break # ====================== Prefill 部分 ====================== self.p_q_len = torch.tensor([], device=device, dtype=torch.int32) self.p_kv_len = torch.tensor([], device=device, dtype=torch.int32) self.p_position_ids = torch.tensor([], device=device, dtype=torch.int32) self.p_block_tables = -1 * torch.ones( [self.prefill_batch, max_page_num], device=device, dtype=torch.int32 ) self.p_tokens = torch.tensor([], device=device, dtype=torch.int32) self.p_temperatures = torch.tensor([], device=device, dtype=torch.float32) self.p_top_ps = torch.tensor([], device=device, dtype=torch.float32) self.p_logits_start: list[int] = [] for i, prefill_query_info in enumerate(new_prefill_querys_info): qid = getattr(prefill_query_info, "id", -1) past_len = int(prefill_query_info.active_position) start = int(prefill_s[i]) # current chunk's start position in query_tokens chunk_len = int(prefill_l[i]) kv_len = past_len + chunk_len prefill_kv_block_len = (kv_len + page_size - 1) // page_size # Q length = current chunk length self.p_q_len = torch.concat( ( self.p_q_len, torch.tensor([chunk_len], device=device, dtype=torch.int32), ), dim=0, ) self.p_kv_len = torch.concat( ( self.p_kv_len, torch.tensor([kv_len], device=device, dtype=torch.int32), ), dim=0, ) self.p_block_tables[i, :prefill_kv_block_len] = prefill_query_info.block_index[ :prefill_kv_block_len ] self.p_position_ids = torch.concat( ( self.p_position_ids, torch.arange( start, start + chunk_len, device=device, dtype=torch.int32, ), ), dim=0, ) self.p_tokens = torch.concat( ( self.p_tokens, prefill_query_info.query_tokens[start : start + chunk_len], ), dim=0, ) self.p_logits_start.append( chunk_len - 1 if len(self.p_logits_start) == 0 else sum(prefill_l[: i + 1]) - 1 ) self.p_temperatures = torch.concat( ( self.p_temperatures, torch.tensor( [prefill_query_info.temperature], device=device, dtype=torch.float32, ), ), dim=0, ) self.p_top_ps = torch.concat( ( self.p_top_ps, torch.tensor( [prefill_query_info.top_p], device=device, dtype=torch.float32, ), ), dim=0, ) # ====================== Decode ====================== self.d_q_len = torch.tensor([], device=device, dtype=torch.int32) self.d_kv_len = torch.tensor([], device=device, dtype=torch.int32) self.d_position_ids = torch.tensor([], device=device, dtype=torch.int32) self.d_block_tables = -1 * torch.ones( [self.decode_batch, max_page_num], device=device, dtype=torch.int32 ) self.d_tokens = torch.tensor([], device=device, dtype=torch.int32) self.d_temperatures = torch.tensor([], device=device, dtype=torch.float32) self.d_top_ps = torch.tensor([], device=device, dtype=torch.float32) self.d_logits_start: list[int] = [] for i, decode_query_info in enumerate(decode_querys_info): qid = getattr(decode_query_info, "id", -1) past_len = int(decode_query_info.active_position) decode_kv_block_len = (past_len + decode_padding_len + page_size - 1) // page_size self.d_q_len = torch.concat( ( self.d_q_len, torch.tensor( [decode_padding_len], device=device, dtype=torch.int32 ), ), dim=0, ) self.d_kv_len = torch.concat( ( self.d_kv_len, torch.tensor( [past_len + decode_padding_len], device=device, dtype=torch.int32, ), ), dim=0, ) self.d_block_tables[i, :decode_kv_block_len] = decode_query_info.block_index[ :decode_kv_block_len ] self.d_position_ids = torch.concat( ( self.d_position_ids, torch.arange( past_len, past_len + decode_padding_len, device=device, dtype=torch.int32, ), ), dim=0, ) if past_len > 0: self.d_tokens = torch.concat( ( self.d_tokens, decode_query_info.query_tokens[ past_len : past_len + decode_padding_len ], ), dim=0, ) else: self.d_tokens = torch.concat( ( self.d_tokens, torch.tensor( [0] * decode_padding_len, device=device, dtype=torch.int32, ), ), dim=0, ) self.d_logits_start.append( 0 if len(self.d_logits_start) == 0 else self.d_logits_start[-1] + decode_padding_len ) self.d_temperatures = torch.concat( ( self.d_temperatures, torch.tensor( [decode_query_info.temperature], device=device, dtype=torch.float32, ), ), dim=0, ) self.d_top_ps = torch.concat( ( self.d_top_ps, torch.tensor( [decode_query_info.top_p], device=device, dtype=torch.float32, ), ), dim=0, ) self.p_q_len = self.p_q_len.contiguous() self.p_kv_len = self.p_kv_len.contiguous() self.p_block_tables = self.p_block_tables.contiguous() self.p_position_ids = self.p_position_ids.contiguous() self.p_tokens = self.p_tokens.contiguous() if self.decode_batch > 1: self.d_q_len = self.d_q_len.reshape(self.decode_batch, -1).contiguous() self.d_kv_len = self.d_kv_len.reshape(self.decode_batch, -1).contiguous() self.d_kv_len_list = self.d_kv_len.flatten().tolist() self.d_block_tables = self.d_block_tables.contiguous() self.d_position_ids = self.d_position_ids.reshape(self.decode_batch, -1).contiguous() self.d_tokens = self.d_tokens.reshape(self.decode_batch, -1).contiguous() else: self.d_q_len = self.d_q_len.contiguous() self.d_kv_len = self.d_kv_len.contiguous() self.d_kv_len_list = self.d_kv_len.flatten().tolist() self.d_block_tables = self.d_block_tables.contiguous() self.d_position_ids = self.d_position_ids.contiguous() self.d_tokens = self.d_tokens.contiguous() self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32) def fill( self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, decode_padding_len: int = 1, device=None, page_size: int = 256, max_page_num: int = 64, ): device = torch.device('npu') if prefill_s is None or prefill_l is None: raise ValueError( "[ForwardMiniBatchSplit.fill] prefill_s / prefill_l 不能为空,chunk prefill 需要这两个参数" ) page_size = 128 new_prefill_querys_info: list[QueryInfo] = [ info for info in prefill_querys_info if info is not None ] batch_prefill = len(new_prefill_querys_info) batch_decode = len(decode_querys_info) self.prefill_batch = batch_prefill self.decode_batch = batch_decode self.batch_size = batch_prefill + batch_decode self.num_tokens = batch_decode * decode_padding_len + sum(prefill_l) self.chunk_size = prefill_l[0] if prefill_l else 0 self.is_last_prefill_chunk = True for i, q in enumerate(new_prefill_querys_info): end_pos = prefill_s[i] + prefill_l[i] if end_pos < q.query_length: self.is_last_prefill_chunk = False break # ---------- Prefill ---------- self.p_q_len = torch.tensor([], device=device, dtype=torch.int32) self.p_kv_len = torch.tensor([], device=device, dtype=torch.int32) new_p_position_ids = torch.tensor([], device=device, dtype=torch.int32) self.p_block_tables = torch.zeros( [self.prefill_batch, max_page_num], device=device, dtype=torch.int32 ) new_p_tokens = torch.tensor([], device=device, dtype=torch.int32) self.p_temperatures = torch.tensor([], device=device, dtype=torch.float32) self.p_top_ps = torch.tensor([], device=device, dtype=torch.float32) self.p_logits_start = [] for i, prefill_query_info in enumerate(new_prefill_querys_info): qid = getattr(prefill_query_info, "id", -1) past_len = int(prefill_query_info.active_position) start = int(prefill_s[i]) chunk_len = int(prefill_l[i]) kv_len = past_len + chunk_len prefill_kv_block_len = (kv_len + page_size - 1) // page_size self.p_q_len = torch.concat( ( self.p_q_len, torch.tensor([chunk_len], device=device, dtype=torch.int32), ), dim=0, ) self.p_kv_len = torch.concat( ( self.p_kv_len, torch.tensor([kv_len], device=device, dtype=torch.int32), ), dim=0, ) self.p_block_tables[i, :prefill_kv_block_len] = prefill_query_info.block_index[ :prefill_kv_block_len ] new_p_position_ids = torch.concat( ( new_p_position_ids, torch.arange( start, start + chunk_len, device=device, dtype=torch.int32, ), ), dim=0, ) new_p_tokens = torch.concat( ( new_p_tokens, prefill_query_info.query_tokens[start : start + chunk_len], ), dim=0, ) self.p_logits_start.append( chunk_len - 1 if len(self.p_logits_start) == 0 else sum(prefill_l[: i + 1]) - 1 ) self.p_temperatures = torch.concat( ( self.p_temperatures, torch.tensor( [prefill_query_info.temperature], device=device, dtype=torch.float32, ), ), dim=0, ) self.p_top_ps = torch.concat( ( self.p_top_ps, torch.tensor( [prefill_query_info.top_p], device=device, dtype=torch.float32, ), ), dim=0, ) if new_p_position_ids.numel() > 0: self.p_position_ids = new_p_position_ids.contiguous() if new_p_tokens.numel() > 0: self.p_tokens = new_p_tokens.contiguous() # ---------- Decode ---------- self.d_q_len = torch.zeros( [1] * self.decode_batch, device=device, dtype=torch.int32 ) self.d_kv_len = torch.tensor([], device=device, dtype=torch.int32) new_d_position_ids = torch.tensor([], device=device, dtype=torch.int32) new_d_block_tables = -1 * torch.ones( [self.decode_batch, max_page_num], device=device, dtype=torch.int32 ) new_d_tokens = torch.tensor([], device=device, dtype=torch.int32) self.d_logits_start = [] self.d_temperatures = torch.tensor([], device=device, dtype=torch.float32) self.d_top_ps = torch.tensor([], device=device, dtype=torch.float32) for i, decode_query_info in enumerate(decode_querys_info): qid = getattr(decode_query_info, "id", -1) past_len = int(decode_query_info.active_position) decode_kv_block_len = (past_len + decode_padding_len + page_size - 1) // page_size self.d_kv_len = torch.concat( ( self.d_kv_len, torch.tensor( [past_len + decode_padding_len], device=device, dtype=torch.int32, ), ), dim=0, ) new_d_block_tables[i, :decode_kv_block_len] = decode_query_info.block_index[ :decode_kv_block_len ] new_d_position_ids = torch.concat( ( new_d_position_ids, torch.arange( past_len, past_len + decode_padding_len, device=device, dtype=torch.int32, ), ), dim=0, ) if past_len > 0: new_d_tokens = torch.concat( ( new_d_tokens, decode_query_info.query_tokens[ past_len : past_len + decode_padding_len ], ), dim=0, ) else: new_d_tokens = torch.concat( ( new_d_tokens, torch.tensor( [0] * decode_padding_len, device=device, dtype=torch.int32, ), ), dim=0, ) self.d_logits_start.append( 0 if len(self.d_logits_start) == 0 else self.d_logits_start[-1] + decode_padding_len ) self.d_temperatures = torch.concat( ( self.d_temperatures, torch.tensor( [decode_query_info.temperature], device=device, dtype=torch.float32, ), ), dim=0, ) self.d_top_ps = torch.concat( ( self.d_top_ps, torch.tensor( [decode_query_info.top_p], device=device, dtype=torch.float32, ), ), dim=0, ) if len(decode_querys_info) > 1: self.d_position_ids[i].copy_(new_d_position_ids[i]) self.d_tokens[i].copy_(new_d_tokens[i]) self.d_block_tables[i].copy_(new_d_block_tables[i]) else: self.d_position_ids[:new_d_position_ids.size(0)].copy_(new_d_position_ids) self.d_tokens[:new_d_tokens.size(0)].copy_(new_d_tokens) self.d_block_tables[0].copy_(new_d_block_tables[0]) self.p_q_len = self.p_q_len.contiguous() self.p_kv_len = self.p_kv_len.contiguous() self.p_block_tables = self.p_block_tables.contiguous() self.d_q_len = self.d_q_len.contiguous() self.d_kv_len = self.d_kv_len.contiguous() self.d_kv_len_list = self.d_kv_len.flatten().tolist() self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32) def __str__(self): ret = '' ret += '=======Prefill forward info:\n' ret += f'batch: {self.prefill_batch}, qLen: {self.p_q_len}, kvLen: {self.p_kv_len}\n' ret += f'tokens: {self.p_tokens}, posIdx: {self.p_position_ids}, block_tables: {self.p_block_tables}\n' ret += '=======Decode forward info:\n' ret += f'batch: {self.decode_batch}, qLen: {self.d_q_len}, kvLen: {self.d_kv_len}\n' ret += f'tokens: {self.d_tokens}, posIdx: {self.d_position_ids}, block_tables: {self.d_block_tables}\n' ret += f'chunk_size={self.chunk_size}, is_last_prefill_chunk={self.is_last_prefill_chunk}\n' return ret class ForwardBatchInput: forward_minibatchs: list[Union[ForwardMiniBatchSplit, ForwardMiniBatchCombine]] decode_mini_batches: list[Union[ForwardMiniBatchSplit, ForwardMiniBatchCombine]] batch_size: int minibatch: Union[ForwardMiniBatchSplit, ForwardMiniBatchCombine] def __init__(self, batch : sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None, device=None, tokens: torch.Tensor = None): if batch is None: return prefill_minibatches = batch.prefill_mini_batches decode_mini_batches = [item for sublist in batch.decode_mini_batches for item in sublist] prefill_querys_info = [] prefill_s = [] prefill_l = [] decode_querys_info = [] self.batch_size = 1 for (qid, s, l) in prefill_minibatches: prefill_querys_info.append(query_manager.query_map[qid]) prefill_s.append(s) prefill_l.append(l) for decode_qid in decode_mini_batches: qinfo = query_manager.query_map[decode_qid] if qinfo.decode_start_time is None: qinfo.decode_start_time = time.time() decode_querys_info.append(qinfo) if use_torch_npu: minibatch = ForwardMiniBatchSplit(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device = query_manager.device, page_size = query_manager.page_size) else: minibatch = ForwardMiniBatchCombine(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device = query_manager.device, page_size = query_manager.page_size) self.minibatch = minibatch @classmethod def gen_max_forward_batch( cls, device=None, tokens: torch.Tensor = None, num_mini_batches: int = 1, max_seq_length: int = 1024, # TODO: add to yaml prefill_query_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size, # TODO: use config prefill_active_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size, gen_prefill: bool = True, decode_batch_size: int = Config().max_decode_batch_size, decode_query_length: int = 1, decode_active_position: torch.Tensor = None, page_size = 256, cuda_lens = 1 ): instance = cls() instance.batch_size = num_mini_batches page_size = page_size prefill_query_info = [] offset = 0 if gen_prefill and prefill_query_length != 0: for i in range(Config().max_prefill_batch_size): prefill_query_info.append(QueryInfo(i, prefill_query_length, max_seq_length, page_size, device, offset=offset)) offset += max_seq_length // page_size decode_querys_info = [] for i in range(min(decode_batch_size, cuda_lens)): query_info = QueryInfo(i+Config().max_prefill_batch_size, decode_query_length, max_seq_length, page_size, device, is_prefill=False, offset=offset) offset += max_seq_length // page_size if tokens is not None: query_info.query_tokens[prefill_active_length:prefill_active_length + decode_query_length].copy_(tokens) if decode_active_position is None: query_info.active_position = prefill_active_length else: query_info.active_position = decode_active_position[i] decode_querys_info.append(query_info) if prefill_query_length * Config().max_prefill_batch_size + len(decode_querys_info) < cuda_lens: decode_querys_info.append(query_info) if use_torch_npu: instance.minibatch = ForwardMiniBatchSplit(prefill_query_info, decode_querys_info, [0, 0], [prefill_active_length for _ in range(Config().max_prefill_batch_size)], device, page_size, decode_padding_len=decode_query_length) else: instance.minibatch = ForwardMiniBatchCombine(prefill_query_info, decode_querys_info, [0, 0], [prefill_active_length for _ in range(Config().max_prefill_batch_size)], device, page_size) return instance def fill(self, batch : sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None, page_size = 256): if batch is None: return prefill_minibatches = batch.prefill_mini_batches decode_mini_batches = [item for sublist in batch.decode_mini_batches for item in sublist] prefill_querys_info = [] prefill_s = [] prefill_l = [] decode_querys_info = [] self.batch_size = 1 for (id, s, l) in prefill_minibatches: prefill_querys_info.append(query_manager.query_map[id]) prefill_s.append(s) prefill_l.append(l) for decode_batch_idx in decode_mini_batches: if query_manager.query_map[decode_batch_idx].decode_start_time is None: query_manager.query_map[decode_batch_idx].decode_start_time =time.time() decode_querys_info.append(query_manager.query_map[decode_batch_idx]) self.minibatch.fill(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device=query_manager.device, page_size=page_size) class ForwardBatchOutput: logits: list[torch.Tensor] pre_hidden_states: list[torch.Tensor] num_batchs: int batch_sizes: list[int] generated_tokens_num: list[int] lm_start: list[int] temperatures: list[torch.Tensor] top_ps: list[torch.Tensor] def __init__(self): self.num_batchs = 0 self.lm_start = [] self.logits = [] self.batch_sizes = [] self.generated_tokens_num = [] self.top_ps = [] self.temperatures = [] self.pre_hidden_states = [] pass def merge(self, new_output): self.logits.extend(new_output.logits) self.num_batchs += new_output.num_batchs self.batch_sizes.extend(new_output.batch_sizes) self.generated_tokens_num.extend(new_output.generated_tokens_num) self.top_ps.extend(new_output.top_ps) self.temperatures.extend(new_output.temperatures) self.lm_start.extend(new_output.lm_start) self.pre_hidden_states.extend(new_output.pre_hidden_states) def __str__(self): logits_shape = [t.shape for t in self.logits] ret = '' ret += f'=======Combined output info:\n' ret += f'logits: {self.logits}\n' ret += f'logits(size): {logits_shape}, num_batchs: {self.num_batchs}, kvLen: {self.generated_tokens_num}\n' ret += f'top_ps: {self.top_ps}, temperatures: {self.temperatures}, pre_hidden_states num: {len(self.pre_hidden_states)}\n' if len(self.pre_hidden_states) != 0: for idx in range(len(self.pre_hidden_states)): ret += f'idx: {idx}, pre_hidden_states shape: {self.pre_hidden_states[idx].shape}\n' return ret ================================================ FILE: archive/ktransformers/server/balance_serve/inference/model_runner.py ================================================ """ Date: 2024-11-07 07:02:20 LastEditors: djw LastEditTime: 2024-12-10 08:48:32 """ import os.path import threading import torch from torch import nn import queue import signal import queue from typing import AsyncIterable from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse from contextlib import asynccontextmanager from pydantic import BaseModel, Field import asyncio import multiprocessing import time import torch.multiprocessing as mp import random import torch.distributed as dist import zmq import copy import tempfile from ktransformers.server.balance_serve.inference.forward_batch import ( ForwardBatchInput, ForwardBatchOutput, ForwardMiniBatchCombine, ForwardMiniBatchSplit) from ktransformers.util import utils from ktransformers.server.config.config import Config from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM from ktransformers.models.custom_modeling_qwen3_next import KQwen3NextForCausalLM from ktransformers.server.balance_serve.inference.query_manager import QueryManager from ktransformers.server.balance_serve.settings import sched_ext try: import torch_npu use_torch_npu = torch_npu.npu.is_available() from ktransformers.models.ascend.custom_ascend_modeling_deepseek_v3 import KNPUDeepseekV3ForCausalLM from ktransformers.models.ascend.custom_ascend_modeling_qwen3 import KNPUQwen3MoeForCausalLM from ktransformers.models.custom_cache import KVC2StaticCache, KVC2Qwen3Cache except: use_torch_npu = False def pad_num_tokens(num_tokens): return (num_tokens + 63) // 64 * 64 def deduplicate_and_sort(lst): return sorted(set(lst)) def generate_cuda_graphs(chunk_size: int) -> list: # 如果输入不符合要求,assert掉 assert chunk_size <= 1024 or chunk_size % 1024 == 0, "chunk_size must <= 1024 or a multiple of 1024" base_list = [1, 2, 3, Config().max_batch_size, 64, 256, 512, chunk_size] if chunk_size <= 1024: return deduplicate_and_sort(base_list) multiples = [i for i in range(1024, chunk_size + 1, 1024)] return deduplicate_and_sort(base_list + multiples) class ModelRunner: """A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile.""" if not use_torch_npu: model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallThinkerForCausalLM | KGlm4MoeForCausalLM | KQwen3NextForCausalLM else: model: KNPUDeepseekV3ForCausalLM | KNPUQwen3MoeForCausalLM cache: KVC2StaticCache | KVC2Qwen3Cache input: ForwardBatchInput | list[ForwardBatchInput] output: ForwardBatchOutput def __init__(self, model = None, cache = None, device = None, use_cuda_graph = False, max_decode_batch_size = 1, max_chunk_size = 4096, num_mini_batches: int = 1, page_size = 256, block_num = 8): # 先注释掉 self.model = model # Compile and move model to the specified device if use_torch_npu: self.stream = torch.npu.Stream(device=device) self.stream_scope = torch.npu.stream self.input_decode = [] max_batch_size = 1 if Config().max_batch_size <= 1 else Config().max_batch_size self.npu_graphs = sorted(set([i for i in range(1, max_batch_size + 1)])) self.model.stream = self.stream # npu do not support multi stream like this if use_cuda_graph: torch_npu.npu._subscribe_report(self.stream) self.start_model_event = torch.npu.Event(enable_timing=True) self.end_model_event = torch.npu.Event(enable_timing=True) else: self.stream = torch.cuda.Stream(device=device) self.cuda_graphs = generate_cuda_graphs(Config().chunk_size) self.start_model_event = torch.cuda.Event(enable_timing=True) self.end_model_event = torch.cuda.Event(enable_timing=True) self.device = device self.input = None self.features_buf = None self.output = None self.graph_memory_pool = None self.cache = cache #TODO 删掉了一行 self.cuda_graphs = generate_cuda_graphs(Config().chunk_size) 是为何,这样下面不会影响GPU吗 self.use_cuda_graph = use_cuda_graph self.debug = False self.model_time = 0 self.page_size = page_size self.block_num = block_num if 'cuda' in device: self.graphs = [torch.cuda.CUDAGraph() for _ in range(len(self.cuda_graphs))] self.page_idx_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))] self.page_offset_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))] elif 'npu' in device: self.workspace = [None for _ in range(len(self.npu_graphs))] self.graphs = [torch.npu.NPUGraph() for _ in range(len(self.npu_graphs))] self.page_idx_buf = [torch.zeros((self.npu_graphs[i], 1), dtype=torch.int32, device = self.device) for i in range(len(self.npu_graphs))] self.page_offset_buf = [torch.zeros((self.npu_graphs[i], 1), dtype=torch.int32, device = self.device) for i in range(len(self.npu_graphs))] else: self.graphs, self.page_idx_buf, self.page_offset_buf = None, None, None self.num_mini_batches = num_mini_batches self.max_chunk_size = max_chunk_size self.bsz_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device) self.num_tokens_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device) def model_attn_plan(self, batch, cuda_graph_idx=0): if isinstance(self.model, KDeepseekV3ForCausalLM): self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf, num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True, sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM) or isinstance(self.model, KQwen3NextForCausalLM): self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf, num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads, head_dim=self.model.config.head_dim if hasattr(self.model.config, 'head_dim') else self.model.config.hidden_size // self.model.config.num_attention_heads, page_size=self.model.cache.page_size, causal=True, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, cuda_graph_idx=cuda_graph_idx) else: assert False, "model type not supported" def warmup(self): def capture_graphs(cuda_graph_idx): with torch.cuda.graph(self.graphs[cuda_graph_idx], pool=self.graph_memory_pool, stream=self.stream): self.outputs_buf[cuda_graph_idx] = self.model(self.input[cuda_graph_idx], self.features_buf[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[cuda_graph_idx], self.page_offset_buf[cuda_graph_idx], cuda_graph_idx=cuda_graph_idx) self.graph_memory_pool = self.graphs[cuda_graph_idx].pool() self.input = [] self.features_buf = [] self.outputs_buf = [] self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device) self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device) for i in range(len(self.cuda_graphs)): prefill_query_length = (self.cuda_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.cuda_graphs[i] > Config().max_decode_batch_size else 0 #@TODO only supprot 2 prefill batch self.input.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, prefill_query_length=prefill_query_length, prefill_active_length=prefill_query_length, page_size=self.page_size, cuda_lens=self.cuda_graphs[i])) self.features_buf.append(self.model.batch_embeddings(self.input[i])) batch_size = self.input[i].minibatch.q_indptr.size(0)-1 num_tokens = self.features_buf[i][0].size(0) print("capturing cuda graph", batch_size, num_tokens) if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM) or isinstance(self.model, KQwen3NextForCausalLM): self.model.init_wrapper(self.use_cuda_graph, self.device, num_tokens ,batch_size, self.block_num, i) # TODO: 1024 is a magic number(max_batch_tokens) self.bsz_tensor_buf[0] = batch_size self.num_tokens_tensor_buf[0] = num_tokens self.model_attn_plan(self.input[i], i) page_idx, page_offset = self.model.cache.get_page_table(self.input[i].minibatch.position_ids, self.input[i].minibatch.q_indptr, self.input[i].minibatch.kv_indptr, self.input[i].minibatch.kv_indices, self.num_tokens_tensor_buf) self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens]) self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens]) self.page_idx_buf[i][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1) self.outputs_buf.append(None) torch.cuda.synchronize() for warm_up_iters in range(11): with torch.cuda.stream(self.stream): self.outputs_buf[i] = self.model(self.input[i], self.features_buf[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i], cuda_graph_idx=i) torch.cuda.synchronize() self.outputs_buf[i].num_batchs = batch_size capture_graphs(i) with torch.cuda.stream(self.stream): self.graphs[i].replay() self.sync(calc_time=False) print(f"cuda_graph: {i+1}/{len(self.cuda_graphs)}, warmup finished.") def warmup_npu(self): # npu 当前使用PD分离 # 当前只支持 decode 阶段的图下沉 # 多batch 场景下只支持 1 2 3 4 5 6 7 8 def capture_graphs(npu_graph_idx): utils._USE_NPU_GRAPH = True print("self.features_buf[npu_graph_idx] is ", self.features_buf[npu_graph_idx]) with torch.npu.graph(self.graphs[npu_graph_idx], pool=self.graph_memory_pool, stream=self.stream, auto_dispatch_capture=True): self.outputs_buf[npu_graph_idx] = self.model( self.input_decode[npu_graph_idx], self.features_buf[npu_graph_idx], self.cache, None, None, self.page_idx_buf[npu_graph_idx], self.page_offset_buf[npu_graph_idx], self.position_ids_buf[npu_graph_idx], self.block_tables_buf[npu_graph_idx], cuda_graph_idx=npu_graph_idx, is_prefill=False ) self.graph_memory_pool = self.graphs[npu_graph_idx].pool() utils._USE_NPU_GRAPH = False self.features_buf = [] self.outputs_buf = [] self.position_ids_buf = [] self.block_tables_buf = [] self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device) self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device) for i in range(len(self.npu_graphs)): prefill_query_length = (self.npu_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.npu_graphs[i] > Config().max_decode_batch_size else 0 #@TODO only supprot 2 prefill batch self.input_decode.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, decode_batch_size=self.npu_graphs[i], prefill_active_length=1, page_size=self.page_size, cuda_lens = self.npu_graphs[i])) self.features_buf.append(self.model.batch_embeddings(self.input_decode[i], device=self.device, is_prefill=False)) batch_size = self.npu_graphs[i] num_tokens = batch_size self.bsz_tensor_buf[0] = batch_size self.num_tokens_tensor_buf[0] = num_tokens page_idx, page_offset = self.cache.get_page_table(self.input_decode[i].minibatch, self.num_tokens_tensor_buf, is_prefill=False) self.position_ids_buf.append(self.input_decode[i].minibatch.d_position_ids.clone()) self.block_tables_buf.append(self.input_decode[i].minibatch.d_block_tables.clone()) self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens][0]) page_offset = page_offset.view(self.page_offset_buf[i].size()) self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens]) self.page_idx_buf[i][num_tokens:].fill_(self.cache.max_cache_len // self.cache.page_size -1) self.outputs_buf.append(None) torch.npu.synchronize() for warm_up_iters in range(11): with torch.npu.stream(self.stream): self.outputs_buf[i] = self.model(self.input_decode[i], self.features_buf[i], self.cache, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i], self.position_ids_buf[i], self.block_tables_buf[i], is_prefill=False) torch.npu.synchronize() capture_graphs(i) self.replay(i) self.sync(calc_time=False) print(f"npu_graph: {i+1}/{len(self.npu_graphs)}, warmup finished.") def run(self, batch: sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None): with torch.cuda.stream(self.stream): batch_size = len(batch.prefill_mini_batches) # TODO: calc this num_tokens = 0 for i in range(len(batch.decode_mini_batches)): batch_size += len(batch.decode_mini_batches[i]) num_tokens += len(batch.decode_mini_batches[i]) print(f'decode_batch_i: {len(batch.decode_mini_batches[i])},') for i in range(len(batch.prefill_mini_batches)): num_tokens += batch.prefill_mini_batches[i][2] print(f'prefill_batch_i: {batch.prefill_mini_batches[i][2]},') # cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens cuda_graph_idx = next((i for i, token in enumerate(self.cuda_graphs) if token >= num_tokens), len(self.cuda_graphs)) if not self.use_cuda_graph: cuda_graph_idx = 0 if self.use_cuda_graph: self.input[cuda_graph_idx].fill(batch, query_manager, self.page_size) else: self.input = [ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device)] if self.use_cuda_graph: self.features = self.model.batch_embeddings(self.input[cuda_graph_idx], device=self.device) self.bsz_tensor_buf.copy_(batch_size) self.num_tokens_tensor_buf.copy_(torch.tensor([num_tokens], dtype=torch.int32, device=self.device)) if self.use_cuda_graph: self.features_buf[cuda_graph_idx][0].copy_(self.features[0], non_blocking=True) self.model_attn_plan(self.input[cuda_graph_idx], cuda_graph_idx) self.start_model_event.record(self.stream) if self.use_cuda_graph: self.model.flash_infer_attn_plan(self.input[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size, causal=True, sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) self.start_model_event.record(self.stream) if use_torch_npu: page_idx, page_offset = self.cache.get_page_table(self.input[cuda_graph_idx].minibatch, self.bsz_tensor_buf) #TODO csx minibatch self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.cache.max_cache_len // self.cache.page_size - 1) else: page_idx, page_offset = self.model.cache.get_page_table(self.input[cuda_graph_idx].minibatch.position_ids, self.input[cuda_graph_idx].minibatch.q_indptr, self.input[cuda_graph_idx].minibatch.kv_indptr, self.input[cuda_graph_idx].minibatch.kv_indices, self.num_tokens_tensor_buf) self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1) self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens]) self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens]) self.replay(cuda_graph_idx) self.output = ForwardBatchOutput() self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps) self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures) self.output.logits.append(self.outputs_buf[cuda_graph_idx].logits[0][self.input[cuda_graph_idx].minibatch.logits_start].clone()) self.end_model_event.record(self.stream) else: self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf, num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size, causal=True, sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) self.start_model_event.record(self.stream) page_idx, page_offset = self.cache.get_page_table(self.input[cuda_graph_idx].minibatch, self.bsz_tensor_buf) self.output = self.model(self.input, self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset) self.output.logits[0] = self.output.logits[0][self.input.minibatch.logits_start] self.output.top_ps.append(self.input.minibatch.top_ps) self.output.temperatures.append(self.input.minibatch.temperatures) self.end_model_event.record(self.stream) if not self.use_cuda_graph: self.output.num_batchs = self.input.batch_size else: self.output.num_batchs = self.input[cuda_graph_idx].batch_size def run_split(self, batch: sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None): """running without flashinfer and prefill & decode split infer""" def _run_infer_stage(is_prefill=True): if "npu" in self.device: cuda_graph_idx = batch_size_decode if is_prefill == False: if cuda_graph_idx != -1 and self.use_cuda_graph: self.features = self.model.batch_embeddings(self.input_decode[cuda_graph_idx], device=self.device, is_prefill=is_prefill) else: self.features = self.model.batch_embeddings(self.input, device=self.device, is_prefill=is_prefill) self.bsz_tensor_buf.copy_(batch_size_decode) if self.use_cuda_graph: if cuda_graph_idx != -1: self.features_buf[cuda_graph_idx].copy_(self.features) else: self.features_buf.copy_(self.features) else: self.features = self.model.batch_embeddings(self.input, device=self.device, is_prefill=is_prefill) self.bsz_tensor_buf.copy_(batch_size_decode) if cuda_graph_idx != -1 and self.use_cuda_graph and is_prefill == False: num_tokens = batch_size_decode + 1 self.start_model_event.record(self.stream) if self.start_model_event else None page_idx, page_offset = self.cache.get_page_table(self.input_decode[cuda_graph_idx].minibatch, self.bsz_tensor_buf, is_prefill=is_prefill) self.position_ids_buf[cuda_graph_idx].copy_(self.input_tmp.minibatch.d_position_ids) self.block_tables_buf[cuda_graph_idx].copy_(self.input_tmp.minibatch.d_block_tables) self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens]) self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens]) self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.cache.max_cache_len // self.cache.page_size - 1) self.replay(cuda_graph_idx) new_output = ForwardBatchOutput() for i in range(num_tokens): new_output.top_ps.append(self.input_decode[cuda_graph_idx].minibatch.d_top_ps[i]) new_output.temperatures.append(self.input_decode[cuda_graph_idx].minibatch.d_temperatures[i]) new_output.logits.append(self.outputs_buf[cuda_graph_idx].logits[i].clone()) # TODO support MTP self.end_model_event.record(self.stream) if self.start_model_event else None if self.output is None: self.output = copy.deepcopy(new_output) else: self.output.merge(new_output) else: self.start_model_event.record(self.stream) if self.start_model_event else None page_idx, page_offset = self.cache.get_page_table(self.input.minibatch, self.num_tokens_tensor_buf, is_prefill=is_prefill) new_output = self.model(self.input, self.features, self.cache, None, None, page_idx, page_offset, None, None, is_prefill=is_prefill) bsz = len(new_output.logits) if is_prefill: for i in range(bsz): new_output.logits[i] = new_output.logits[i][-1:, :] # batched tensor do not need location new_output.top_ps.append(self.input.minibatch.p_top_ps[i]) new_output.temperatures.append(self.input.minibatch.p_temperatures[i]) else: for i in range(bsz): new_output.top_ps.append(self.input.minibatch.d_top_ps[i]) new_output.temperatures.append(self.input.minibatch.d_temperatures[i]) if self.output is None: self.output = copy.deepcopy(new_output) else: self.output.merge(new_output) self.end_model_event.record(self.stream) if self.end_model_event else None with self.stream_scope(self.stream): batch_size = len(batch.prefill_mini_batches) # TODO: calc this num_d_tokens, num_p_tokens = 0, 0 for i in range(len(batch.decode_mini_batches)): batch_size += len(batch.decode_mini_batches[i]) num_d_tokens += len(batch.decode_mini_batches[i]) if self.debug: print(f'decode_batch_i: {len(batch.decode_mini_batches[i])}, token_num: {len(batch.decode_mini_batches[i])} ,batch_size: {batch_size}') for i in range(len(batch.prefill_mini_batches)): num_p_tokens += batch.prefill_mini_batches[i][2] if self.debug: print(f'prefill_batch_i: {batch.prefill_mini_batches[i][2]}, token_num: {batch.prefill_mini_batches[i][2]}') # batch info holder both in graph mode & kernel mode self.input_tmp = ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device) batch_size_decode = self.input_tmp.minibatch.decode_batch - 1 idx = self.input_tmp.minibatch.decode_batch - 1 cuda_graph_idx = batch_size_decode self.output = None # clear last step output if self.input_tmp.minibatch.decode_batch > 0: if self.use_cuda_graph and len(self.input_decode) > 0: self.input_decode[idx].fill(batch, query_manager, self.page_size) else: self.input = self.input_tmp assert isinstance(self.input.minibatch, ForwardMiniBatchSplit), 'split batch input type must be ForwardMiniBatchSplit' print(self.input.minibatch) if self.debug else None if self.input_tmp.minibatch.prefill_batch > 0: self.input = self.input_tmp assert isinstance(self.input.minibatch, ForwardMiniBatchSplit), 'split batch input type must be ForwardMiniBatchSplit' print(self.input.minibatch) if self.debug else None # ++++++++++++++++++++++++++++++++++++++++++ Prefill Stage ++++++++++++++++++++++++++++++++++++++++++++++++ if self.input_tmp.minibatch.prefill_batch > 0: _run_infer_stage(is_prefill=True) self.output.num_batchs = self.input.minibatch.batch_size # ++++++++++++++++++++++++++++++++++++++++++ Decode Stage ++++++++++++++++++++++++++++++++++++++++++++++++ if self.input_tmp.minibatch.decode_batch > 0: if self.use_cuda_graph: _run_infer_stage(is_prefill=False) self.output.num_batchs = self.input_decode[idx].minibatch.batch_size else: _run_infer_stage(is_prefill=False) self.output.num_batchs = self.input.minibatch.batch_size print(self.output) if self.debug else None def replay(self, cuda_graph_idx=-1): if use_torch_npu: thread = threading.Thread(target=self.graphs[cuda_graph_idx].update, kwargs={"cpu_update_input": [{"actual_seq_lengths_kv": self.input_decode[cuda_graph_idx].minibatch.d_kv_len_list}]}) thread.start() torch_npu.npu.synchronize() with torch.cuda.stream(self.stream): if cuda_graph_idx != -1: self.graphs[cuda_graph_idx].replay() else: self.graphs.replay() def sync(self, calc_time = True): self.stream.synchronize() if calc_time: self.model_time = self.start_model_event.elapsed_time(self.end_model_event) # In ms def get_or_create_model_runner(model=None, cache=None, device=None, use_cuda_graph=None, page_size=None): from ktransformers.server.balance_serve.inference.config import model_runner_dict runner = model_runner_dict.get(device) if runner is None: print("[WARN] the new ModelRunner and deviceId is ", device) runner = ModelRunner(model, cache, device, use_cuda_graph, page_size) model_runner_dict[device] = runner return runner ================================================ FILE: archive/ktransformers/server/balance_serve/inference/query_manager.py ================================================ ''' Date: 2024-11-14 12:23:45 LastEditors: djw LastEditTime: 2024-11-20 04:06:23 ''' import torch from ktransformers.server.balance_serve.settings import sched_ext import random import time from ktransformers.server.config.config import Config from ktransformers.server.utils.serve_profiling import PROF_TIME_STAT class QueryInfo: id: int active_position: int query_length: int is_prefill: int is_first_token: int block_index: torch.Tensor query_tokens: torch.Tensor stop_criteria: list[torch.Tensor] temperature: float top_p: float max_length: int pos_status: torch.Tensor probs: list[torch.Tensor] acc_position: int def __init__(self, id, query_length: int, max_length: int, page_size: int, device: torch.device, is_prefill: bool = True, offset: int = 0, active_position: int = 0, temperature: float = 0.01, top_p: float = 1.0): self.id = id self.is_prefill = is_prefill self.is_first_token = False self.active_position = active_position self.max_length = max_length - 1 self.query_tokens = torch.zeros((max_length + 2,), dtype=torch.int, device = device) self.stop_criteria = [] self.block_index = torch.arange(offset, offset + (max_length + active_position + page_size - 1) // page_size, dtype=torch.int, device = device) self.query_length = query_length self.enqueue_time = time.time() self.decode_start_time = None self.speculative_token = {} # {position: (accept, token)} self.pos_status = torch.zeros((max_length + 2,), dtype=torch.int, device = device) self.probs = [None] * (max_length + 2) self.acc_tokens_num = 0 self.rej_tokens_num = 0 self.round = 0 self.acc_length = 0 self.acc_position = 0 self.temperature = temperature self.top_p = top_p def check_stop(self): if self.active_position >= self.max_length - 2: if PROF_TIME_STAT.on: PROF_TIME_STAT.print_all() # PROF_TIME_STAT.reset_all() return True # 遍历每个停止条件 for stop_tensor in self.stop_criteria: stop_len = len(stop_tensor) # 如果停止条件比 query_tokens 长,跳过 if stop_len >= self.active_position: continue #print(f"stop_tensor: {stop_tensor}, stop_len: {stop_len}, active_position: {self.active_position}, query_token: {self.query_tokens[self.active_position - stop_len - 1:self.active_position - 1]}") if (torch.equal(self.query_tokens[self.active_position - stop_len - 1:self.active_position - 1], stop_tensor) and self.active_position) or self.max_length <= self.active_position + 3: self.life_time = time.time() - self.enqueue_time self.decode_duration_time = time.time() - self.decode_start_time self.decode_tps = (self.active_position - self.query_length) / self.decode_duration_time print(f"prefill length: {self.query_length}, prefill time: {self.prefill_duration_time}, prefill tps {self.prefill_tps}, decode length: {self.active_position - self.query_length}, decode time: {self.decode_duration_time}, decode tps {self.decode_tps}") if self.acc_tokens_num + self.rej_tokens_num != 0: verify_counts = self.acc_tokens_num + self.rej_tokens_num print(f"mtp accept rate: {self.acc_tokens_num}/{verify_counts} = {self.acc_tokens_num * 100 / verify_counts} %") if PROF_TIME_STAT.on: PROF_TIME_STAT.print_all() # PROF_TIME_STAT.reset_all() return True # 找到匹配的停止条件 return False # 没有找到任何停止条件 def print(self): print(f"active_position: {self.active_position}, query_length: {self.query_length}, is_prefill: {self.is_prefill}") print(f"block_index_shape: {self.block_index.shape}, query_tokens_shape: {self.query_tokens.shape}") print(f"query_tokens_shape: {self.query_tokens}, is_first_token: {self.is_first_token}" ) print(f"pos_status: {self.pos_status}, acc_position: ", self.acc_position) print(f"probs: {self.probs}") class QueryManager: max_length: int = 65536 page_size: int = 256 device: torch.device query_map : dict[int, QueryInfo] def __init__(self, max_length = 65536, page_size = 256, device = torch.device('cuda')): self.max_length = max_length self.page_size = page_size self.device = device self.query_map = {} def print(self, hint: str = ""): print(hint," query_manager: ", self.query_map) for key in self.query_map: query_info = self.query_map[key] print(">>> query: ", key) print("query_info: ") query_info.print() def add_query(self, batch: sched_ext.BatchQueryTodo): for i in range(len(batch.query_ids)): id = batch.query_ids[i] if id not in self.query_map: print(f"add query id: {id}, batch.query_lengths: {batch.query_lengths[i]}, " f"batch_query_tokens: {batch.query_tokens[i].shape}, " f"batch.block_indexes: {batch.block_indexes[i]}") assert batch.query_tokens[i].size(0) < self.max_length, "query max length in batchquerytodo exceeds internal max_length" query_info = QueryInfo(id=id, query_length=batch.query_lengths[i], max_length=batch.query_tokens[i].size(0) + 1, page_size=self.page_size, device=self.device, temperature=batch.sample_options[i].temperature, top_p=batch.sample_options[i].top_p) query_info.query_tokens[:query_info.query_length].copy_(batch.query_tokens[i][:query_info.query_length].to(self.device)) for stop_token_list in batch.stop_criteria[i]: query_info.stop_criteria.append(torch.tensor(stop_token_list, dtype=torch.int, device = self.device)) block_num = batch.block_indexes[i].size(0) query_info.block_index[:block_num].copy_(batch.block_indexes[i].to(self.device)) self.query_map[id] = query_info prefill_mini_batches = batch.prefill_mini_batches for (prefill_id, s, l) in prefill_mini_batches: if prefill_id == id: self.query_map[prefill_id].active_position = s def update(self, batch: sched_ext.BatchQueryTodo) -> list[sched_ext.QueryUpdate]: query_updates = [] prefill_mini_batches = batch.prefill_mini_batches for (id, s, l) in prefill_mini_batches: if id not in self.query_map: assert False, f"query id {id} not found in query_map" # update query_info query_info = self.query_map[id] query_info.active_position += l if query_info.active_position >= query_info.query_length and query_info.is_prefill: query_info.is_prefill = False query_info.is_first_token = True query_info.prefill_duration_time = time.time() - query_info.enqueue_time query_info.prefill_tps = query_info.query_length / query_info.prefill_duration_time # generate schedule query_update query_update = sched_ext.QueryUpdate() query_update.id = id query_update.ok = True query_update.is_prefill = query_info.is_prefill query_update.active_position = query_info.active_position # if(not query_info.is_prefill): query_updates.append(query_update) decode_mini_batches = batch.decode_mini_batches for ids in decode_mini_batches: for id in ids: if id not in self.query_map: assert False, f"query id {id} not found in query_map" query_info = self.query_map[id] query_info.is_first_token = False query_info.active_position += 1 query_update = sched_ext.QueryUpdate() query_update.id = id query_update.ok = True query_update.is_prefill = query_info.is_prefill query_update.decode_done = query_info.check_stop() query_update.active_position = query_info.active_position query_updates.append(query_update) return query_updates ================================================ FILE: archive/ktransformers/server/balance_serve/inference/sampling/penaltylib/__init__.py ================================================ from .orchestrator import BatchedPenalizerOrchestrator from .penalizers.frequency_penalty import BatchedFrequencyPenalizer from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer from .penalizers.presence_penalty import BatchedPresencePenalizer from .penalizers.repetition_penalty import BatchedRepetitionPenalizer __all__ = [ "BatchedFrequencyPenalizer", "BatchedMinNewTokensPenalizer", "BatchedPresencePenalizer", "BatchedRepetitionPenalizer", "BatchedPenalizerOrchestrator", ] ================================================ FILE: archive/ktransformers/server/balance_serve/inference/sampling/penaltylib/orchestrator.py ================================================ import abc import dataclasses import typing import torch @dataclasses.dataclass class _ReqLike: origin_input_ids: typing.Union[torch.Tensor, typing.List[int]] @dataclasses.dataclass class _BatchLike: reqs: typing.List[_ReqLike] def batch_size(self): return len(self.reqs) class BatchedPenalizerOrchestrator: batch: _BatchLike device: str vocab_size: int penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"] def __init__( self, vocab_size: int, batch: _BatchLike, device: str, Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]], ): self.vocab_size = vocab_size self.batch = batch self.device = device self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers} is_required = False for penalizer in self.penalizers.values(): pen_is_required = penalizer.prepare_if_required() is_required |= pen_is_required self.is_required = is_required if self.is_required: self.cumulate_input_tokens( input_ids=[req.origin_input_ids for req in self.reqs()] ) def reqs(self): return self.batch.reqs def batch_size(self): return self.batch.batch_size() def cumulate_input_tokens( self, input_ids: typing.Union[ typing.List[torch.Tensor], typing.List[typing.List[int]] ], ): """ Feed the input tokens to the penalizers. Args: input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens. """ token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids) for penalizer in self.penalizers.values(): penalizer.cumulate_input_tokens(input_ids=token_ids) def cumulate_output_tokens( self, output_ids: typing.Union[ typing.List[torch.Tensor], typing.List[typing.List[int]] ], ): """ Feed the output tokens to the penalizers. Args: output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens. """ if not self.is_required: return token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids) for penalizer in self.penalizers.values(): penalizer.cumulate_output_tokens(output_ids=token_ids) def apply(self, logits: torch.Tensor) -> torch.Tensor: """ Apply the penalizers to the logits. Note that it may apply the penalizers in-place. Args: logits (torch.Tensor): The logits to apply the penalizers to. Returns: torch.Tensor: The logits after applying the penalizers. """ if not self.is_required: return for penalizer in self.penalizers.values(): logits = penalizer.apply(logits) return logits def filter( self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor = None, ): """ Filter the penalizers based on the indices to keep in the batch. Args: indices_to_keep (typing.List[int]): List of indices to keep in the batch. indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor. """ if not self.is_required: return empty_indices = len(indices_to_keep) == 0 is_required = False for penalizer in self.penalizers.values(): tmp_is_required = penalizer.is_required() is_required = is_required or tmp_is_required if not tmp_is_required or empty_indices: penalizer.teardown() else: # create tensor index only when it's needed if indices_tensor_to_keep is None: indices_tensor_to_keep = torch.tensor( indices_to_keep, dtype=torch.int32, device=self.device ) penalizer.filter( indices_to_keep=indices_to_keep, indices_tensor_to_keep=indices_tensor_to_keep, ) self.is_required = is_required def merge(self, their: "BatchedPenalizerOrchestrator"): """ Merge the penalizers of another orchestrator into this one. Note that this function **must** be called _before_ self.batch.reqs is updated (filtered). Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging. This step requires the original batch.reqs, before it gets merged with other batch.reqs. Args: their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one. """ if not self.is_required and not their.is_required: return self.is_required |= their.is_required for Penalizer, their_penalizer in their.penalizers.items(): if Penalizer not in self.penalizers: raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers") self.penalizers[Penalizer].merge(their_penalizer) class _TokenIDs: """ A class that wraps token IDs to provide additional utility functions to penalizers. Attributes: orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to. token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs. cached_counts (torch.Tensor): The cached occurrence count tensor. """ orchestrator: BatchedPenalizerOrchestrator token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]] cached_counts: torch.Tensor = None def __init__( self, orchestrator: BatchedPenalizerOrchestrator, token_ids: typing.Union[ typing.List[torch.Tensor], typing.List[typing.List[int]] ], ): self.orchestrator = orchestrator if not isinstance(token_ids[0], torch.Tensor): token_ids = [ torch.tensor( data=ids, dtype=torch.int64, device=self.orchestrator.device ) for ids in token_ids ] self.token_ids = token_ids def occurrence_count(self) -> torch.Tensor: """ Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch. Returns: torch.Tensor: The occurrence count tensor. """ if self.cached_counts is not None: return self.cached_counts token_ids = self.token_ids if isinstance(token_ids, torch.Tensor): token_ids = token_ids.unsqueeze(1) # needs to be long to be used as index in scatter_add if token_ids.dtype != torch.int64: token_ids = token_ids.to(torch.int64) padded_token_ids = torch.nn.utils.rnn.pad_sequence( sequences=token_ids, batch_first=True, padding_value=self.orchestrator.vocab_size, ) self.cached_counts = torch.zeros( size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1), dtype=torch.int64, device=self.orchestrator.device, ).scatter_add_( dim=1, index=padded_token_ids, src=torch.ones_like(padded_token_ids), )[ :, : self.orchestrator.vocab_size ] return self.cached_counts class _BatchedPenalizer(abc.ABC): """ An abstract class for a batched penalizer. """ orchestrator: BatchedPenalizerOrchestrator _is_prepared: bool = False def __init__(self, orchestrator: BatchedPenalizerOrchestrator): self.orchestrator = orchestrator def is_prepared(self) -> bool: return self._is_prepared def is_required(self) -> bool: return self._is_required() def prepare(self): if not self.is_prepared(): self._prepare() self._is_prepared = True def prepare_if_required(self): if self.is_required(): self.prepare() return True else: return False def teardown(self): if self.is_prepared(): self._teardown() self._is_prepared = False def cumulate_input_tokens(self, input_ids: _TokenIDs): if not self.is_prepared(): return self._cumulate_input_tokens(input_ids=input_ids) def cumulate_output_tokens(self, output_ids: _TokenIDs): if not self.is_prepared(): return self._cumulate_output_tokens(output_ids=output_ids) def apply(self, logits: torch.Tensor) -> torch.Tensor: if not self.is_prepared(): return logits return self._apply(logits=logits) def filter( self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor ): if not self.is_prepared(): return self._filter( indices_to_keep=indices_to_keep, indices_tensor_to_keep=indices_tensor_to_keep, ) def merge(self, their: "_BatchedPenalizer"): if not self.is_prepared() and not their.is_prepared(): return self.prepare() their.prepare() self._merge(their) @abc.abstractmethod def _is_required(self) -> bool: """ Check if the penalizer is required to be prepared. """ pass @abc.abstractmethod def _prepare(self): """ Prepare the penalizer. Usually, this is where the penalizer initializes its tensors. """ pass @abc.abstractmethod def _teardown(self): """ Tear down the penalizer. Usually, this is where the penalizer frees its tensors. """ pass @abc.abstractmethod def _cumulate_input_tokens(self, input_ids: _TokenIDs): """ Cumulate the input tokens. Orchestrator will call this function to feed the input tokens to the penalizer. """ pass @abc.abstractmethod def _cumulate_output_tokens(self, output_ids: _TokenIDs): """ Cumulate the output tokens. Orchestrator will call this function to feed the output tokens to the penalizer. """ pass @abc.abstractmethod def _apply(self, logits: torch.Tensor) -> torch.Tensor: """ Apply the penalizer to the logits. Penalizers can modify the logits in-place if needed. """ pass @abc.abstractmethod def _filter( self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor ): """ Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch. """ pass @abc.abstractmethod def _merge(self, their: "_BatchedPenalizer"): """ Merge the penalizer with another penalizer. """ pass ================================================ FILE: archive/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/frequency_penalty.py ================================================ import typing import torch from ..orchestrator import _BatchedPenalizer, _TokenIDs class BatchedFrequencyPenalizer(_BatchedPenalizer): """ Frequency penalizer penalizes tokens based on their frequency in the output. """ frequency_penalties: torch.Tensor = None cumulated_frequency_penalties: torch.Tensor = None def _is_required(self) -> bool: return any( req.sampling_params.frequency_penalty != 0.0 for req in self.orchestrator.reqs() ) def _prepare(self): self.cumulated_frequency_penalties = ( torch.tensor( data=[0.0 for _ in self.orchestrator.reqs()], dtype=torch.float32, device=self.orchestrator.device, ) .unsqueeze_(1) .repeat(1, self.orchestrator.vocab_size) ) self.frequency_penalties = ( torch.tensor( data=[ req.sampling_params.frequency_penalty for req in self.orchestrator.reqs() ], dtype=torch.float32, device=self.orchestrator.device, ) .unsqueeze_(1) .expand_as(self.cumulated_frequency_penalties) ) def _teardown(self): del self.frequency_penalties del self.cumulated_frequency_penalties self.frequency_penalties = None self.cumulated_frequency_penalties = None def _cumulate_input_tokens(self, input_ids: _TokenIDs): pass def _cumulate_output_tokens(self, output_ids: _TokenIDs): self.cumulated_frequency_penalties += ( self.frequency_penalties * output_ids.occurrence_count() ) def _apply(self, logits: torch.Tensor) -> torch.Tensor: logits -= self.cumulated_frequency_penalties return logits def _filter( self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor ): self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep] self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[ indices_tensor_to_keep ] def _merge(self, their: "BatchedFrequencyPenalizer"): self.frequency_penalties = torch.cat( [self.frequency_penalties, their.frequency_penalties], dim=0 ) self.cumulated_frequency_penalties = torch.cat( [self.cumulated_frequency_penalties, their.cumulated_frequency_penalties], dim=0, ) ================================================ FILE: archive/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/min_new_tokens.py ================================================ import typing import torch from ..orchestrator import _BatchedPenalizer, _TokenIDs class BatchedMinNewTokensPenalizer(_BatchedPenalizer): """ Min new tokens penalizer penalizes tokens based on the length of the output. """ min_new_tokens: torch.Tensor = None stop_token_penalties: torch.Tensor = None len_output_tokens: torch.Tensor = None def _is_required(self) -> bool: return any( req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs() ) def _prepare(self): self.min_new_tokens = torch.tensor( data=[ req.sampling_params.min_new_tokens for req in self.orchestrator.reqs() ], dtype=torch.int32, device=self.orchestrator.device, ).unsqueeze_(1) padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence( sequences=[ torch.tensor( data=( list( (req.sampling_params.stop_token_ids or set()) | (req.tokenizer.additional_stop_token_ids or set()) | {req.tokenizer.eos_token_id} ) ), dtype=torch.int64, device=self.orchestrator.device, ) for req in self.orchestrator.reqs() ], batch_first=True, padding_value=self.orchestrator.vocab_size, ) self.stop_token_penalties = torch.zeros( size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1), dtype=torch.float32, device=self.orchestrator.device, ).scatter_add_( dim=1, index=padded_stop_token_ids, src=torch.full_like( input=padded_stop_token_ids, dtype=torch.float32, fill_value=float("-inf"), device=self.orchestrator.device, ), )[ :, : self.orchestrator.vocab_size ] self.len_output_tokens = torch.zeros( size=(self.orchestrator.batch_size(), 1), dtype=torch.int32, device=self.orchestrator.device, ) def _teardown(self): del self.min_new_tokens del self.stop_token_penalties del self.len_output_tokens self.min_new_tokens = None self.stop_token_penalties = None self.len_output_tokens = None def _cumulate_input_tokens(self, input_ids: _TokenIDs): pass def _cumulate_output_tokens(self, output_ids: _TokenIDs): self.len_output_tokens += 1 def _apply(self, logits: torch.Tensor) -> torch.Tensor: mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits) logits[mask] += self.stop_token_penalties[mask] return logits def _filter( self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor ): self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep] self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep] self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep] def _merge(self, their: "BatchedMinNewTokensPenalizer"): self.min_new_tokens = torch.cat( [self.min_new_tokens, their.min_new_tokens], dim=0 ) self.stop_token_penalties = torch.cat( [self.stop_token_penalties, their.stop_token_penalties], dim=0 ) self.len_output_tokens = torch.cat( [self.len_output_tokens, their.len_output_tokens], dim=0 ) ================================================ FILE: archive/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/presence_penalty.py ================================================ import typing import torch from ..orchestrator import _BatchedPenalizer, _TokenIDs class BatchedPresencePenalizer(_BatchedPenalizer): """ Presence penalizer penalizes tokens based on their presence in the output. """ presence_penalties: torch.Tensor = None cumulated_presence_penalties: torch.Tensor = None def _is_required(self) -> bool: return any( req.sampling_params.presence_penalty != 0.0 for req in self.orchestrator.reqs() ) def _prepare(self): self.cumulated_presence_penalties = ( torch.tensor( data=[0.0 for _ in self.orchestrator.reqs()], dtype=torch.float32, device=self.orchestrator.device, ) .unsqueeze_(1) .repeat(1, self.orchestrator.vocab_size) ) self.presence_penalties = ( torch.tensor( data=[ req.sampling_params.presence_penalty for req in self.orchestrator.reqs() ], dtype=torch.float32, device=self.orchestrator.device, ) .unsqueeze_(1) .expand_as(self.cumulated_presence_penalties) ) def _teardown(self): del self.presence_penalties del self.cumulated_presence_penalties self.presence_penalties = None self.cumulated_presence_penalties = None def _cumulate_input_tokens(self, input_ids: _TokenIDs): pass def _cumulate_output_tokens(self, output_ids: _TokenIDs): mask = output_ids.occurrence_count() > 0 self.cumulated_presence_penalties[mask] = self.presence_penalties[mask] def _apply(self, logits: torch.Tensor) -> torch.Tensor: logits -= self.cumulated_presence_penalties return logits def _filter( self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor ): self.presence_penalties = self.presence_penalties[indices_tensor_to_keep] self.cumulated_presence_penalties = self.cumulated_presence_penalties[ indices_tensor_to_keep ] def _merge(self, their: "BatchedPresencePenalizer"): self.presence_penalties = torch.cat( [self.presence_penalties, their.presence_penalties], dim=0 ) self.cumulated_presence_penalties = torch.cat( [self.cumulated_presence_penalties, their.cumulated_presence_penalties], dim=0, ) ================================================ FILE: archive/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/repetition_penalty.py ================================================ import typing import torch from ..orchestrator import _BatchedPenalizer, _TokenIDs class BatchedRepetitionPenalizer(_BatchedPenalizer): """ Repetition penalizer penalizes tokens based on their repetition in the input and output. """ repetition_penalties: torch.Tensor = None cumulated_repetition_penalties: torch.Tensor = None def _is_required(self) -> bool: return any( req.sampling_params.repetition_penalty != 1.0 for req in self.orchestrator.reqs() ) def _prepare(self): self.cumulated_repetition_penalties = ( torch.tensor( data=[1.0 for _ in self.orchestrator.reqs()], dtype=torch.float32, device=self.orchestrator.device, ) .unsqueeze_(1) .repeat(1, self.orchestrator.vocab_size) ) self.repetition_penalties = ( torch.tensor( data=[ req.sampling_params.repetition_penalty for req in self.orchestrator.reqs() ], dtype=torch.float32, device=self.orchestrator.device, ) .unsqueeze_(1) .expand_as(self.cumulated_repetition_penalties) ) def _teardown(self): del self.repetition_penalties del self.cumulated_repetition_penalties self.repetition_penalties = None self.cumulated_repetition_penalties = None def _cumulate_input_tokens(self, input_ids: _TokenIDs): mask = input_ids.occurrence_count() > 0 self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] def _cumulate_output_tokens(self, output_ids: _TokenIDs): mask = output_ids.occurrence_count() > 0 self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] def _apply(self, logits: torch.Tensor) -> torch.Tensor: return torch.where( logits > 0, logits / self.cumulated_repetition_penalties, logits * self.cumulated_repetition_penalties, ) def _filter( self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor ): self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep] self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[ indices_tensor_to_keep ] def _merge(self, their: "BatchedRepetitionPenalizer"): self.repetition_penalties = torch.cat( [self.repetition_penalties, their.repetition_penalties], dim=0 ) self.cumulated_repetition_penalties = torch.cat( [self.cumulated_repetition_penalties, their.cumulated_repetition_penalties], dim=0, ) ================================================ FILE: archive/ktransformers/server/balance_serve/inference/sampling/sampler.py ================================================ ''' Date: 2024-11-14 12:23:45 LastEditors: Xie Weiyu ervinxie@qq.com LastEditTime: 2024-11-25 08:59:23 ''' import logging import torch from torch import nn from transformers import GenerationConfig from flashinfer.sampling import ( min_p_sampling_from_probs, top_k_renorm_probs, top_k_top_p_sampling_from_logits, top_p_renorm_probs, ) try: import torch_npu use_torch_npu = torch.npu.is_available() except: use_torch_npu = False logger = logging.getLogger(__name__) class SamplingOptions(): # Batched sampling params temperatures: torch.Tensor top_ps: torch.Tensor top_ks: torch.Tensor min_ps: torch.Tensor # All requests use greedy sampling is_all_greedy: bool # Dispatch in CUDA graph need_min_p_sampling: bool def __init__(self, bsz = 1, device = torch.device('cuda'), pretrained_config:GenerationConfig = None, temperatures: torch.Tensor = None, top_ps: torch.Tensor = None): if pretrained_config is None and temperatures is None: self.temperatures = torch.full((bsz, 1), 0, device=device, dtype=torch.float32) self.top_ps = torch.ones((bsz, 1), device=device, dtype=torch.float32) self.top_ks = torch.ones((bsz, 1), device=device, dtype=torch.float32) self.need_min_p_sampling = False self.is_all_greedy = True else: if temperatures is not None: self.temperatures = temperatures.unsqueeze(-1) else: self.temperatures = torch.full((bsz, 1), pretrained_config.temperature, device=device, dtype=torch.float32) if top_ps is not None: self.top_ps = top_ps.unsqueeze(-1) else: self.top_ps = torch.full((bsz, 1), pretrained_config.top_p, device=device, dtype=torch.float32) self.top_ks = torch.full((bsz, 1), pretrained_config.top_k, device=device, dtype=torch.float32) self.need_min_p_sampling = False self.is_all_greedy = False class Sampler(nn.Module): def __init__(self): super().__init__() def forward( self, logits: torch.Tensor, sampling_config: SamplingOptions = None, ): if sampling_config == None: sampling_config = SamplingOptions() logits = logits.contiguous() origin_logits = logits.clone() if sampling_config.is_all_greedy or use_torch_npu: # Use torch.argmax if all requests use greedy sampling probs = torch.softmax(logits, dim=-1) batch_next_token_ids = torch.argmax(logits, -1) else: # Post process logits logits.div_(sampling_config.temperatures) max_top_k_round, batch_size = 32, logits.shape[0] if sampling_config.need_min_p_sampling: probs = torch.softmax(logits, dim=-1) logits = None del logits probs = top_k_renorm_probs(probs, sampling_config.top_ks) probs = top_p_renorm_probs(probs, sampling_config.top_ps) batch_next_token_ids = min_p_sampling_from_probs( probs, sampling_config.min_ps ) temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0] batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32) else: # TODO: use different kernel when don't need top_k or top_p # @TODO get probs probs = logits batch_next_token_ids = top_k_top_p_sampling_from_logits( logits, sampling_config.top_ks, sampling_config.top_ps, filter_apply_order="joint", ) temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0] batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32) return batch_next_token_ids.to(torch.int32), probs ================================================ FILE: archive/ktransformers/server/balance_serve/sched_rpc.py ================================================ from datetime import datetime import os from typing import Optional import zmq import pickle import threading import torch.multiprocessing as mp import sys current_file_path = os.path.abspath(__file__) # sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) import pickle import argparse import torch try: import torch_npu use_npu = torch.npu.is_available() except: use_npu = False from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe, create_sched_settings_glm4moe, create_sched_settings_smallthinker, create_sched_settings_qwen3next if mp.get_start_method(allow_none=True) is None: print('set start method') mp.set_start_method('spawn') else: print(f'start method already set to {mp.get_start_method(allow_none=True)}') class SchedulerServer: def __init__(self, settings, main_args): # 创建 Scheduler 实例并初始化 if use_npu: for device_id in settings.gpu_device_id: torch_npu.npu.set_device(f'npu:{device_id}') self.sched = sched_ext.create_scheduler(settings) # 初始化 ZeroMQ 上下文和套接字 self.context = zmq.Context() self.frontend = self.context.socket(zmq.ROUTER) print(f"sched zmq rpc server on port {main_args.sched_port}") self.frontend.bind(f"tcp://*:{main_args.sched_port}") # 创建内部的 DEALER 套接字,用于与工作线程通信 self.backend = self.context.socket(zmq.DEALER) self.backend.bind("inproc://backend") # 启动调度器 def run_scheduler(self): self.sched.run() # 停止调度器 def stop_scheduler(self): self.sched.stop() # 处理客户端请求 def start_proxy(self): # 使用 ZMQ 的内置代理,将前端请求分发给后端工作线程 zmq.proxy(self.frontend, self.backend) # 工作线程处理请求 def worker_routine(self): worker = self.context.socket(zmq.REP) worker.connect("inproc://backend") while True: try: # 接收客户端请求 message = worker.recv() data = pickle.loads(message) method = data.get('method') params = data.get('params', {}) # print(f"Received request: {method}") if method == 'add_query': query_add = params.get('query') # 直接是一个 QueryAdd 对象 # 添加查询 query_id = self.sched.add_query(query_add) # 发送响应 response = {'status': 'ok', 'query_id': query_id} worker.send(pickle.dumps(response)) elif method == 'cancel_query': query_id = params.get('query_id') # 假设您的 Scheduler 类实现了 cancel 方法 self.sched.cancel(query_id) response = {'status': 'ok'} worker.send(pickle.dumps(response)) elif method == 'update_last_batch': updates = params.get('updates') # 直接是一个列表,包含 QueryUpdate 对象 # 更新最后一个批次 batch_todo = self.sched.update_last_batch(updates) # 直接发送 batch_todo 对象 response = {'status': 'ok', 'batch_todo': batch_todo} # print (batch_todo.query_lengths, batch_todo.query_ids) worker.send(pickle.dumps(response)) elif method == 'get_inference_context': inference_context = self.sched.get_inference_context() data = { "k_cache":inference_context.k_cache, "v_cache":inference_context.v_cache } print(f"Serializing KVCache") data["k_cache"] = [mp.reductions.reduce_tensor(t) for t in data['k_cache']] data["v_cache"] = [mp.reductions.reduce_tensor(t) for t in data['v_cache']] # print(data) response = {'status': 'ok', 'inference_context': data} worker.send(pickle.dumps(response)) # response['inference_context'].k_cache[0][0, 0, 0, 0, 0] = 1 # print("k_cache update") else: # 未知方法 response = {'status': 'error', 'message': 'Unknown method'} worker.send(pickle.dumps(response)) except Exception as e: # 处理异常并发送错误响应 response = {'status': 'error', 'message': str(e)} worker.send(pickle.dumps(response)) # 启动 RPC 服务 def start_rpc_service(self): try: print("Scheduler RPC service is running...") # 在单独的线程中运行调度器 threading.Thread(target=self.run_scheduler, daemon=True).start() # 启动工作线程 for _ in range(10): # 根据需要调整线程数 threading.Thread(target=self.worker_routine, daemon=True).start() # 启动代理,开始监听请求 self.start_proxy() except KeyboardInterrupt: print("Shutting down scheduler RPC service...") self.stop_rpc_service() # 停止 RPC 服务 def stop_rpc_service(self): self.stop_scheduler() self.frontend.close() self.backend.close() self.context.term() def start_server(settings, main_args): server = SchedulerServer(settings, main_args) server.start_rpc_service() # Add async client for webserver class SchedulerClient: def __init__(self, sched_port): address=f'tcp://localhost:{sched_port}' self.address = address self.context = zmq.Context() self.socket = self.context.socket(zmq.REQ) self.socket.connect(self.address) print(f"Connected to server at {self.address}") def __del__(self): self.socket.close() self.context.term() def send_request(self, method, params=None): if params is None: params = {} request = { 'method': method, 'params': params } # print(f'send request {request}') self.socket.send(pickle.dumps(request)) response = self.socket.recv() # print(response) response = pickle.loads(response) if response.get('status') == 'ok': return response else: raise Exception(f"Error from server: {response.get('message')}") def add_query(self, query): response = self.send_request('add_query', {'query': query}) return response.get('query_id') def cancel_query(self, query_id): self.send_request('cancel_query', {'query_id': query_id}) def update_last_batch(self, updates): response = self.send_request('update_last_batch', {'updates': updates}) # print(f"update_last_batch response {response}") return response.get('batch_todo') def rebuild_inferece_context(self,response): data = response.get('inference_context') inference_context = sched_ext.InferenceContext() print('Rebuilding kvcache') inference_context.k_cache = [fn(*args) for fn,args in data['k_cache']] inference_context.v_cache = [fn(*args) for fn,args in data['v_cache']] return inference_context def get_inference_context_raw(self): response = self.send_request('get_inference_context') return response if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) args = parser.parse_args() with open(args.config, "rb") as f: main_args = pickle.load(f) if main_args.architectures == "Qwen2MoeForCausalLM": settings = create_sched_settings_qwen2moe(main_args) elif main_args.architectures == "Qwen3MoeForCausalLM": settings = create_sched_settings_qwen3moe(main_args) elif main_args.architectures == "Glm4MoeForCausalLM": settings = create_sched_settings_glm4moe(main_args) elif main_args.architectures == "SmallThinkerForCausalLM": settings = create_sched_settings_smallthinker(main_args) elif main_args.architectures == "Qwen3NextForCausalLM": settings = create_sched_settings_qwen3next(main_args) else: settings = create_sched_settings(main_args) start_server(settings, main_args) ================================================ FILE: archive/ktransformers/server/balance_serve/settings.py ================================================ ''' Date: 2024-11-13 09:43:39 LastEditors: djw LastEditTime: 2024-11-18 16:41:03 ''' import sys, os import yaml, json from time import sleep import sched_ext from transformers import AutoConfig from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig from ktransformers.models.configuration_smallthinker import SmallthinkerConfig from ktransformers.models.configuration_qwen3_next import Qwen3NextConfig def create_sched_settings(args): default_sample_options = sched_ext.SampleOptions() model_name = os.path.basename(os.path.normpath(args.model_dir)) input_model_settings = sched_ext.ModelSettings() input_model_settings.model_path = args.model_dir input_model_settings.params_count = int(0) model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) input_model_settings.layer_count = model_config.num_hidden_layers input_model_settings.num_k_heads = 1 # model_config["num_key_value_heads"] input_model_settings.k_head_dim = 576 input_model_settings.bytes_per_params = 2 input_model_settings.bytes_per_kv_cache_element = 2 settings = sched_ext.Settings() settings.model_name = model_name settings.quant_type = "BF16" settings.model_settings = input_model_settings settings.page_size = args.page_size settings.gpu_device_count = args.tp # only full tp supported now settings.gpu_device_id = [i for i in range(settings.gpu_device_count)] # settings.gpu_memory_size = args.cache_lens*576*2 settings.gpu_memory_size = args.gpu_memory_size settings.memory_utilization_percentage = args.utilization_percentage max_batch_size = args.max_batch_size chunk_size = args.chunk_size max_decode_batch_size = max_batch_size - 2 settings.max_batch_size = max_batch_size settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2 settings.sample_options = default_sample_options settings.sched_metrics_port = args.sched_metrics_port settings.gpu_only = args.memory_gpu_only settings.use_self_defined_head_dim = True settings.self_defined_head_dim = 576 settings.full_kv_cache_on_each_gpu = True settings.k_cache_on = True settings.v_cache_on = False settings.kvc2_root_path = args.kvc2_disk_path settings.kvc2_config_path = args.kvc2_config_dir settings.memory_pool_size_GB = args.cpu_memory_size_GB settings.evict_count = 40 settings.kvc2_metrics_port = args.kvc2_metrics_port settings.load_from_disk = False settings.save_to_disk = True settings.strategy_name = args.sched_strategy settings.auto_derive() return settings def create_sched_settings_qwen2moe(args): default_sample_options = sched_ext.SampleOptions() model_name = os.path.basename(os.path.normpath(args.model_dir)) input_model_settings = sched_ext.ModelSettings() input_model_settings.model_path = args.model_dir input_model_settings.params_count = int(0) model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) input_model_settings.layer_count = model_config.num_hidden_layers input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config["num_key_value_heads"] input_model_settings.k_head_dim = 128 input_model_settings.bytes_per_params = 2 input_model_settings.bytes_per_kv_cache_element = 2 settings = sched_ext.Settings() settings.model_name = model_name settings.quant_type = "BF16" settings.model_settings = input_model_settings settings.page_size = args.page_size settings.gpu_device_count = 1 # tp settings.gpu_device_id = [i for i in range(settings.gpu_device_count)] # settings.gpu_memory_size = args.cache_lens*576*2 settings.gpu_memory_size = args.gpu_memory_size settings.memory_utilization_percentage = args.utilization_percentage max_batch_size = args.max_batch_size chunk_size = args.chunk_size max_decode_batch_size = max_batch_size - 2 settings.max_batch_size = max_batch_size settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2 settings.sample_options = default_sample_options settings.sched_metrics_port = args.sched_metrics_port settings.gpu_only = args.memory_gpu_only settings.use_self_defined_head_dim = False settings.self_defined_head_dim = 576 settings.full_kv_cache_on_each_gpu = True settings.k_cache_on = True settings.v_cache_on = True settings.kvc2_root_path = args.kvc2_disk_path settings.kvc2_config_path = args.kvc2_config_dir settings.memory_pool_size_GB = args.cpu_memory_size_GB settings.evict_count = 40 settings.kvc2_metrics_port = args.kvc2_metrics_port settings.load_from_disk = False settings.save_to_disk = True settings.strategy_name = args.sched_strategy settings.auto_derive() return settings def create_sched_settings_qwen3moe(args): default_sample_options = sched_ext.SampleOptions() model_name = os.path.basename(os.path.normpath(args.model_dir)) input_model_settings = sched_ext.ModelSettings() input_model_settings.model_path = args.model_dir input_model_settings.params_count = int(0) model_config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) input_model_settings.layer_count = model_config.num_hidden_layers input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config["num_key_value_heads"] input_model_settings.k_head_dim = 128 input_model_settings.bytes_per_params = 2 input_model_settings.bytes_per_kv_cache_element = 2 settings = sched_ext.Settings() settings.model_name = model_name settings.quant_type = "BF16" settings.model_settings = input_model_settings settings.page_size = args.page_size settings.gpu_device_count = 1 # tp settings.gpu_device_id = [i for i in range(settings.gpu_device_count)] # settings.gpu_memory_size = args.cache_lens*576*2 settings.gpu_memory_size = args.gpu_memory_size settings.memory_utilization_percentage = args.utilization_percentage max_batch_size = args.max_batch_size chunk_size = args.chunk_size max_decode_batch_size = max_batch_size - 2 settings.max_batch_size = max_batch_size settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2 settings.sample_options = default_sample_options settings.sched_metrics_port = args.sched_metrics_port settings.gpu_only = args.memory_gpu_only settings.use_self_defined_head_dim = False settings.self_defined_head_dim = 576 settings.full_kv_cache_on_each_gpu = True settings.k_cache_on = True settings.v_cache_on = True settings.kvc2_root_path = args.kvc2_disk_path settings.kvc2_config_path = args.kvc2_config_dir settings.memory_pool_size_GB = args.cpu_memory_size_GB settings.evict_count = 40 settings.kvc2_metrics_port = args.kvc2_metrics_port settings.load_from_disk = False settings.save_to_disk = True settings.strategy_name = args.sched_strategy settings.auto_derive() return settings def create_sched_settings_glm4moe(args): default_sample_options = sched_ext.SampleOptions() model_name = os.path.basename(os.path.normpath(args.model_dir)) input_model_settings = sched_ext.ModelSettings() input_model_settings.model_path = args.model_dir input_model_settings.params_count = int(0) model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) input_model_settings.layer_count = model_config.num_hidden_layers input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config["num_key_value_heads"] input_model_settings.k_head_dim = 128 input_model_settings.bytes_per_params = 2 input_model_settings.bytes_per_kv_cache_element = 2 settings = sched_ext.Settings() settings.model_name = model_name settings.quant_type = "BF16" settings.model_settings = input_model_settings settings.page_size = args.page_size settings.gpu_device_count = 1 # tp settings.gpu_device_id = [i for i in range(settings.gpu_device_count)] # settings.gpu_memory_size = args.cache_lens*576*2 settings.gpu_memory_size = args.gpu_memory_size settings.memory_utilization_percentage = args.utilization_percentage max_batch_size = args.max_batch_size chunk_size = args.chunk_size max_decode_batch_size = max_batch_size - 2 settings.max_batch_size = max_batch_size settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2 settings.sample_options = default_sample_options settings.sched_metrics_port = args.sched_metrics_port settings.gpu_only = args.memory_gpu_only settings.use_self_defined_head_dim = False settings.self_defined_head_dim = 576 settings.full_kv_cache_on_each_gpu = True settings.k_cache_on = True settings.v_cache_on = True settings.kvc2_root_path = args.kvc2_disk_path settings.kvc2_config_path = args.kvc2_config_dir settings.memory_pool_size_GB = args.cpu_memory_size_GB settings.evict_count = 40 settings.kvc2_metrics_port = args.kvc2_metrics_port settings.load_from_disk = False settings.save_to_disk = True settings.strategy_name = args.sched_strategy settings.auto_derive() return settings def create_sched_settings_smallthinker(args): default_sample_options = sched_ext.SampleOptions() model_name = os.path.basename(os.path.normpath(args.model_dir)) input_model_settings = sched_ext.ModelSettings() input_model_settings.model_path = args.model_dir input_model_settings.params_count = int(0) model_config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True) input_model_settings.layer_count = model_config.num_hidden_layers input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config["num_key_value_heads"] input_model_settings.k_head_dim = 128 input_model_settings.bytes_per_params = 2 input_model_settings.bytes_per_kv_cache_element = 2 settings = sched_ext.Settings() settings.model_name = model_name settings.quant_type = "BF16" settings.model_settings = input_model_settings settings.page_size = args.page_size settings.gpu_device_count = 1 # tp settings.gpu_device_id = [i for i in range(settings.gpu_device_count)] # settings.gpu_memory_size = args.cache_lens*576*2 settings.gpu_memory_size = args.gpu_memory_size settings.memory_utilization_percentage = args.utilization_percentage max_batch_size = args.max_batch_size chunk_size = args.chunk_size max_decode_batch_size = max_batch_size - 2 settings.max_batch_size = max_batch_size settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2 settings.sample_options = default_sample_options settings.sched_metrics_port = args.sched_metrics_port settings.gpu_only = args.memory_gpu_only settings.use_self_defined_head_dim = False settings.self_defined_head_dim = 576 settings.full_kv_cache_on_each_gpu = True settings.k_cache_on = True settings.v_cache_on = True settings.kvc2_root_path = args.kvc2_disk_path settings.kvc2_config_path = args.kvc2_config_dir settings.memory_pool_size_GB = args.cpu_memory_size_GB settings.evict_count = 40 settings.kvc2_metrics_port = args.kvc2_metrics_port settings.load_from_disk = False settings.save_to_disk = True settings.strategy_name = args.sched_strategy settings.auto_derive() return settings def create_sched_settings_qwen3next(args): default_sample_options = sched_ext.SampleOptions() model_name = os.path.basename(os.path.normpath(args.model_dir)) input_model_settings = sched_ext.ModelSettings() input_model_settings.model_path = args.model_dir input_model_settings.params_count = int(0) model_config = Qwen3NextConfig.from_pretrained(args.model_dir, trust_remote_code=True) input_model_settings.layer_count = model_config.num_hidden_layers input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config["num_key_value_heads"] input_model_settings.k_head_dim = 256 input_model_settings.bytes_per_params = 2 input_model_settings.bytes_per_kv_cache_element = 2 settings = sched_ext.Settings() settings.model_name = model_name settings.quant_type = "BF16" settings.model_settings = input_model_settings settings.page_size = args.page_size settings.gpu_device_count = 1 # tp settings.gpu_device_id = [i for i in range(settings.gpu_device_count)] # settings.gpu_memory_size = args.cache_lens*576*2 settings.gpu_memory_size = args.gpu_memory_size settings.memory_utilization_percentage = args.utilization_percentage max_batch_size = args.max_batch_size chunk_size = args.chunk_size max_decode_batch_size = max_batch_size - 2 settings.max_batch_size = max_batch_size settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2 settings.sample_options = default_sample_options settings.sched_metrics_port = args.sched_metrics_port settings.gpu_only = args.memory_gpu_only settings.use_self_defined_head_dim = False settings.self_defined_head_dim = 576 settings.full_kv_cache_on_each_gpu = True settings.k_cache_on = True settings.v_cache_on = True settings.kvc2_root_path = args.kvc2_disk_path settings.kvc2_config_path = args.kvc2_config_dir settings.memory_pool_size_GB = args.cpu_memory_size_GB settings.evict_count = 40 settings.kvc2_metrics_port = args.kvc2_metrics_port settings.load_from_disk = False settings.save_to_disk = True settings.strategy_name = args.sched_strategy settings.auto_derive() return settings ================================================ FILE: archive/ktransformers/server/config/config.py ================================================ #!/usr/bin/env python # coding=utf-8 """ Description : Author : unicornchan Date : 2024-06-11 16:35:42 Version : 1.0.0 LastEditors : WuHao LastEditTime : 2024-08-12 06:31:14 """ import os import shutil import yaml import psutil from ktransformers.server.config.singleton import Singleton from typing import Optional class Config(metaclass=Singleton): """Singleton pattern Config class, used to get all configurations.""" CONFIG_FILE_NAME = "config.yaml" @staticmethod def load() -> dict: """load config file Returns: dict: all configs """ base_path: str = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) config_yaml: str = os.path.join(base_path, "configs", Config.CONFIG_FILE_NAME) user_path: str = os.path.expanduser("~") localstore_path: str = os.path.join(user_path, ".ktransformers") kvc2_config_dir = os.path.join(localstore_path, "kvc2") config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME) if not os.path.exists(config_yaml): print(f"Can't find config file, {config_yaml}") exit(-1) if not os.path.exists(localstore_path): os.mkdir(localstore_path) if not os.path.exists(kvc2_config_dir): os.mkdir(kvc2_config_dir) if not os.path.exists(config_path): shutil.copyfile(config_yaml, config_path) with open(config_path, "r", encoding="utf-8") as fp: config = yaml.safe_load(fp) return config @staticmethod def to_path(path: str) -> str: """ process file path """ base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) real_path = path if os.path.isabs(path) else os.path.join(base_path, path) return real_path def __init__(self): cfg = Config.load() self.base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) self.user_path: str = os.path.expanduser("~") self.localstore_path: str = os.path.join(self.user_path, ".ktransformers") # log configs self.log_dir = os.path.join(self.localstore_path, cfg["log"]["dir"]) if not os.path.exists(self.log_dir): os.mkdir(self.log_dir) self.log_file = cfg["log"]["file"] self.log_level = cfg["log"]["level"] self.backup_count = cfg["log"]["backup_count"] self.kvc2_config_dir = os.path.join(self.localstore_path, "kvc2") # server configs self.server: dict = cfg.get("server", {}) self.server_ip = self.server.get("ip", "0.0.0.0") self.server_port = self.server.get("port", 9016) self.api_key = self.server.get("api_key", "") # db configs self.db_configs: dict = cfg.get("db", {}) self.db_type = self.db_configs.get("type", "") self.db_host = self.localstore_path self.db_port = self.db_configs.get("port", "") self.db_name = self.db_configs.get("database", "") self.db_pool_size = self.db_configs.get("pool_size") self.db_database = self.db_configs.get("database", "") # user config self.user_config: dict = cfg.get("user", {}) self.user_secret_key = self.user_config.get("secret_key", "") self.user_algorithm = self.user_config.get("algorithm", "") self.user_force_think = self.user_config.get("force_think", False) # model config self.model: dict = cfg.get("model", {}) self.backend_type: str = self.model.get("type", "transformers") self.model_dir: str = self.model.get("path", "") # to make sure it consistent with previous version self.model_path: str = self.model_dir self.model_name: str = self.model.get("name", "") self.architectures: str = self.model.get("name", "") self.model_device: str = self.model.get("device", "cuda:0") self.gguf_path: Optional[str] = self.model.get("gguf_path", None) self.use_cuda_graph = self.model.get("use_cuda_graph", True) self.trust_remote_code = self.model.get("trust_remote_code", True) # self.model_cache_lens = self.model.get("cache_lens") self.optimize_config_path: Optional[str] = self.model.get( "optimize_config_path", None ) self.max_new_tokens = self.model.get("max_new_tokens", 2000) self.json_mode = self.model.get("json_mode", False) self.healing = self.model.get("healing", False) self.ban_strings: Optional[list] = self.model.get("ban_strings", None) self.gpu_split: Optional[str] = self.model.get("gpu_split", None) self.length: Optional[int] = self.model.get("length", None) self.rope_scale: Optional[float] = self.model.get("rope_scale", None) self.rope_alpha: Optional[float] = self.model.get("rope_alpha", None) self.no_flash_attn = self.model.get("no_flash_attn", False) self.low_mem = self.model.get("low_mem", False) self.experts_per_token: Optional[int] = self.model.get("experts_per_token", None) self.load_q4 = self.model.get("load_q4", False) self.fast_safetensors = self.model.get("fast_safetensors", False) self.draft_model_dir: Optional[str] = self.model.get("draft_model_dir", None) self.no_draft_scale = self.model.get("no_draft_scale", False) self.modes = self.model.get("modes", False) self.mode = self.model.get("mode", "llama") self.username = self.model.get("username", "User") self.botname = self.model.get("botname", "Chatbort") self.system_prompt: Optional[str] = self.model.get("system_prompt", None) self.temperature = self.model.get("temperature", 0.95) self.smoothing_factor = self.model.get("smoothing_factor", 0.0) self.dynamic_temperature: Optional[str] = self.model.get("dynamic_temperature", None) self.top_k = self.model.get("top_k", 50) self.top_p = self.model.get("top_p", 0.8) self.top_a = self.model.get("top_a", 0.0) self.skew = self.model.get("skew", 0.0) self.typical = self.model.get("typical", 0.0) self.repetition_penalty = self.model.get("repetition_penalty", 1.01) self.frequency_penalty = self.model.get("frequency_penalty", 0.0) self.presence_penalty = self.model.get("presence_penalty", 0.0) self.response_chunk = self.model.get("response_chunk", 250) self.no_code_formatting = self.model.get("no_code_formatting", False) self.cache_8bit = self.model.get("cache_8bit", False) self.cache_q4 = self.model.get("cache_q4", True) self.ngram_decoding = self.model.get("ngram_decoding", False) self.print_timings = self.model.get("print_timings", False) self.amnesia = self.model.get("amnesia", False) self.batch_size = self.model.get("batch_size", 1) self.cache_lens = self.model.get("cache_lens", 4096) self.device = self.model.get("device", "cuda:2") # web config self.web: dict = cfg.get("web", {}) self.web_cross_domain: bool = self.web.get("open_cross_domain", True) self.mount_web: bool = self.web.get("mount", False) # ext self.ext: dict = cfg.get("ext", {}) self.cpu_infer = psutil.cpu_count(logical=False) - 3 # file config self.local_store_configs: dict = cfg.get("local_store", {}) self.file_upload_dir: str = os.path.join( self.localstore_path, self.local_store_configs.get("file_upload_dir", "") ) self.assistant_store_dir: str = os.path.join( self.localstore_path, self.local_store_configs.get("assistant_store_dir", "") ) # long context config self.long_context_config: dict = cfg.get("long_context", {}) self.max_seq_len = self.long_context_config.get("max_seq_len", 32000) self.block_size = self.long_context_config.get("block_size", 128) self.local_windows_len = self.long_context_config.get("local_windows_len", 4096) self.second_select_num = self.long_context_config.get("second_select_num", 32) self.anchor_type = self.long_context_config.get("anchor_type", "DYNAMIC") self.kv_type = self.long_context_config.get("kv_type", "FP16") self.dense_layer_num = self.long_context_config.get("dense_layer_num", 2) self.anchor_num = self.long_context_config.get("anchor_num", 1) self.preselect_block = self.long_context_config.get("preselect_block", True) self.head_select_mode = self.long_context_config.get("head_select_mode", "SHARED") self.preselect_block_count = self.long_context_config.get("preselect_block_count", 32) self.layer_step = self.long_context_config.get("layer_step", 1) self.token_step = self.long_context_config.get("token_step", 100) # local chat self.local_chat_config: dict = cfg.get("local_chat", {}) self.prompt_file = self.local_chat_config.get("prompt_file", None) # asyncserver self.sched_strategy = cfg["async_server"]["sched_strategy"] self.sched_port = cfg["async_server"]["sched_port"] self.sched_metrics_port = cfg["async_server"]["sched_metrics_port"] self.kvc2_metrics_port = cfg["async_server"]["kvc2_metrics_port"] self.max_batch_size = cfg["async_server"]["max_batch_size"] self.page_size = cfg["attn"]["page_size"] self.chunk_size = cfg["attn"]["chunk_size"] self.memory_gpu_only = cfg["kvc2"]["gpu_only"] self.cache_lens = ((self.cache_lens + self.page_size - 1) // self.page_size) * self.page_size self.gpu_memory_size = 2*576*61*self.cache_lens self.utilization_percentage = 1.0 #cfg["kvc2"]["utilization_percentage"] self.cpu_memory_size_GB = cfg["kvc2"]["cpu_memory_size_GB"] self.kvc2_disk_path = cfg["kvc2"]["disk_path"] # only support 2 prefill task self.max_prefill_batch_size = 2 self.max_decode_batch_size = self.max_batch_size - self.max_prefill_batch_size ================================================ FILE: archive/ktransformers/server/config/log.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : unicornchan Date : 2024-06-12 02:48:39 Version : 1.0.0 LastEditors : chenxl LastEditTime : 2024-07-27 01:55:50 ''' import codecs import logging import os import re import locale from pathlib import Path from logging.handlers import BaseRotatingHandler import time import colorlog from ktransformers.server.config.config import Config class DailyRotatingFileHandler(BaseRotatingHandler): """ such as 'logging.TimeRotatingFileHandler', Additional features: - support multiprocess - support rotating daily """ def __init__(self, filename, backupCount=0, encoding=None, delay=False, utc=False, **kwargs): # pylint: disable=unused-argument self.backup_count = backupCount self.utc = utc self.suffix = "%Y-%m-%d" self.base_log_path = Path(filename) if not os.path.exists(self.base_log_path.parent): os.makedirs(self.base_log_path.parent) self.base_filename = self.base_log_path.name self.current_filename = self._compute_fn() self.current_log_path = self.base_log_path.with_name( self.current_filename) BaseRotatingHandler.__init__(self, filename, 'a', encoding, delay) # pylint: disable=unused-argument, invalid-name def shouldRollover(self, record): """ Determine whether to rotate the log. If the log filename corresponding to the current time is not consistent with the currently opened log filename, then it is necessary to rotate the log Args: record: record is not used, as we are just comparing times, but it is needed so the method signatures are the same """ if self.current_filename != self._compute_fn(): return True return False def doRollover(self): """ roll over """ # close last log file if self.stream: self.stream.close() self.stream = None # type: ignore # gen new log file name self.current_filename = self._compute_fn() self.current_log_path = self.base_log_path.with_name( self.current_filename) if not self.delay: self.stream = self._open() # type: ignore self.delete_expired_files() def _compute_fn(self): """ gen log file name """ return self.base_filename + "." + time.strftime(self.suffix, time.localtime()) def _open(self): """ open a new log file, create soft link """ if self.encoding is None: stream = open(str(self.current_log_path), self.mode, encoding=locale.getpreferredencoding()) else: stream = codecs.open(str(self.current_log_path), self.mode, self.encoding) if self.base_log_path.exists(): try: if not self.base_log_path.is_symlink() or os.readlink(self.base_log_path) != self.current_filename: os.remove(self.base_log_path) except OSError: pass try: os.symlink(self.current_filename, str(self.base_log_path)) except OSError: pass return stream def delete_expired_files(self): """ delete expired files every day """ if self.backup_count <= 0: return file_names = os.listdir(str(self.base_log_path.parent)) result = [] prefix = self.base_filename + "." plen = len(prefix) for file_name in file_names: if file_name[:plen] == prefix: suffix = file_name[plen:] if re.match(r"^\d{4}-\d{2}-\d{2}(\.\w+)?$", suffix): result.append(file_name) if len(result) < self.backup_count: result = [] else: result.sort() result = result[:len(result) - self.backup_count] for file_name in result: os.remove(str(self.base_log_path.with_name(file_name))) class Logger(object): """ logger class """ level_relations = { 'debug': logging.DEBUG, 'info': logging.INFO, 'warn': logging.WARNING, 'error': logging.ERROR, 'crit': logging.CRITICAL } def __init__(self, level: str = 'info'): fmt = '%(asctime)s %(levelname)s %(pathname)s[%(lineno)d] %(funcName)s: %(message)s' cfg: Config = Config() filename: str = os.path.join(cfg.log_dir, cfg.log_file) backup_count: int = cfg.backup_count th = DailyRotatingFileHandler(filename=filename, when='MIDNIGHT', backupCount=backup_count, encoding="utf-8") th.setFormatter(logging.Formatter(fmt)) color_fmt = ( '%(log_color)s%(asctime)s %(levelname)s %(pathname)s[%(lineno)d]: %(message)s' ) color_formatter = colorlog.ColoredFormatter( color_fmt, log_colors={ 'DEBUG': 'cyan', 'INFO': 'green', 'WARNING': 'yellow', 'ERROR': 'red', 'CRITICAL': 'bold_red' } ) sh = logging.StreamHandler() sh.setFormatter(color_formatter) self.logger = logging.getLogger(filename) self.logger.setLevel(self.level_relations.get(level)) # type: ignore self.logger.addHandler(th) self.logger.addHandler(sh) logger = Logger(level=Config().log_level).logger ================================================ FILE: archive/ktransformers/server/config/singleton.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Implement singleton Author : unicornchan Date : 2024-06-11 17:08:36 Version : 1.0.0 LastEditors : chenxl LastEditTime : 2024-07-27 01:55:56 ''' import abc class Singleton(abc.ABCMeta, type): """_summary_ Args: abc.ABCMeta: Provide a mechanism for defining abstract methods and properties, enforcing subclasses to implement these methods and properties. type: Inherit from 'type' to make 'Singleton' a metaclass, enabling the implementation of the Singleton """ _instances = {} def __call__(cls, *args, **kwds): if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwds) return cls._instances[cls] class AbstractSingleton(abc.ABC, metaclass=Singleton): """Provided an abstract Singleton base class, any class inheriting from this base class will automatically become a Singleton class. Args: abc.ABC: Abstract base class, it cannot be instantiated, only inherited. """ ================================================ FILE: archive/ktransformers/server/crud/__init__.py ================================================ ================================================ FILE: archive/ktransformers/server/crud/assistants/__init__.py ================================================ ================================================ FILE: archive/ktransformers/server/crud/assistants/assistants.py ================================================ from time import time from typing import Optional,List from uuid import uuid4 from ktransformers.server.models.assistants.assistants import Assistant from ktransformers.server.schemas.assistants.assistants import AssistantCreate,AssistantObject,AssistantModify from ktransformers.server.utils.sql_utils import SQLUtil from ktransformers.server.config.log import logger from ktransformers.server.schemas.base import Order class AssistantDatabaseManager: def __init__(self) -> None: self.sql_util = SQLUtil() def create_assistant_object(self, assistant: AssistantCreate) -> AssistantObject: assistant = AssistantObject( **assistant.model_dump(mode='json'), id=str(uuid4()), object='assistant', created_at=int(time()), ) return assistant def db_count_assistants(self) -> int: with self.sql_util.get_db() as db: return db.query(Assistant).count() def db_create_assistant(self, assistant: AssistantCreate): ass_obj = self.create_assistant_object(assistant) ass_obj.sync_db() return ass_obj def db_list_assistants(self, limit: Optional[int], order: Order) -> List[AssistantObject]: with self.sql_util.get_db() as db: query = db.query(Assistant).order_by( order.to_sqlalchemy_order()(Assistant.created_at)) if limit is not None: db_assistants = query.limit(limit) else: db_assistants = query.all() return [AssistantObject.model_validate(a.__dict__) for a in db_assistants] def db_get_assistant_by_id(self, assistant_id: str) -> Optional[AssistantObject]: with self.sql_util.get_db() as db: db_assistant = db.query(Assistant).filter( Assistant.id == assistant_id).first() if db_assistant is None: logger.debug(f"no assistant with id {str}") return None return AssistantObject.model_validate(db_assistant.__dict__) def db_update_assistant_by_id(self, assistant_id: str, assistant: AssistantModify): with self.sql_util.get_db() as db: db_assistant = db.query(Assistant).filter( Assistant.id == assistant_id).first() self.sql_util.db_update_commit_refresh(db, db_assistant, assistant) return AssistantObject.model_validate(db_assistant.__dict__) def db_delete_assistant_by_id(self, assistant_id: str): with self.sql_util.get_db() as db: db_assistant = db.query(Assistant).filter( Assistant.id == assistant_id).first() db.delete(db_assistant) db.commit() ================================================ FILE: archive/ktransformers/server/crud/assistants/messages.py ================================================ from time import time from typing import Optional from uuid import uuid4 from ktransformers.server.models.assistants.messages import Message from ktransformers.server.schemas.assistants.messages import MessageCore, MessageCreate, MessageObject from ktransformers.server.schemas.base import Order,ObjectID from ktransformers.server.utils.sql_utils import SQLUtil class MessageDatabaseManager: def __init__(self) -> None: self.sql_util = SQLUtil() @staticmethod def create_db_message_by_core(message: MessageCore): message_dict = message.model_dump(mode="json") return Message(**message_dict, id=str(uuid4()), created_at=int(time())) def create_db_message(self, message: MessageCreate): return MessageDatabaseManager.create_db_message_by_core(message.to_core()) def db_add_message(self, message: Message): with self.sql_util.get_db() as db: db.add(message) self.sql_util.db_add_commit_refresh(db, message) def db_create_message(self, thread_id: str, message: MessageCreate, status: MessageObject.Status): db_message = self.create_db_message(message) db_message.status = status.value db_message.thread_id = thread_id self.db_add_message(db_message) return MessageObject.model_validate(db_message.__dict__) @staticmethod def create_message_object(thread_id: ObjectID, run_id: ObjectID, message: MessageCreate): core = message.to_core() return MessageObject( **core.model_dump(mode='json'), id=str(uuid4()), object='thread.message', created_at=int(time()), thread_id=thread_id, run_id=run_id, status=MessageObject.Status.in_progress, ) def db_sync_message(self, message: MessageObject): db_message = Message( **message.model_dump(mode="json"), ) with self.sql_util.get_db() as db: self.sql_util.db_merge_commit(db, db_message) def db_list_messages_of_thread( self, thread_id: str, limit: Optional[int] = None, order: Order = Order.DESC): # logger.debug( # f"list messages of: {thread_id}, limit {limit}, order {order}") with self.sql_util.get_db() as db: query = ( db.query(Message) .filter(Message.thread_id == thread_id) .order_by(order.to_sqlalchemy_order()(Message.created_at)) ) if limit is not None: messages = query.limit(limit) else: messages = query.all() message_list = [MessageObject.model_validate(m.__dict__) for m in messages] return message_list def db_get_message_by_id(self, thread_id: ObjectID, message_id: ObjectID) -> MessageObject: with self.sql_util.get_db() as db: message = db.query(Message).filter( Message.id == message_id).first() assert message.thread_id == thread_id message_info = MessageObject.model_validate(message.__dict__) return message_info def db_delete_message_by_id(self, thread_id: ObjectID, message_id: ObjectID): with self.sql_util.get_db() as db: message = db.query(Message).filter( Message.id == message_id).first() assert message.thread_id == thread_id db.delete(message) db.commit() ================================================ FILE: archive/ktransformers/server/crud/assistants/runs.py ================================================ from time import time from uuid import uuid4 from ktransformers.server.models.assistants.runs import Run from ktransformers.server.schemas.assistants.runs import RunCreate,RunObject from ktransformers.server.schemas.base import ObjectID from ktransformers.server.utils.sql_utils import SQLUtil class RunsDatabaseManager: def __init__(self) -> None: self.sql_util = SQLUtil() def create_run_object(self, thread_id: ObjectID, run: RunCreate) -> RunObject: run_obj = RunObject( **run.model_dump(mode='json', exclude={"stream"}), id=str(uuid4()), object='run', created_at=int(time()), thread_id=thread_id, status=RunObject.Status.queued, ) run_obj.set_compute_save(0) return run_obj def db_create_run(self, thread_id: str, run: RunCreate): db_run = Run( **run.model_dump(mode="json", exclude={"stream"}), id=str(uuid4()), created_at=int(time()), status="queued", thread_id=thread_id, ) with self.sql_util.get_db() as db: self.sql_util.db_add_commit_refresh(db, db_run) run_obj = RunObject.model_validate(db_run.__dict__) run_obj.set_compute_save(0) return run_obj def db_sync_run(self, run: RunObject) -> None: db_run = Run( **run.model_dump(mode='json'), ) with self.sql_util.get_db() as db: self.sql_util.db_merge_commit(db, db_run) def db_get_run(self, run_id: ObjectID) -> RunObject: with self.sql_util.get_db() as db: db_run = db.query(Run).filter(Run.id == run_id).first() return RunObject.model_validate(db_run.__dict__) ================================================ FILE: archive/ktransformers/server/crud/assistants/threads.py ================================================ from time import time from typing import Optional,List from uuid import uuid4 from ktransformers.server.models.assistants.messages import Message from ktransformers.server.models.assistants.threads import Thread from ktransformers.server.schemas.assistants.threads import ThreadCreate,ThreadObject from ktransformers.server.schemas.base import ObjectID, Order from ktransformers.server.schemas.conversation import ThreadPreview from ktransformers.server.utils.sql_utils import SQLUtil from ktransformers.server.crud.assistants.messages import MessageDatabaseManager from ktransformers.server.config.log import logger from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager class ThreadsDatabaseManager: def __init__(self) -> None: self.sql_util = SQLUtil() self.message_manager = MessageDatabaseManager() self.assistant_maanager = AssistantDatabaseManager() def db_create_thread(self, thread: ThreadCreate): thread_id = str(uuid4()) db_messages = [] with self.sql_util.get_db() as db: if thread.messages is not None: logger.debug("Creating messages first for thread") for message in thread.messages: db_message: Message = MessageDatabaseManager.create_db_message_by_core( message) db_message.role = "user" db_message.thread_id = thread_id db.add(db_message) db_messages.append(db_message) db_thread = Thread( **thread.model_dump(exclude="messages"), id=str(uuid4()), created_at=int(time()), messages=db_messages, ) self.sql_util.db_add_commit_refresh(db, db_thread) thread_obj = ThreadObject.model_validate(db_thread.__dict__) if 'assistant_id' in thread.meta_data: # assistant = self.assistant_maanager.db_get_assistant_by_id(thread.meta_data['assistant_id'], db) assistant = self.assistant_maanager.db_get_assistant_by_id(thread.meta_data['assistant_id']) logger.info( f'Append this related thread to assistant {assistant.id}') assistant.append_related_threads([thread_obj.id]) assistant.sync_db(db) return thread_obj def db_get_thread_by_id(self, thread_id: ObjectID): with self.sql_util.get_db() as db: db_thread = db.query(Thread).filter(Thread.id == thread_id).first() return ThreadObject.model_validate(db_thread.__dict__) def db_list_threads(self, limit: Optional[int], order: Order) -> List[ThreadObject]: with self.sql_util.get_db() as db: query = db.query(Thread).order_by(order.to_sqlalchemy_order()( Thread.created_at)).filter(~Thread.meta_data.contains('assistant_id')) if limit is not None: db_threads = query.limit(limit) else: db_threads = query.all() return [ThreadObject.model_validate(tool.__dict__) for tool in db_threads] def db_list_threads_preview(self, limit: Optional[int], order: Order) -> List[ThreadPreview]: threads = self.db_list_threads(limit, order) previews = [] for thread in threads: messages = self.message_manager.db_list_messages_of_thread( thread.id, limit=2, order=Order.ASC) if len(messages) == 2: message = messages[0] assistant = self.assistant_maanager.db_get_assistant_by_id( messages[1].assistant_id) else: message = None assistant = None previews.append(ThreadPreview( assistant=assistant, thread=thread, first_message=message)) return previews def db_delete_thread_by_id(self, thread_id: ObjectID): with self.sql_util.get_db() as db: db_thread = db.query(Thread).filter(Thread.id == thread_id).first() db.delete(db_thread) # TODO delete related messages and runs and other stuff or just gc db.commit() ================================================ FILE: archive/ktransformers/server/exceptions.py ================================================ from fastapi import HTTPException, status def db_exception(): return HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="DB Error", ) def not_implemented(what): return HTTPException( status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=f"{what} not implemented", ) def internal_server_error(what): return HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"{what}") def request_error(what): return HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"{what}") ================================================ FILE: archive/ktransformers/server/main.py ================================================ import asyncio import os import re from uuid import uuid4 import torch import torch.distributed from fastapi import FastAPI from fastapi.staticfiles import StaticFiles import uvicorn.logging import uvicorn import sys import atexit project_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) from fastapi.middleware.cors import CORSMiddleware from ktransformers.server.args import ArgumentParser from ktransformers.server.config.config import Config from ktransformers.util import utils from ktransformers.server.utils.create_interface import create_interface, GlobalInterface, get_thread_context_manager from fastapi.openapi.utils import get_openapi from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from ktransformers.server.api import router, post_db_creation_operations from ktransformers.server.utils.sql_utils import Base, SQLUtil from ktransformers.server.config.log import logger import subprocess import tempfile def mount_app_routes(mount_app: FastAPI): sql_util = SQLUtil() logger.info("Creating SQL tables") Base.metadata.create_all(bind=sql_util.sqlalchemy_engine) post_db_creation_operations() mount_app.include_router(router) def create_app(): cfg = Config() if(hasattr(GlobalInterface.interface, "lifespan")): app = FastAPI(lifespan=GlobalInterface.interface.lifespan) else: app = FastAPI() if Config().web_cross_domain: app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) mount_app_routes(app) if cfg.mount_web: mount_index_routes(app) return app def update_web_port(config_file: str): ip_port_pattern = ( r"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}" ) with open(config_file, "r", encoding="utf-8") as f_cfg: web_config = f_cfg.read() ip_port = "localhost:" + str(Config().server_port) new_web_config = re.sub(ip_port_pattern, ip_port, web_config) with open(config_file, "w", encoding="utf-8") as f_cfg: f_cfg.write(new_web_config) def mount_index_routes(app: FastAPI): project_dir = os.path.dirname(os.path.dirname(__file__)) web_dir = os.path.join(project_dir, "website/dist") web_config_file = os.path.join(web_dir, "config.js") update_web_port(web_config_file) if os.path.exists(web_dir): app.mount("/web", StaticFiles(directory=web_dir), name="static") else: err_str = f"No website resources in {web_dir}, please complile the website by npm first" logger.error(err_str) print(err_str) exit(1) def run_api(app, host, port, **kwargs): if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): uvicorn.run( app, host=host, port=port, ssl_keyfile=kwargs.get("ssl_keyfile"), ssl_certfile=kwargs.get("ssl_certfile"), ) else: uvicorn.run(app, host=host, port=port, log_level="debug") def custom_openapi(app): if app.openapi_schema: return app.openapi_schema openapi_schema = get_openapi( title="ktransformers server", version="1.0.0", summary="This is a server that provides a RESTful API for ktransformers.", description="We provided chat completion and openai assistant interfaces.", routes=app.routes, ) openapi_schema["info"]["x-logo"] = {"url": "https://kvcache.ai/media/icon_1.png"} app.openapi_schema = openapi_schema return app.openapi_schema def verify_arg(args): nproc_per_node = int(os.getenv('LOCAL_WORLD_SIZE')) if args.batch_size not in [1, 2, 3, 4]: raise ValueError(f'argument batch_size should be in [1, 2, 3, 4], got {args.batch_size}') if nproc_per_node not in [1, 2]: raise ValueError(f'argument nproc_per_node should be in [1, 2], got {nproc_per_node}') if args.tp not in [1, 2]: raise ValueError(f'argument tp should be in [1, 2], got {args.tp}') if nproc_per_node != args.tp: raise ValueError(f'argument nproc_per_node should be equal to tp, got nproc_per_node is {nproc_per_node}, tp is {args.tp}') def main(): try: import torch_npu use_npu = torch.npu.is_available() torch.npu.config.allow_internal_format = True except: use_npu = False cfg = Config() arg_parser = ArgumentParser(cfg) args = arg_parser.parse_args() if use_npu: verify_arg(args) rank_id = int(os.environ["RANK"]) args.device = args.device[:-1] + str(rank_id) create_interface(config=cfg, default_args=cfg, input_args=args) tp_size = args.tp world_size = int(os.getenv("WORLD_SIZE", '1')) if tp_size == world_size and tp_size > 1: if rank_id == 0: app = create_app() custom_openapi(app) run_api( app=app, host=args.host, port=args.port, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile, ) elif cfg.backend_type == 'ktransformers': while True: try: context = get_thread_context_manager() id = str(uuid4()) context.interface.sync_inference("", id, 1.0, 1.0) except Exception as e: print(f"An error occurred: {e}") finally: pass else: app = create_app() custom_openapi(app) run_api( app=app, host=args.host, port=args.port, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile, ) if __name__ == "__main__": main() ================================================ FILE: archive/ktransformers/server/models/__init__.py ================================================ ================================================ FILE: archive/ktransformers/server/models/assistants/__init__.py ================================================ ================================================ FILE: archive/ktransformers/server/models/assistants/assistants.py ================================================ from sqlalchemy import JSON, Column, Float, Integer, String, Text from sqlalchemy.orm import relationship from ktransformers.server.utils.sql_utils import Base class Assistant(Base): __tablename__ = "assistants" id = Column(String, primary_key=True, index=True) object = Column(String, default="assistant") created_at = Column(Integer) name = Column(String, nullable=True) description = Column(String, nullable=True) model = Column(String) instructions = Column(Text, nullable=True) tools = Column(JSON) tool_resources = Column(JSON) temperature = Column(Float, nullable=True) meta_data = Column(JSON, nullable=True) top_p = Column(Float, nullable=True) response_format = Column(JSON, default="auto") build_status = Column(JSON, nullable=True) runs = relationship("Run", back_populates="assistant") messages = relationship("Message", back_populates="assistant") ================================================ FILE: archive/ktransformers/server/models/assistants/messages.py ================================================ from sqlalchemy import JSON, Column, ForeignKey, Integer, String from sqlalchemy.orm import relationship from ktransformers.server.utils.sql_utils import Base class Message(Base): __tablename__ = "messages" id = Column(String, primary_key=True, index=True) object = Column(String, default="thread.message") created_at = Column(Integer) thread_id = Column(String, ForeignKey("threads.id")) status = Column(String, default="in_progress") incomplete_details = Column(JSON, nullable=True) completed_at = Column(Integer, nullable=True) incomplete_at = Column(Integer, nullable=True) role = Column(JSON) content = Column(JSON) assistant_id = Column(String, ForeignKey("assistants.id"), nullable=True) run_id = Column(String, ForeignKey("runs.id"), nullable=True) attachments = Column(JSON, nullable=True) meta_data = Column(JSON, nullable=True) thread = relationship("Thread", back_populates="messages") assistant = relationship("Assistant", back_populates="messages") run = relationship("Run", back_populates="message") ================================================ FILE: archive/ktransformers/server/models/assistants/run_steps.py ================================================ from sqlalchemy import JSON, Column, ForeignKey, Integer, String from sqlalchemy.orm import relationship from ktransformers.server.utils.sql_utils import Base class RunStep(Base): __tablename__ = "run_steps" # todo id = Column(String, primary_key=True, index=True) object = Column(String, default="thread.run.step") created_at = Column(Integer) assistant_id = Column(String, ForeignKey("assistants.id")) thread_id = Column(String, ForeignKey("threads.id")) run_id = Column(String, ForeignKey("runs.id")) type = Column(String) status = Column(String) step_details = Column(JSON) last_error = Column(JSON, nullable=True) expires_at = Column(Integer, nullable=True) cancelled_at = Column(Integer, nullable=True) failed_at = Column(Integer, nullable=True) completed_at = Column(Integer, nullable=True) meta_data = Column(JSON, nullable=True) usage = Column(JSON, nullable=True) assistant = relationship("Assistant", back_populates="run_steps") thread = relationship("Thread", back_populates="run_steps") run = relationship("Run", back_populates="run_steps") ================================================ FILE: archive/ktransformers/server/models/assistants/runs.py ================================================ from sqlalchemy import JSON, Column, Float, ForeignKey, Integer, String, Text from sqlalchemy.orm import relationship from ktransformers.server.utils.sql_utils import Base class Run(Base): __tablename__ = "runs" id = Column(String, primary_key=True, index=True) object = Column(String, default="thread.run") created_at = Column(Integer) thread_id = Column(String, ForeignKey("threads.id")) assistant_id = Column(String, ForeignKey("assistants.id")) status = Column(String) required_action = Column(JSON, nullable=True) last_error = Column(JSON, nullable=True) expires_at = Column(Integer, nullable=True) started_at = Column(Integer, nullable=True) cancelled_at = Column(Integer, nullable=True) failed_at = Column(Integer, nullable=True) completed_at = Column(Integer, nullable=True) incomplete_details = Column(JSON, nullable=True) # get from assistant model = Column(String) instructions = Column(Text, nullable=True) tools = Column(JSON) meta_data = Column(JSON, nullable=True) usage = Column(JSON, nullable=True) temperature = Column(Float, nullable=True) top_p = Column(Float, nullable=True) max_propmp_tokens = Column(Integer, nullable=True) truncation_strategy = Column(JSON) tool_choice = Column(JSON) response_format = Column(JSON, default="auto") thread = relationship("Thread", back_populates="runs") assistant = relationship("Assistant", back_populates="runs") message = relationship("Message", back_populates="run") ================================================ FILE: archive/ktransformers/server/models/assistants/threads.py ================================================ from sqlalchemy import JSON, Column, Integer, String from sqlalchemy.orm import relationship from ktransformers.server.utils.sql_utils import Base class Thread(Base): __tablename__ = "threads" id = Column(String, primary_key=True, index=True) object = Column(String, default="thread") created_at = Column(Integer) tool_resources = Column(JSON, nullable=True) meta_data = Column(JSON, nullable=True) runs = relationship("Run", back_populates="thread") messages = relationship("Message", back_populates="thread") ================================================ FILE: archive/ktransformers/server/requirements.txt ================================================ torch >= 2.3.0 transformers >= 4.51.3 fastapi >= 0.111.0 langchain >= 0.2.0 blessed >= 1.20.0 accelerate >= 0.31.0 sentencepiece >= 0.1.97 openai setuptools build ninja wheel colorlog fire zmq psutil ================================================ FILE: archive/ktransformers/server/schemas/__init__.py ================================================ ================================================ FILE: archive/ktransformers/server/schemas/assistants/__init__.py ================================================ ================================================ FILE: archive/ktransformers/server/schemas/assistants/assistants.py ================================================ from enum import Enum from time import time from typing import AsyncIterable, Callable, Dict, List, Optional, Union from asyncio import Lock, Queue from fastapi import logger from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator import torch from ktransformers.server.config.config import Config from ktransformers.server.models.assistants.assistants import Assistant from ktransformers.server.models.assistants.threads import Thread from ktransformers.server.schemas.assistants.messages import Role from ktransformers.server.schemas.assistants.runs import RunObject,RunStreamResponse,ObjectWithCreatedTime from ktransformers.server.schemas.assistants.threads import ThreadObject from ktransformers.server.schemas.base import Metadata,MetadataField,ObjectID from ktransformers.server.schemas.assistants.tool import Tool,CodeInterpreter,FileSearch,RelatedThreads,FuntionTool,ToolResource,CodeInterpreterResource,FileSearchResource,RelatedThreadsResource,ToolType from ktransformers.server.utils.sql_utils import SQLUtil class AssistantBase(BaseModel): name: Optional[str] = Field(None,description='The name of the assistant.') description: Optional[str] = Field(None,description='The description of the assistant.') instructions: Optional[str] = Field(None,description='Instructions which is added in front of the input of LLM') tools: List[Tool] = Field([], max_length=128) @field_validator('tools', mode='before') def validate_tools(cls, value): re = [] if not isinstance(value, list): raise ValueError('Invalid type for tools') for tool in value: if 'type' not in tool: raise ValueError('Invalid type for tools') if tool['type'] == 'code_interpreter': re.append(CodeInterpreter(**tool)) elif tool['type'] == 'file_search': re.append(FileSearch(**tool)) elif tool['type'] == 'related_threads': re.append(RelatedThreads(**tool)) elif tool['type'] == 'function': re.append(FuntionTool(**tool)) else: raise ValueError('Invalid type for tools') return re tool_resources: List[ToolResource] = Field([], max_length=128) @field_validator('tool_resources', mode='before') def validate_tool_resources(cls, value): re = [] if not isinstance(value, list): raise ValueError('Invalid type for tool resources') for tool_re in value: if 'file_ids' in tool_re: re.append(CodeInterpreterResource(**tool_re)) elif 'vector_stores' in tool_re: re.append(FileSearchResource(**tool_re)) elif 'thread_ids' in tool_re: re.append(RelatedThreadsResource(**tool_re)) else: raise ValueError('Invalid type for tool resources') return re meta_data: Metadata = MetadataField @model_validator(mode='before') def convert_meta_data(cls, values): if 'meta_data' in values: values['metadata'] = values['meta_data'] return values temperature: Optional[float] = Field(ge=0.0, le=2.0, default=1) top_p: Optional[float] = Field(ge=0.0, le=1.0, default=1) response_format: Union[str, Dict[str, str]] = "auto" class AssistantCreate(AssistantBase): model: str class AssistantBuildStatus(BaseModel): class Status(Enum): not_build = "not_build" in_queue = "in_queue" parsing = "parsing" prefilling = "prefilling" dumping = "dumping" completed = "completed" paused = "paused" _lock: Lock = PrivateAttr(default_factory=Lock) _queue: Optional[Queue] = PrivateAttr(None) status: Status = Field(default=Status.not_build) total_file_count: int = Field(default=0) parsed_file_count: int = Field(default=0) prefilling_current: int = Field(default=0) prefilling_total: int = Field(default=0) build_started_time: Optional[int] = Field(default=None) build_completed_time: Optional[int] = Field(default=None) # in megabytes assistant_usage: int = Field(default=0, description='') assistant_total_usage: int = Field(default=0) disk_free_space: int = Field(default=0) disk_total_space: int = Field(default=0) def to_stream_reply(self) -> str: return f"event: assistant.build.status\ndata: {self.model_dump_json()}\n\n" class AssistantObject(AssistantBase, ObjectWithCreatedTime): model: Optional[str] = Field( default=Config().model_name) related_threads_objects: Optional[List] = Field(None, exclude=True) _encoded_instruction: Optional[torch.Tensor] = PrivateAttr(default=None) build_status: AssistantBuildStatus = Field(default=AssistantBuildStatus()) def as_api_response(self): return self.model_dump(exclude={'build_status'}) def get_related_threads_ids(self) -> List[ObjectID]: re = [] for tool, tool_re in zip(self.tools, self.tool_resources): if tool.type == ToolType.RELATED_THREADS: re += tool_re.thread_ids or [] return re def get_related_threads_objects(self) -> List: # raise NotImplementedError # should be replaced sql_utils = SQLUtil() if self.related_threads_objects is None: with sql_utils.get_db() as db: db_threads = db.query(Thread).all() self.related_threads_objects = [tool for tool in [ThreadObject.model_validate( tool.__dict__) for tool in db_threads] if tool.is_related_threads and tool.meta_data['assistant_id'] == self.id] # logger.debug( # f'Found {len(self.related_threads_objects)} related threads') return self.related_threads_objects def append_related_threads(self, thread_ids: List[ObjectID]): # logger.debug(f'{self.tools} {self.tool_resources}') for tool, tool_re in zip(self.tools, self.tool_resources): if tool.type == ToolType.RELATED_THREADS: tool_re.thread_ids += thread_ids return self.tools.append(RelatedThreads(type=ToolType.RELATED_THREADS)) self.tool_resources.append( RelatedThreadsResource(thread_ids=thread_ids)) async def update_build_status(self, events: AsyncIterable) -> AsyncIterable: async for event in events: # logger.debug(event) if isinstance(event, RunStreamResponse): if event.event == RunObject.Status.completed: self.build_status.status = AssistantBuildStatus.Status.completed self.build_status.build_completed_time = int(time()) self.sync_db() yield self.build_status.model_copy() elif isinstance(event, dict): # logger.debug('dict') if 'stage' in event: if event['stage'] == 'prefill': self.build_status.status = AssistantBuildStatus.Status.prefilling self.build_status.prefilling_current = event['curr_progress'] self.build_status.prefilling_total = event['max_progress'] if event['stage'] == 'parse': self.build_status.status = AssistantBuildStatus.Status.parsing self.build_status.parsed_file_count = event['curr_progress'] self.build_status.total_file_count = event['max_progress'] yield self.build_status.model_copy() def get_build_status(self) -> AssistantBuildStatus: return self.build_status def sync_db(self)->None: # raise NotImplementedError # should be replaced sql_utils = SQLUtil() db_assistant = Assistant( **self.model_dump(mode='json'), ) with sql_utils.get_db() as db: sql_utils.db_merge_commit(db, db_assistant) def get_encoded_instruction(self,encode_fn:Callable)->torch.Tensor: if self._encoded_instruction is None: logger.info(f'encoding assistant instruction: {self.instructions}') self._encoded_instruction = encode_fn(self.instructions, Role.user) return self._encoded_instruction class AssistantModify(AssistantBase): model: Optional[str] = None # Non API Backend ================================================ FILE: archive/ktransformers/server/schemas/assistants/messages.py ================================================ from enum import Enum from typing import ForwardRef, List, Optional, Union,Callable import torch from pydantic import BaseModel, PrivateAttr, model_validator from ktransformers.server.exceptions import not_implemented from ktransformers.server.config.log import logger from ktransformers.server.models.assistants.messages import Message from ktransformers.server.schemas.base import Metadata, MetadataField, ObjectWithCreatedTime from ktransformers.server.schemas.assistants.tool import Field,CodeInterpreter,FileSearch from ktransformers.server.utils.sql_utils import SQLUtil class IncompleteDetails(BaseModel): reason: str class ContentType(Enum): image_file = "image_file" image_url = "image_url" text = "text" class ContentObject(BaseModel): type: ContentType class ImageFile(BaseModel): file_id: str detail: str class ImageFileObject(ContentObject): image_file: ImageFile class ImageUrl(BaseModel): url: str detail: str class ImageUrlObject(ContentObject): image_url: ImageUrl class Annotation(BaseModel): todo: str class Text(BaseModel): value: str annotations: List[Annotation] = Field(default=[]) class TextObject(ContentObject): text: Text delta_index: int = Field(default=0,exclude=True) special_tokens_on: bool = Field(default=False,exclude=True) last_two: str= Field(default='',exclude=True) def filter_append(self,text:str): self.text.value+=text self.delta_index+=1 return True Content = Union[ImageFileObject, ImageUrlObject, TextObject] class Attachment(BaseModel): file_id: Optional[str] = Field(default=None) tools: Optional[List[Union[CodeInterpreter, FileSearch]]] = Field(default=None) class Role(Enum): user = "user" assistant = "assistant" def is_user(self)->bool: return self == Role.user class MessageCore(BaseModel): role: Role content: List[Content] attachments: Optional[List[Attachment]] meta_data: Metadata = MetadataField @model_validator(mode='before') @classmethod def convert_meta_data(cls,values): if 'meta_data' in values: values['metadata'] = values['meta_data'] return values class MessageBase(MessageCore): class Status(Enum): created = "created" # only used for stream in_progress = "in_progress" incomplete = "incomplete" completed = "completed" thread_id: str status: Status incomplete_details: Optional[IncompleteDetails] = None completed_at: Optional[int] = None incomplete_at: Optional[int] = None assistant_id: Optional[str] = None run_id: Optional[str] MessageStreamResponse = ForwardRef('MessageStreamResponse') class MessageObject(MessageBase, ObjectWithCreatedTime): _encoded_content: Optional[torch.Tensor] = PrivateAttr(default=None) def get_text_content(self) -> str: text_content = "" for content in self.content: if content.type == ContentType.text: text_content += content.text.value else: raise not_implemented("Content other than text") return text_content async def get_encoded_content(self,encode_fn:Callable): if self._encoded_content is None: logger.info(f'encoding {self.role.value} message({self.status.value}): {self.get_text_content()}') self._encoded_content = encode_fn(self.get_text_content(),self.role) for f in self.get_attached_files(): logger.info(f'encoding file: {f.filename}') self._encoded_content = torch.cat([self._encoded_content, encode_fn(await f.get_str(),self.role)],dim=-1) yield None yield self._encoded_content def get_attached_files(self): raise NotImplementedError # should be replaced def append_message_delta(self,text:str): raise NotImplementedError # should be replaced def sync_db(self): # raise NotImplementedError # should be replaced sql_utils = SQLUtil() db_message = Message( **self.model_dump(mode="json"), ) with sql_utils.get_db() as db: sql_utils.db_merge_commit(db, db_message) def stream_response_with_event(self, event: MessageBase.Status) -> MessageStreamResponse: match event: case MessageObject.Status.created: self.status = MessageObject.Status.in_progress case _: self.status = event return MessageStreamResponse(message=self, event=event) class MessageStreamResponse(BaseModel): message: MessageObject event: MessageObject.Status def to_stream_reply(self): return f"event: thread.message.{self.event.value}\ndata: {self.message.model_dump_json()}\n\n" class MessageCreate(BaseModel): role: Role = Field(default=Role.user) content: Union[str | List[Content]] attachments: Optional[List[Attachment]] = None meta_data: Metadata = MetadataField @model_validator(mode='before') @classmethod def convert_meta_data(cls,values): if 'meta_data' in values: values['metadata'] = values['meta_data'] return values def to_core(self) -> MessageCore: # logger.debug(f"Converting message create to core {self.model_dump()}") core = MessageCore( role=self.role, content=[], attachments=self.attachments, meta_data=self.meta_data, ) if isinstance(self.content, str): core.content = [TextObject(type="text", text=Text(value=self.content, annotations=[]))] elif isinstance(self.content, list): core.content = self.content else: raise ValueError("Invalid content type") return core class MessageModify(BaseModel): meta_data: Metadata = MetadataField @model_validator(mode='before') @classmethod def convert_meta_data(cls,values): if 'meta_data' in values: values['metadata'] = values['meta_data'] return values ================================================ FILE: archive/ktransformers/server/schemas/assistants/runs.py ================================================ from enum import Enum from typing import Dict, List, Optional, Union, ForwardRef from pydantic import BaseModel, Field, model_validator from ktransformers.server.models.assistants.runs import Run from ktransformers.server.schemas.base import TODO, Metadata, MetadataField, ObjectWithCreatedTime from ktransformers.server.schemas.assistants.threads import ThreadCreate from ktransformers.server.schemas.assistants.tool import Tool, ToolResource from ktransformers.server.utils.sql_utils import SQLUtil class ToolCall(BaseModel): id: str type: str function: TODO class SubmitToolOutputs(BaseModel): tool_calls: List[ToolCall] class RequiredAction(BaseModel): type: str submit_tool_outputs: TODO class LastError(BaseModel): code: str message: str class IncompleteDetails(BaseModel): reason: str class Usage(BaseModel): completion_tokens: int prompt_tokens: int total_tokens: int class TruncationStrategy(BaseModel): type: str = "auto" last_message: Optional[int] class ToolChoiceType(Enum): none = "none" auto = "auto" required = "required" class RunBase(BaseModel): class Status(Enum): created = "created" # only stream event will have this created status queued = "queued" in_progress = "in_progress" requires_action = "requires_action" cancelling = "cancelling" cancelled = "cancelled" failed = "failed" completed = "completed" expired = "expired" thread_id: str assistant_id: str status: Status = Status.queued required_action: Optional[RequiredAction] = Field(None) last_error: Optional[LastError] = Field(None) expires_at: Optional[int]= Field(None) started_at: Optional[int] = Field(None) cancelled_at: Optional[int] = Field(None) failed_at: Optional[int] = Field(None) completed_at: Optional[int] = Field(None) incomplete_details: Optional[IncompleteDetails] = Field(None) model: Optional[str] = Field(None) instructions: Optional[str] = Field(None) tools: Optional[List[Tool]] = Field([]) meta_data: Metadata = MetadataField @model_validator(mode='before') @classmethod def convert_meta_data(cls,values): if 'meta_data' in values: values['metadata'] = values['meta_data'] return values def set_compute_save(self,save:int): self.meta_data['compute_save'] = str(save) usage: Optional[Usage] = Field(None) temperature: Optional[float] = Field(None) top_p: Optional[float]= Field(None) max_propmp_tokens: Optional[int]= Field(None) truncation_strategy: Optional[TruncationStrategy]= Field(None) tool_choice: Optional[Union[ToolChoiceType, dict]]= Field(None) response_format: Union[str, Dict[str, str]] = "auto" RunStreamResponse = ForwardRef('RunStreamResponse') class RunObject(RunBase, ObjectWithCreatedTime): def stream_response_with_event(self,event:RunBase.Status)->RunStreamResponse: match event: case RunBase.Status.created: self.status = RunBase.Status.queued case _: self.status = event return RunStreamResponse(run=self, event=event) def sync_db(self): # raise NotImplementedError # should be replaced in crud sql_utils = SQLUtil() db_run = Run( **self.model_dump(mode='json'), ) with sql_utils.get_db() as db: sql_utils.db_merge_commit(db, db_run) def create_message_creation_step(self): raise NotImplementedError # should be replaced class RunStreamResponse(BaseModel): run: RunObject event: RunObject.Status def to_stream_reply(self): return f"event: thread.run.{self.event.value}\ndata: {self.run.model_dump_json()}\n\n" class RunCreate(BaseModel): assistant_id: str model: Optional[str] = Field(default=None) instructions: Optional[str] = Field(default=None) # TODO: Add this # additional_instructions: Optional[str] # additional_messages: Optional[List[MessageCore]] tools: List[Tool] = Field(default=[]) meta_data: Metadata = MetadataField @model_validator(mode='before') @classmethod def convert_meta_data(cls,values): if 'meta_data' in values: values['metadata'] = values['meta_data'] return values temperature: Optional[float] = Field(default=None) top_p: Optional[float] = Field(default=None) stream: Optional[bool] = Field(default=None) max_propmp_tokens: Optional[int] = Field(default=None) # TODO: Add this # max_completion_tokens: Optional[int] truncation_strategy: Optional[TruncationStrategy] = Field(default=None) tool_choice: Optional[Union[ToolChoiceType, dict]] = Field(default=None) response_format: Union[str, Dict[str, str]] = Field(default="auto") class RunThreadCreate(BaseModel): assistant_id: str thread: Optional[ThreadCreate] model: Optional[str] instructions: Optional[str] tools: List[Tool] tool_resources: List[ToolResource] meta_data: Metadata = MetadataField @model_validator(mode='before') @classmethod def convert_meta_data(cls,values): if 'meta_data' in values: values['metadata'] = values['meta_data'] return values temperature: Optional[float] top_p: Optional[float] stream: Optional[bool] max_propmp_tokens: Optional[int] # TODO: Add this # max_completion_tokens: Optional[int] truncation_strategy: TruncationStrategy tool_choice: Union[ToolChoiceType, dict] response_format: Union[str, Dict[str, str]] = "auto" class RunModify(BaseModel): meta_data: Metadata = MetadataField @model_validator(mode='before') @classmethod def convert_meta_data(cls,values): if 'meta_data' in values: values['metadata'] = values['meta_data'] return values class ToolOutput(BaseModel): tool_call_id: Optional[str] output: Optional[str] class RunSubmit(BaseModel): tool_outputs: List[ToolOutput] stream: Optional[bool] ================================================ FILE: archive/ktransformers/server/schemas/assistants/streaming.py ================================================ import asyncio from typing import AsyncIterable, List, Union from fastapi import Request from fastapi.responses import StreamingResponse from pydantic import BaseModel from ktransformers.server.schemas.assistants.runs import RunStreamResponse from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk from ktransformers.server.config.log import logger from ktransformers.server.schemas.base import Object from ktransformers.server.schemas.assistants.messages import ContentType, ImageFileObject, ImageUrlObject, MessageObject, Text, TextObject class TextObjectWithIndex(TextObject): index: int class ImageFileObjectWithIndex(ImageFileObject): index: int class ImageUrlObjectWithIndex(ImageUrlObject): index: int ContentWithIndex = Union[TextObjectWithIndex, ImageFileObjectWithIndex, ImageUrlObjectWithIndex] class MessageDeltaImpl(BaseModel): # role: Optional[str] content: List[ContentWithIndex] class MessageDelta(Object): delta: MessageDeltaImpl def to_stream_reply(self): return f"event: thread.message.delta\ndata: {self.model_dump_json()}\n\n" def text_delta(index: int, text: str): return MessageDeltaImpl(content=[TextObjectWithIndex(index=index, type=ContentType.text, text=Text(value=text))]) def append_message_delta(self: MessageObject, text: str): if len(self.content) == 0: self.content.append(TextObject(type=ContentType.text, text=Text(value=''), delta_index=0)) text_object: TextObject = self.content[0] if text_object.filter_append(text): return MessageDelta(id=self.id, object="thread.message.delta", delta=text_delta(text_object.delta_index, text)) else: return None MessageObject.append_message_delta = append_message_delta class RunStepDeltaImpl(BaseModel): pass class RunStepDelta(Object): delta: RunStepDeltaImpl def to_stream_reply(self): return f"event: thread.run.step.delta\ndata: {self.model_dump_json()}\n\n" class Done(): def to_stream_reply(self): return f"data: [DONE]\n\n" async def check_client_link(request: Request, async_events: AsyncIterable): async for event in async_events: if await request.is_disconnected(): break yield event async def add_done(async_events: AsyncIterable): async for event in async_events: yield event yield Done() async def to_stream_reply(async_events: AsyncIterable): async for event in async_events: if isinstance(event, str): yield event else: yield event.to_stream_reply() async def filter_api_event(async_events: AsyncIterable): async for event in async_events: if isinstance(event, MessageDelta) or isinstance(event, RunStepDelta) or isinstance(event, RunStreamResponse) or isinstance(event, Done): yield event async def filter_chat_chunk(async_events: AsyncIterable): async for event in async_events: if isinstance(event, ChatCompletionChunk): yield event async def filter_by_types(async_events: AsyncIterable, types: List): async for event in async_events: for type in types: if isinstance(event, type): yield event continue def api_stream_response(request: Request, async_events: AsyncIterable): return StreamingResponse(check_client_link(request, to_stream_reply(add_done(filter_api_event(async_events)))), media_type="text/event-stream") def chat_stream_response(request: Request, async_events: AsyncIterable): return StreamingResponse(check_client_link(request, to_stream_reply(add_done(filter_chat_chunk(async_events)))), media_type="text/event-stream") def stream_response(request: Request, async_events: AsyncIterable): return StreamingResponse(check_client_link(request, to_stream_reply(add_done(async_events))), media_type="text/event-stream") def check_link_response(request: Request, async_events: AsyncIterable): return StreamingResponse(check_client_link(request, async_events), media_type="text/event-stream") def wrap_async_generator_into_queue(async_events: AsyncIterable) -> asyncio.Queue: queue = asyncio.Queue() async def inner(): # logger.debug('run inner') async for event in async_events: # logger.debug(f'put: {event}') await queue.put(event) await asyncio.sleep(0) # logger.debug(f'put: None') await queue.put(None) asyncio.create_task(inner()) return queue async def unwrap_async_queue(queue: asyncio.Queue) -> AsyncIterable: while True: events = [await queue.get()] events.extend([queue.get_nowait() for _ in range(queue.qsize())]) logger.debug(f'getting {len(events)} events') for event in events: if event is None: break yield event async def unwrap_async_queue_slow(queue: asyncio.Queue) -> AsyncIterable: while True: event = await queue.get() # logger.debug(f'unwrap_async_queue {event}') if event is None: break yield event ================================================ FILE: archive/ktransformers/server/schemas/assistants/threads.py ================================================ from enum import Enum from typing import List from typing_extensions import Self from pydantic import BaseModel, Field, model_validator from ktransformers.server.schemas.base import Metadata, MetadataField, ObjectWithCreatedTime from ktransformers.server.schemas.assistants.tool import ToolResource from ktransformers.server.schemas.assistants.messages import MessageCore class ThreadBase(BaseModel): meta_data: Metadata = MetadataField @model_validator(mode='before') @classmethod def convert_meta_data(cls,values): if 'meta_data' in values: values['metadata'] = values['meta_data'] return values tool_resources: List[ToolResource] = Field([], max_length=128) class ThreadObject(ThreadBase, ObjectWithCreatedTime): is_related_threads:bool = Field(False,exclude=True) @model_validator(mode='after') def check_is_related_threads(self)->Self: # logger.debug(f'check thread {self.id} is related thread? by {self}') if 'assistant_id' in self.meta_data: self.is_related_threads = True return self class StreamEvent(Enum): created = 'created' def to_stream_reply(self,event:StreamEvent): return f"event: thread.{event.value}\ndata: {self.model_dump_json()}\n\n" class ThreadCreate(ThreadBase): messages: List[MessageCore] = Field(default=[]) class ThreadModify(ThreadBase): pass # other than OpenAI API ================================================ FILE: archive/ktransformers/server/schemas/assistants/tool.py ================================================ from enum import Enum from typing import List, Optional, Union from pydantic import BaseModel, Field from ktransformers.server.schemas.base import ObjectID class ToolType(str, Enum): CODE_INTERPRETER = "code_interpreter" FILE_SEARCH = "file_search" RELATED_THREADS = "related_threads" FUNCTION = "function" class ToolBase(BaseModel): type: ToolType class CodeInterpreter(ToolBase): pass class FileSearch(ToolBase): pass class RelatedThreads(ToolBase): pass class FuntionTool(ToolBase): description: str name: str parameters: List[str] Tool = Union[CodeInterpreter, FileSearch, RelatedThreads, FuntionTool] class CodeInterpreterResource(BaseModel): file_ids: Optional[List[str]] = Field(default_factory=list, max_length=20) class FileSearchResource(BaseModel): vector_store_ids: Optional[List[str]] = Field(default_factory=list, max_length=1) vector_stores: Optional[List[str]] = Field(default_factory=list, max_length=1) class RelatedThreadsResource(BaseModel): thread_ids: List[ObjectID] = Field(default=[]) ToolResource = Union[CodeInterpreterResource,FileSearchResource,RelatedThreadsResource] ================================================ FILE: archive/ktransformers/server/schemas/base.py ================================================ from enum import Enum from typing import Dict import sqlalchemy from pydantic import BaseModel, ConfigDict, Field TODO = BaseModel ObjectID = str class Object(BaseModel): id: ObjectID object: str model_config = ConfigDict(from_attributes=True) # Pydantic Base Models class ObjectWithCreatedTime(Object): created_at: int class Order(str, Enum): ASC = "asc" DESC = "desc" def to_sqlalchemy_order(self): match self: case Order.ASC: return sqlalchemy.asc case Order.DESC: return sqlalchemy.desc Metadata = Dict[str, str] MetadataField: Metadata = Field({},max_length=16, alias="metadata") class DeleteResponse(Object): deleted: bool = True class OperationResponse(BaseModel): operation: str status: str ================================================ FILE: archive/ktransformers/server/schemas/conversation.py ================================================ from typing import Optional from pydantic import BaseModel from .assistants.assistants import AssistantObject from .assistants.threads import ThreadObject from .assistants.messages import MessageObject class ThreadPreview(BaseModel): assistant: Optional[AssistantObject] = None thread: ThreadObject first_message: Optional[MessageObject] = None ================================================ FILE: archive/ktransformers/server/schemas/endpoints/chat.py ================================================ from typing import List, Optional, Union, Dict, Any from typing_extensions import Literal from enum import Enum from pydantic import BaseModel, Field from ktransformers.server.config.config import Config from ktransformers.server.schemas.base import Object from openai.types.chat.chat_completion_chunk import Choice from uuid import uuid4 class CompletionUsage(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int prompt_tokens_details: Optional[Dict[str, Any]] = None completion_tokens_details: Optional[Dict[str, Any]] = None prefill_time: Optional[float] = None decode_time: Optional[float] = None class Role(Enum): system = 'system' user = 'user' assistant = 'assistant' tool = 'tool' function = 'function' class Message(BaseModel): content: Optional[str] = None role: Role name: Optional[str] = None tool_calls: Optional[List[Dict[str, Any]]] = {} tool_call_id: Optional[str] = None def to_tokenizer_message(self): message = {'role': self.role.value} if self.content is not None: message['content'] = self.content if self.name is not None: message['name'] = self.name if self.tool_calls is not {}: message['tool_calls'] = self.tool_calls if self.tool_call_id is not None: message['tool_call_id'] = self.tool_call_id return message class FunctionParameters(BaseModel): type: str = "object" properties: Dict[str, Any] = {} required: Optional[List[str]] = None class FunctionDefinition(BaseModel): name: str description: Optional[str] = None parameters: FunctionParameters = Field(default_factory=FunctionParameters) class ToolFunction(BaseModel): function: FunctionDefinition class Tool(BaseModel): type: Literal["function"] function: FunctionDefinition class ChatCompletionCreate(BaseModel): messages: List[Message] model: str stream: bool = False temperature: Optional[float] = Field(default=Config().temperature) top_p: Optional[float] = Field(default=Config().top_p) tools: Optional[List[Tool]] = None tool_choice: Optional[Union[str, Dict[str, Any]]] = None stream_options: Optional[Dict[str, Any]] = None frequency_penalty: float = 0 presence_penalty: float = 0 max_tokens: Optional[int] = Field(default=None) max_completion_tokens: Optional[int] = Field(default=None) return_speed: Optional[bool] = Field(default=False) def get_tokenizer_messages(self): return [m.to_tokenizer_message() for m in self.messages] class ChatCompletionChunk(BaseModel): id: str choices: List[Choice] created: int model: str object: Literal["chat.completion.chunk"] service_tier: Optional[Literal["scale", "default"]] = None system_fingerprint: Optional[str] = None usage: Optional[CompletionUsage] = None def to_stream_reply(self): return f"data: {self.model_dump_json()}\n\n" class RawUsage(BaseModel): tokenize_time: float prefill_time: float decode_time: float prefill_count: int decode_count: int ================================================ FILE: archive/ktransformers/server/schemas/legacy/__init__.py ================================================ ================================================ FILE: archive/ktransformers/server/schemas/legacy/completions.py ================================================ from typing import List, Optional from enum import Enum from pydantic import BaseModel, Field from ktransformers.server.config.config import Config from ..base import Object class CompletionCreate(BaseModel): model: str prompt: str | List[str] stream: bool = False temperature: Optional[float] = Field(default=Config().temperature) top_p: Optional[float] = Field(default=Config().top_p) max_tokens: Optional[int] = Field(default=None) max_completion_tokens: Optional[int] = Field(default=None) def get_tokenizer_messages(self): if isinstance(self.prompt,List): self.get_tokenizer_messages('\n'.join(self.prompt)) return [{'content':self.prompt,'role':'user'}] class FinishReason(Enum): stop = 'stop' length = 'length' class Choice(BaseModel): index: int text: str logprobs: Optional[str] = None finish_reason: FinishReason = None class CompletionObject(Object): created:int choices: List[Choice] = [] model:str = 'not implmented' system_fingerprint:str = 'not implmented' usage: Optional[str] = None def set_token(self,token:str): if len(self.choices)==0: self.choices.append(Choice(index=0,text='')) self.choices[0].text = token def append_token(self,token:str): if len(self.choices)==0: self.choices.append(Choice(index=0,text='')) self.choices[0].text += token def to_stream_reply(self): return f"data:{self.model_dump_json()}\n\n" ================================================ FILE: archive/ktransformers/server/utils/__init__.py ================================================ ================================================ FILE: archive/ktransformers/server/utils/create_interface.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : qiyuxinlin Date : 2024-07-25 11:50:16 Version : 1.0.0 LastEditors : qiyuxinlin LastEditTime : 2024-07-25 12:54:48 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' from ktransformers.server.config.config import Config from ktransformers.server.backend.args import ConfigArgs from ktransformers.server.backend.context_manager import ThreadContextManager from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface from ktransformers.server.backend.interfaces.transformers import TransformersInterface from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface def create_interface(config: Config, default_args: ConfigArgs, input_args=None): if config.backend_type=='transformers': from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface elif config.backend_type == 'exllamav2': from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface elif config.backend_type == 'ktransformers': from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface elif config.backend_type == 'balance_serve': from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface as BackendInterface else: raise NotImplementedError(f'{config.backend_type} not implemented') if config.backend_type == 'ktransformers': GlobalInterface.interface = BackendInterface(default_args, input_args) elif config.backend_type == 'balance_serve': GlobalInterface.interface = BackendInterface(default_args, input_args) else: GlobalInterface.interface = BackendInterface(default_args) GlobalContextManager.context_manager = ThreadContextManager(GlobalInterface.interface) class GlobalContextManager: context_manager: ThreadContextManager class GlobalInterface: interface: TransformersInterface | KTransformersInterface | ExllamaInterface def get_thread_context_manager() -> GlobalContextManager: return GlobalContextManager.context_manager def get_interface() -> GlobalInterface: return GlobalInterface.interface ================================================ FILE: archive/ktransformers/server/utils/multi_timer.py ================================================ import time def format_time(seconds): units = [ ("hours", 3600), ("minutes", 60), ("seconds", 1), ("milliseconds", 1e-3), ("microseconds", 1e-6), ] for unit_name, unit_value in units: if seconds >= unit_value: time_value = seconds / unit_value return f"{time_value:.2f} {unit_name}" return "0 seconds" # Handle case for 0 seconds class Profiler: def __init__(self): self.timers = {} self.counters = {} def create_timer(self, name): self.timers[name] = { "start_time": None, "elapsed_time": 0, "running": False, } def start_timer(self, name): if name not in self.timers: raise ValueError(f"Timer '{name}' does not exist.") if self.timers[name]["running"]: raise ValueError(f"Timer '{name}' is already running.") self.timers[name]["start_time"] = time.time() self.timers[name]["running"] = True def pause_timer(self, name): if name not in self.timers: raise ValueError(f"Timer '{name}' does not exist.") if not self.timers[name]["running"]: raise ValueError(f"Timer '{name}' is not running.") self.timers[name]["elapsed_time"] += time.time() - self.timers[name]["start_time"] self.timers[name]["running"] = False def get_timer_sec(self, name): if name not in self.timers: raise ValueError(f"Timer '{name}' does not exist.") if self.timers[name]["running"]: current_time = self.timers[name]["elapsed_time"] + (time.time() - self.timers[name]["start_time"]) else: current_time = self.timers[name]["elapsed_time"] return current_time def get_all_timers(self): all_timers = {} for name in self.timers: all_timers[name] = self.get_timer_sec(name) return all_timers def report_timer_string(self, name): return f"{name} elapsed time: {format_time(self.get_timer_sec(name))}" def create_and_start_timer(self, name): self.create_timer(name) self.start_timer(name) # Counter def inc(self,key:str,delta:int=1): self.counters[key] = self.counters.get(key,0) + delta def set_counter(self,key:str,to=0): self.counters[key] = to def get_counter(self,key:str): return self.counters.get(key,0) ================================================ FILE: archive/ktransformers/server/utils/serve_profiling.py ================================================ import re import itertools import time import enum import math from enum import StrEnum class ProfStatKey(StrEnum): ExpertsSummitCurrLayer = "ExpertsSummitCurrLayer" ExpertsSummitNextLayer = "ExpertsSummitNextLayer" ExpertsCPUForwardOne = "ExpertsCPUForwardOne" ExpertsCPUForwardTwo = "ExpertsCPUForwardTwo" CPUMoEKExpertsCallback = "CPUMoEKExpertsCallback" class ProfTimeStat: def __init__(self): # open_status = os.environ["KT_PERF_STAT"] if "KT_PERF_STAT" in os.environ else "0" # if open_status == "0": # self.on = False # else: # self.on = True self.on = False self.prefill_stats = dict() self.decode_stats = dict() for key in ProfStatKey: self.prefill_stats[key] = ProfStatItem() self.decode_stats[key] = ProfStatItem() self.reset_all() def record_start_time(self): start_time = time.time_ns() return start_time def add_time_stat(self, key: ProfStatKey, time_ns, is_prefill): if not key: return # torch.cuda.synchronize() cost = time.time_ns() - time_ns if is_prefill: item = self.prefill_stats[key] else: item = self.decode_stats[key] item.add_item(cost) def print_all(self): # rank = f"[rank:{torch.distributed.get_rank()}]" rank = f"[rank:0]" msg = f"\n{rank} Prefill Time Stat\n" msg += rank + " {:27}{:>15}{:>15}{:>15}{:>15}{:>15}{:>15}{:>15}\n".format("", "min(ms)", "max(ms)", "avg(ms)", "count", "total(ms)", ">2ms", ">10ms") for key, value in self.prefill_stats.items(): msg += rank + f" {key.value:<25}:{value.get_stat()}\n" msg += f"\n{rank} Decode Time Stat\n" msg += rank + " {:27}{:>15}{:>15}{:>15}{:>15}{:>15}{:>15}{:>15}\n".format("", "min(ms)", "max(ms)", "avg(ms)", "count", "total(ms)", ">2ms", ">10ms") for key, value in self.decode_stats.items(): msg += rank + f" {key.value:<25}:{value.get_stat()}\n" print(msg) def reset_all(self): for _, value in self.prefill_stats.items(): value.reset() for _, value in self.decode_stats.items(): value.reset() class ProfStatItem: def __init__(self): self.min_time = 100000000 self.max_time = 0 self.total_time_ns = 0 self.count = 0 self.err_time = [] self.ms_count2 = 0 self.ms_count10 = 0 def add_item(self, cost_time_ns): self.count += 1 self.total_time_ns += cost_time_ns self.min_time = min(self.min_time, cost_time_ns) self.max_time = max(self.max_time, cost_time_ns) if (cost_time_ns > 2000000): # self.err_time.append(round(cost_time_ns / 1000 / 1000, 2)) self.ms_count2 += 1 if (cost_time_ns > 10000000): # self.err_time.append(round(cost_time_ns / 1000 / 1000, 2)) self.ms_count10 += 1 # self.err_time.append(round(cost_time_ns / 1000 / 1000, 2)) def reset(self): self.min_time = 100000000 self.max_time = 0 self.total_time_ns = 0 self.count = 0 def get_stat(self): min_time = self.min_time / 1000 / 1000 max_time = self.max_time / 1000 / 1000 if self.count != 0: avg_time = self.total_time_ns / self.count / 1000 / 1000 else: avg_time = 0 total = self.total_time_ns / 1000 / 1000 # tmpstr = str(self.err_time) # print(f"\r\n err_time: {tmpstr} \r\n ") return f"{min_time:15.2f}{max_time:15.2f}{avg_time:15.2f}{self.count:15}{total:15.2f}{self.ms_count2:>15}{self.ms_count10:>15}" PROF_TIME_STAT = ProfTimeStat() ================================================ FILE: archive/ktransformers/server/utils/sql_utils.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : chenxl Date : 2024-06-12 09:12:58 Version : 1.0.0 LastEditors : chenxl LastEditTime : 2024-07-27 01:56:04 ''' from urllib.parse import urlparse import os from contextlib import contextmanager from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker, declarative_base from ktransformers.server.config.config import Config from ktransformers.server.config.singleton import Singleton from ktransformers.server.config.log import logger from ktransformers.server.exceptions import db_exception Base = declarative_base() class SQLUtil(metaclass=Singleton): """ database connections init and management """ sqlalchemy_engine = None session_local = None def __init__(self) -> None: self.cfg: Config = Config() if not self.sqlalchemy_engine: SQLUtil.init_engine(self.cfg) @contextmanager def get_db(self): """ After you finish using the session, it's crucial to close it. """ if not SQLUtil.sqlalchemy_engine: SQLUtil.init_engine(self.cfg) session = self.session_local() # type: ignore pylint: disable=not-callable try: yield session finally: session.close() @staticmethod def init_engine(cfg: Config): """ initial engine and session maker Factory """ pool_size = cfg.db_pool_size if SQLUtil.sqlalchemy_engine is None: if cfg.db_type == "sqllite": db_url = SQLUtil.create_sqllite_url(cfg) else: logger.error("Unsupported database type %s", cfg.db_type) exit(-1) SQLUtil.sqlalchemy_engine = create_engine( db_url, connect_args={"check_same_thread": False}, pool_size=pool_size) SQLUtil.session_local = sessionmaker( autocommit=False, autoflush=False, bind=SQLUtil.sqlalchemy_engine) @staticmethod def create_sqllite_url(cfg): """ create and validate SQLLite url """ path: str = cfg.db_host database: str = cfg.db_database absolute_path: str = os.path.join(path, database) url = 'sqlite:///' + absolute_path try: result = urlparse(url) if all([result.scheme, result.path, result.scheme == 'sqlite']): return url else: logger.error("invalid sqllite url: %s", url) exit(-1) except ValueError: logger.error("invalid sqllite url: %s", url) exit(-1) def db_add_commit_refresh(self, session: Session, what): """ add data to database """ try: session.add(what) session.commit() session.refresh(what) except Exception as e: logger.exception("db commit error with data %s", str(what.__dict__)) ex = db_exception() ex.detail = str(e) session.rollback() raise ex from e def db_merge_commit(self, session: Session, what): try: session.merge(what) session.commit() except Exception as e: ex = db_exception() ex.detail = str(e) logger.exception("db merge commit error with data %s", str(what.__dict__)) session.rollback() raise ex from e def db_update_commit_refresh(self, session: Session, existing, what): what = what.model_dump(mode="json") try: for key in what.keys(): if what[key] is not None: # 检查b中的字段是否为None setattr(existing, key, what[key]) # 更新a的字段 session.commit() session.refresh(existing) except Exception as e: ex = db_exception() ex.detail = str(e) logger.exception("db update commit refresh error with data %s", str(what.__dict__)) session.rollback() raise ex from e ================================================ FILE: archive/ktransformers/tests/.gitignore ================================================ results/ ================================================ FILE: archive/ktransformers/tests/AIME_2024/eval_api.py ================================================ # adapt from https://github.com/abacaj/code-eval?tab=readme-ov-file import argparse import json import os import time import requests import tqdm from evaluation import filter_answer from prompts import instruct_prompt import pandas as pd from datasets import load_dataset os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' def generate_text(api_url,question , model_name, stream=False, auth_token=None): headers = { 'accept': 'application/json', 'Content-Type': 'application/json', # 添加 API Key 'Authorization' : 'Bearer ' + auth_token if auth_token else '' } question = instruct_prompt(question) data = { "messages": [{"content": question, "role": "user"}], "model": model_name, "stream": stream, "temperature": 0.6, "max_tokens": 10240, } print(f"content: {question}") response = requests.post(api_url, headers=headers, json=data,verify=False) if response.status_code == 200: result = response.json() results = result.get('choices', [{}])[0].get('message', {}).get('content', '') return filter_answer(results) else: print(f"API Request failed with status code {response.status_code}") return None def load_data(file_path): """ Load data from a Parquet file into a list. Each record in the Parquet file should represent an individual record. """ # 读取 Parquet 文件 # dataset = load_dataset('parquet', data_files=file_path) data = [] ds = load_dataset(file_path) df = pd.DataFrame(ds['train']) for _, row in df.iterrows(): data.append(row.to_dict()) return data def get_score(pred, answer): """ Calculate scores between the prediction and the answer. Uses ROUGE scores as the evaluation metric. :param pred: The predicted string. :param answer: The reference answer string. :return: A dictionary containing ROUGE scores. """ if pred == answer: return 1 # if we need to compare str with number, convert teh str to number try: pred = float(pred) answer = float(answer) except: pass if pred == answer: return 1 return 0 def run_eval_api( api_url: str, model_name: str, out_path: str, format_tabs: bool = False, auth_token: str = None, problem_file: str = None, append: bool = False, skip: int = 0 ): data = load_data(problem_file) pbar = tqdm.tqdm(total=len(data) * 1) pbar.update(skip) for i in range(len(data)): i = i+skip data_item = data[i] question = data_item['Problem'] # Start the timer for this evaluation start_time = time.time() try: completion = generate_text(api_url, question, model_name, auth_token=auth_token) if completion is None: raise Exception(f"Failed to get prediction for {question}") answer = data_item['Answer'] score = get_score(completion, answer) elapsed_time = time.time() - start_time result = { "index": i, "question_id": data_item["ID"], "answer": answer, "prediction": completion, "score": score, "time": elapsed_time } with open(out_path, "a" if append else "w") as f: f.write(json.dumps(result) + "\n") except Exception as e: print(f"Failed to get prediction for {question}") print(e) continue pbar.update(1) def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append,skip): os.makedirs(os.path.dirname(output_path), exist_ok=True) run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append,skip) if __name__ == "__main__": parser = argparse.ArgumentParser(description="API Generate Tester") parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL") parser.add_argument("--model_name", type=str, default="Pro/deepseek-ai/DeepSeek-R1", help="Model Name") parser.add_argument("--out_path", type=str, default="results/api/eval_aime.jsonl", help="Output Path") parser.add_argument("--auth_token", type=str, default=None, help="Auth Token") parser.add_argument("--format_tabs", action="store_true", help="Format Tabs") parser.add_argument("--problem_file", type=str, default="Maxwell-Jia/AIME_2024", help="Evalset File") parser.add_argument("--no_append", action="store_false", help="Append to existing file") parser.add_argument("--skip", type=int, default=0, help="Skip some tasks") args = parser.parse_args() # api_url = "https://api.siliconflow.cn/v1/chat/completions" main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append, args.skip) ================================================ FILE: archive/ktransformers/tests/AIME_2024/evaluation.py ================================================ # reference: https://github.com/declare-lab/instruct-eval/blob/main/human_eval/main.py#L35 def filter_answer(completion: str) -> str: # the answer is the last part of the completion, it's a int64 number # get the last line completion = completion.strip().split("\n")[-1] # handle the $\\boxed{...}$ format if "$\\boxed{" in completion: return completion.split("}")[0].split("{")[-1] return completion.split()[-1] ================================================ FILE: archive/ktransformers/tests/AIME_2024/prompts.py ================================================ def instruct_prompt(prompt: str) -> str: return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nSolve the following math problem without any tests or explanation only one answer surrounede by '$\\boxed{{}}$'\n{prompt}\n\n### Response:""" ================================================ FILE: archive/ktransformers/tests/UT/test_kdeepseek_attention_w8a8a2serve_npu.py ================================================ import sys import types import torch import torch.nn as nn import pytest torch_npu = pytest.importorskip("torch_npu") from ktransformers.operators.ascend.ascend_attention import ( KDeepseekV2AttentionW8A8A2Serve, ) import ktransformers.operators.ascend.ascend_attention as attn_mod class DummyConfig: def __init__(self, hidden_size=4, num_attention_heads=1): self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads class DummyOrigAttn(nn.Module): def __init__(self, config=None, layer_idx=0): super().__init__() self.config = config self.layer_idx = layer_idx hidden_dim = config.hidden_size if config is not None else 4 self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) self.kv_a_proj_with_mqa = None self.kv_a_layernorm = nn.LayerNorm(2) self.o_proj = None class DummyDynamicQuantOps: def execute(self, inputs): x = inputs[0] return [x] class DummyMatMulOps: def execute(self, inputs): x = inputs[0] return [x] class DummyQuantProj(nn.Module): def __init__(self, dim): super().__init__() self.input_scale = torch.tensor(1.0, dtype=torch.float16) self.input_offset = torch.tensor(0.0, dtype=torch.float16) self.weight = nn.Parameter(torch.zeros(dim, dim, dtype=torch.float16)) self.quant_bias = torch.zeros(dim, dtype=torch.float16) self.deq_scale = torch.tensor(1.0, dtype=torch.float16) class DummyStaticCache: def __init__(self, page_size=16): self.page_size = page_size def get_usable_length(self, kv_seq_len, layer_idx): return 0 def update(self, combined, layer_idx, cache_kwargs): return combined, None class DummyNpuFusedAttention: def __call__(self, q, k, v, **kwargs): bsz, max_q_len, num_heads, dim = q.shape out = torch.zeros( bsz, max_q_len, num_heads, dim, dtype=q.dtype, device=q.device ) softmax_lse = torch.zeros(1, dtype=q.dtype, device=q.device) return out, softmax_lse def out(self, q, k, v, workspace=None, query_rope=None, key_rope=None, num_heads=None, num_key_value_heads=None, input_layout=None, scale=None, antiquant_mode=None, antiquant_scale=None, block_table=None, block_size=None, actual_seq_lengths_kv=None, sparse_mode=None, out=None): attn_output, softmax_lse = out attn_output.zero_() softmax_lse.zero_() return attn_output, softmax_lse class DummyOpsNpu: def npu_fused_infer_attention_score(self, q, k, v, **kwargs): bsz, num_heads, q_len, dim = q.shape out = torch.zeros( bsz, num_heads, q_len, dim, dtype=q.dtype, device=q.device ) softmax_lse = torch.zeros(1, dtype=q.dtype, device=q.device) return out, softmax_lse def fake_apply_rotary_pos_emb_fusion(q_pe, k_pe, cos, sin): return q_pe, k_pe def build_attention_module(q_lora_rank=None): if hasattr(attn_mod, "get_tensor_parallel_size"): attn_mod.get_tensor_parallel_size = lambda: 1 # type: ignore config = DummyConfig(hidden_size=4, num_attention_heads=1) orig = DummyOrigAttn(config=config, layer_idx=0) attn = KDeepseekV2AttentionW8A8A2Serve( key="test", gguf_loader=None, config=config, orig_module=orig, prefill_device="npu", generate_device="npu", ) hidden_dim = 4 num_heads = 1 qk_nope_head_dim = 2 qk_rope_head_dim = 2 q_head_dim = qk_nope_head_dim + qk_rope_head_dim # 4 kv_lora_rank = 2 v_head_dim = 2 attn.num_heads = num_heads attn.q_head_dim = q_head_dim attn.qk_nope_head_dim = qk_nope_head_dim attn.qk_rope_head_dim = qk_rope_head_dim attn.kv_lora_rank = kv_lora_rank attn.v_head_dim = v_head_dim attn.softmax_scale = 1.0 attn.layer_idx = 0 attn.sparse_mode = 0 attn.q_lora_rank = q_lora_rank attn.elewise_quant = DummyDynamicQuantOps() attn.matmulDequant_operation = DummyMatMulOps() attn.matmulDequant_operation_aclnn = DummyMatMulOps() orig_mod = attn.orig_module if q_lora_rank is None: orig_mod.q_proj = nn.Linear(hidden_dim, num_heads * q_head_dim, bias=False) orig_mod.q_proj = orig_mod.q_proj.to(dtype=torch.float16) else: orig_mod.q_a_proj = DummyQuantProj(hidden_dim) orig_mod.q_b_proj = DummyQuantProj(hidden_dim) orig_mod.q_a_layernorm = nn.LayerNorm(hidden_dim) orig_mod.kv_a_proj_with_mqa = DummyQuantProj(hidden_dim) orig_mod.kv_a_layernorm = nn.LayerNorm(kv_lora_rank) orig_mod.o_proj = DummyQuantProj(num_heads * v_head_dim) attn.q_absorb = torch.randn( num_heads, qk_nope_head_dim, kv_lora_rank, dtype=torch.float16 ) attn.out_absorb = torch.randn( num_heads, kv_lora_rank, v_head_dim, dtype=torch.float16 ) def fake_rotary_emb(q_pe, position_ids): bsz, n_heads, q_len, dim = q_pe.shape cos = torch.ones(1, 1, q_len, dim, dtype=q_pe.dtype, device=q_pe.device) sin = torch.zeros(1, 1, q_len, dim, dtype=q_pe.dtype, device=q_pe.device) return cos, sin attn.rotary_emb = fake_rotary_emb return attn @pytest.fixture(autouse=True) def _patch_env(monkeypatch): if hasattr(attn_mod, "apply_rotary_pos_emb_fusion"): monkeypatch.setattr( attn_mod, "apply_rotary_pos_emb_fusion", fake_apply_rotary_pos_emb_fusion ) if hasattr(attn_mod, "get_use_npu_graph"): monkeypatch.setattr(attn_mod, "get_use_npu_graph", lambda: False) if hasattr(attn_mod, "get_tensor_parallel_size"): monkeypatch.setattr(attn_mod, "get_tensor_parallel_size", lambda: 1) if hasattr(attn_mod, "get_tensor_parallel_group"): monkeypatch.setattr(attn_mod, "get_tensor_parallel_group", lambda: None) if hasattr(attn_mod, "get_current_device"): monkeypatch.setattr(attn_mod, "get_current_device", lambda: "cpu") # torch.distributed.barrier -> no-op if hasattr(torch, "distributed") and hasattr(torch.distributed, "barrier"): monkeypatch.setattr( torch.distributed, "barrier", lambda *args, **kwargs: None, raising=False, ) dummy_op = DummyNpuFusedAttention() monkeypatch.setattr( torch_npu, "npu_fused_infer_attention_score", dummy_op, raising=False ) def fake_get_workspace(q, k, v, **kwargs): return torch.empty(1, dtype=q.dtype, device=q.device) monkeypatch.setattr( torch_npu, "_npu_fused_infer_attention_score_get_max_workspace", fake_get_workspace, raising=False ) monkeypatch.setattr(torch.ops, "npu", DummyOpsNpu(), raising=False) yield # ========================== # 测试用例 # ========================== def test_print_callback_smoke(): attn = build_attention_module() bsz, q_len, hidden_dim = 1, 3, 4 hidden_states = torch.randn(bsz, q_len, hidden_dim) position_ids = torch.arange(q_len).unsqueeze(0) cache_position = torch.arange(q_len).unsqueeze(0) page_idx = torch.zeros(bsz, dtype=torch.int32) page_offset = torch.zeros(bsz, dtype=torch.int32) block_table = torch.zeros(bsz, 1, dtype=torch.int32) attn.print_callback( (hidden_states, position_ids, cache_position, page_idx, page_offset, block_table) ) def _common_inputs_prefill(): bsz, q_len, hidden_dim = 1, 3, 4 hidden_states = torch.randn(bsz, q_len, hidden_dim, dtype=torch.float16) attention_mask = torch.zeros(bsz, 1, q_len, q_len, dtype=torch.float32) position_ids = torch.arange(q_len).unsqueeze(0) cache_position = torch.arange(q_len).unsqueeze(0) page_idx = torch.zeros(bsz, dtype=torch.int32) page_offset = torch.zeros(bsz, dtype=torch.int32) block_table = torch.zeros(bsz, 1, dtype=torch.int32) past_key_value = DummyStaticCache(page_size=16) q_len_raw = torch.tensor([q_len], dtype=torch.int32) kv_len_raw = torch.tensor([q_len], dtype=torch.int32) return ( hidden_states, attention_mask, position_ids, cache_position, page_idx, page_offset, block_table, past_key_value, q_len_raw, kv_len_raw ) def test_forward_prefill_with_mask(): """ is_prefill=True + attention_mask 不为 None + past_key_value 不为 None """ attn = build_attention_module(q_lora_rank=None) (hidden_states, attention_mask, position_ids, cache_position, page_idx, page_offset, block_table, past_key_value, q_len_raw, kv_len_raw) = _common_inputs_prefill() outputs = attn.forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=False, use_cache=True, cache_position=cache_position, is_prefill=True, page_idx=page_idx, page_offset=page_offset, block_table=block_table, q_len_raw=q_len_raw, kv_len_raw=kv_len_raw, stream=None, ) attn_output, attn_weights, new_cache = outputs assert attn_output.shape == ( 1, # bsz 3, # q_len attn.num_heads * attn.v_head_dim, ) assert attn_weights is None assert new_cache is past_key_value def test_forward_prefill_without_mask_and_q_lora(): """ is_prefill=True + attention_mask=None + q_lora_rank 非 None 分支 """ attn = build_attention_module(q_lora_rank=1) (hidden_states, attention_mask, position_ids, cache_position, page_idx, page_offset, block_table, past_key_value, q_len_raw, kv_len_raw) = _common_inputs_prefill() outputs = attn.forward( hidden_states=hidden_states, attention_mask=None, position_ids=position_ids, past_key_value=past_key_value, output_attentions=False, use_cache=True, cache_position=cache_position, is_prefill=True, page_idx=None, page_offset=None, block_table=None, q_len_raw=q_len_raw, kv_len_raw=kv_len_raw, stream=None, ) attn_output, attn_weights, new_cache = outputs assert attn_output.shape == ( 1, 3, attn.num_heads * attn.v_head_dim, ) assert attn_weights is None assert new_cache is past_key_value def test_forward_decode_paged_path(): """ is_prefill=False + get_use_npu_graph=False => 走 forward_paged + torch.ops.npu.npu_fused_infer_attention_score 分支 """ attn = build_attention_module(q_lora_rank=None) bsz, q_len, hidden_dim = 1, 1, 4 hidden_states = torch.randn(bsz, q_len, hidden_dim, dtype=torch.float16) position_ids = torch.arange(q_len).unsqueeze(0) cache_position = torch.arange(q_len).unsqueeze(0) past_key_value = DummyStaticCache(page_size=16) q_len_raw = torch.tensor([q_len], dtype=torch.int32) kv_len_raw = torch.tensor([q_len], dtype=torch.int32) block_table = torch.zeros(bsz, 1, dtype=torch.int32) outputs = attn.forward( hidden_states=hidden_states, attention_mask=None, position_ids=position_ids, past_key_value=past_key_value, output_attentions=False, use_cache=True, cache_position=cache_position, is_prefill=False, page_idx=None, page_offset=None, block_table=block_table, q_len_raw=q_len_raw, kv_len_raw=kv_len_raw, stream=None, ) attn_output, attn_weights, new_cache = outputs assert attn_output.shape == ( bsz, q_len, attn.num_heads * attn.v_head_dim, ) assert attn_weights is None assert new_cache is past_key_value def test_forward_prefill_layer_idx_none_raises(): """ 覆盖: past_key_value 不为 None 且 layer_idx 为 None 的异常分支。 """ attn = build_attention_module(q_lora_rank=None) attn.layer_idx = None # 手动破坏 layer_idx (hidden_states, attention_mask, position_ids, cache_position, page_idx, page_offset, block_table, past_key_value, q_len_raw, kv_len_raw) = _common_inputs_prefill() with pytest.raises(ValueError): attn.forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=False, use_cache=True, cache_position=cache_position, is_prefill=True, page_idx=page_idx, page_offset=page_offset, block_table=block_table, q_len_raw=q_len_raw, kv_len_raw=kv_len_raw, stream=None, ) def test_forward_prefill_attn_output_shape_mismatch_raises(monkeypatch): """ 覆盖: attn_output 形状不符合期望时的 ValueError 分支。 """ attn = build_attention_module(q_lora_rank=None) def bad_fused(q, k, v, **kwargs): bsz, max_q_len, num_heads, dim = q.shape # 刻意制造 num_heads+1,触发 size 检查不通过 out = torch.zeros( bsz, max_q_len, num_heads + 1, attn.v_head_dim, dtype=q.dtype, device=q.device ) lse = torch.zeros(1, dtype=q.dtype, device=q.device) return out, lse monkeypatch.setattr( torch_npu, "npu_fused_infer_attention_score", bad_fused, raising=False ) (hidden_states, attention_mask, position_ids, cache_position, page_idx, page_offset, block_table, past_key_value, q_len_raw, kv_len_raw) = _common_inputs_prefill() with pytest.raises(ValueError): attn.forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=False, use_cache=True, cache_position=cache_position, is_prefill=True, page_idx=page_idx, page_offset=page_offset, block_table=block_table, q_len_raw=q_len_raw, kv_len_raw=kv_len_raw, stream=None, ) def test_forward_paged_use_npu_graph(monkeypatch): """ 覆盖: get_use_npu_graph() == True 的 graph 路径。 """ # 让 ascend_attention.get_use_npu_graph 返回 True monkeypatch.setattr(attn_mod, "get_use_npu_graph", lambda: True) # 伪造 model_runner 模块,满足 import ktransformers.server.balance_serve.inference.model_runner dummy_runner = type( "DummyRunner", (), {"__init__": lambda self: setattr(self, "workspace", [None] * 4)} ) dummy_mr = types.SimpleNamespace( ModelRunner=dummy_runner, get_or_create_model_runner=lambda device=None: dummy_runner(), ) sys.modules[ "ktransformers.server.balance_serve.inference.model_runner" ] = dummy_mr attn = build_attention_module(q_lora_rank=None) bsz, q_len, hidden_dim = 1, 1, 4 hidden_states = torch.randn(bsz, q_len, hidden_dim, dtype=torch.float16) position_ids = torch.arange(q_len).unsqueeze(0) cache_position = torch.arange(q_len).unsqueeze(0) past_key_value = DummyStaticCache(page_size=16) q_len_raw = torch.tensor([q_len], dtype=torch.int32) kv_len_raw = torch.tensor([q_len], dtype=torch.int32) block_table = torch.zeros(bsz, 1, dtype=torch.int32) outputs = attn.forward( hidden_states=hidden_states, attention_mask=None, position_ids=position_ids, past_key_value=past_key_value, output_attentions=False, use_cache=True, cache_position=cache_position, is_prefill=False, page_idx=None, page_offset=None, block_table=block_table, q_len_raw=q_len_raw, kv_len_raw=kv_len_raw, stream=None, ) attn_output, attn_weights, new_cache = outputs assert attn_output.shape == ( bsz, q_len, attn.num_heads * attn.v_head_dim, ) assert attn_weights is None assert new_cache is past_key_value ================================================ FILE: archive/ktransformers/tests/UT/test_kdeepseek_ln_npu.py ================================================ import torch import torch.nn as nn import pytest # 按你实际代码位置改路径: from ktransformers.operators.ascend.ascend_layernorm import KDeepseekV3RMSNormW8A8 import ktransformers.util.utils as utils_mod torch_npu = pytest.importorskip("torch_npu") # ========================== # Dummy 依赖 # ========================== class DummyOrigModule(nn.Module): def __init__(self, hidden_size=4, variance_epsilon=1e-5): super().__init__() self.hidden_size = hidden_size self.variance_epsilon = variance_epsilon class DummySafeTensorLoader: def __init__(self): self.tensors = {} self.load_calls = [] def load_tensor(self, name: str): self.load_calls.append(name) return self.tensors[name] class DummyGGUFLoader: def __init__(self, safetensor_loader: DummySafeTensorLoader): self.safetensor_loader = safetensor_loader class DummyConfig: pass class FakeRMSNorm: def __init__(self): self.last_args = None def __call__(self, hidden_states, weight, eps): self.last_args = (hidden_states, weight, eps) out = hidden_states * weight return (out,) def build_rms_module(hidden_size=4, eps=1e-5, safetensor_loader=None): orig = DummyOrigModule(hidden_size=hidden_size, variance_epsilon=eps) if safetensor_loader is None: safetensor_loader = DummySafeTensorLoader() gguf_loader = DummyGGUFLoader(safetensor_loader) config = DummyConfig() module = KDeepseekV3RMSNormW8A8( key="rms", gguf_loader=gguf_loader, config=config, orig_module=orig, prefill_device="npu", generate_device="npu", ) return module, safetensor_loader, orig @pytest.fixture(autouse=True) def patch_utils_and_npu(monkeypatch): monkeypatch.setattr(utils_mod, "get_current_device", lambda: "cpu", raising=False) fake = FakeRMSNorm() monkeypatch.setattr(torch_npu, "npu_rms_norm", fake, raising=False) import sys sys.modules[__name__]._fake_rms = fake yield def get_fake_rms(): import sys return sys.modules[__name__]._fake_rms def test_forward_preserves_shape_and_dtype(): hidden_size = 4 module, _, orig = build_rms_module(hidden_size=hidden_size, eps=1e-6) x = torch.randn(2, 3, hidden_size, dtype=torch.float16) out = module(x) assert out.shape == x.shape assert out.dtype == x.dtype fake_rms = get_fake_rms() hs_arg, w_arg, eps_arg = fake_rms.last_args assert hs_arg is x assert w_arg is module.weight assert eps_arg == orig.variance_epsilon def test_forward_with_bfloat16_dtype(): hidden_size = 4 module, _, _ = build_rms_module(hidden_size=hidden_size, eps=1e-6) x = torch.randn(1, 2, hidden_size, dtype=torch.bfloat16) out = module(x) assert out.shape == x.shape assert out.dtype == torch.bfloat16 def test_forward_uses_bias(): hidden_size = 4 module, _, _ = build_rms_module(hidden_size=hidden_size, eps=1e-6) module.weight.data = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) module.bias.data = torch.tensor([-1.0, 0.5, 0.0, 2.0], dtype=torch.float32) x = torch.arange(2 * 3 * hidden_size, dtype=torch.float16).view(2, 3, hidden_size) out = module(x) expected_rms = x.to(torch.float32) * module.weight expected = expected_rms + module.bias assert torch.allclose(out, expected.to(out.dtype)) def test_load_from_safetensor_loader(): hidden_size = 4 module, safe_loader, _ = build_rms_module(hidden_size=hidden_size, eps=1e-5) w_loaded = torch.arange(hidden_size, dtype=torch.float32) b_loaded = torch.full((hidden_size,), 3.0, dtype=torch.float32) safe_loader.tensors["rms.weight"] = w_loaded safe_loader.tensors["rms.bias"] = b_loaded module.load() assert torch.allclose(module.weight, w_loaded) assert torch.allclose(module.bias, b_loaded) assert safe_loader.load_calls == ["rms.weight", "rms.bias"] def test_unload_sets_weight_and_bias_to_none_idempotent(): module, _, _ = build_rms_module(hidden_size=4, eps=1e-5) assert module.weight is not None assert module.bias is not None module.unload() assert module.weight is None assert module.bias is None module.unload() assert module.weight is None assert module.bias is None ================================================ FILE: archive/ktransformers/tests/dequant_gpu.py ================================================ import os # os.environ["CUDA_VISIBLE_DEVICES"]="1,2" # add path import sys current_path = os.path.abspath(os.path.dirname(__file__)) sys.path.append(current_path+"/../..") import numpy as np # from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin # from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch from ktransformers.util.custom_gguf import GGUFLoader import torch import KTransformersOps torch.set_default_dtype(torch.bfloat16) import time from transformers import ( AutoConfig, ) import os # CUDA_LAUNCH_BLOCKING=1 os.environ["CUDA_LAUNCH_BLOCKING"]="1" gguf_config = GGUFLoader("/data/Qwen2-57B-A14B-Instruct-GGUF/q4_k_m") model_name = "/data/Qwen2-57B-A14B-Instruct" # Q4k key = "blk.1." target = "attn_q.weight" t1 = time.time() q_weight_cpu = gguf_config.load_gguf_tensor(key+target, "cpu") # q_weight_cpu = torch.from_numpy(q_weight_cpu) t2 = time.time() q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda:0") t3 = time.time() print() allclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu(), atol=1e-6) print(f"Q4k {key+target}") print("load gguf tensor from cpu cost: ", t2-t1) print("load gguf tensor from gpu cost: ", t3-t2) print("allclose: ", allclose) # Q6k key = "blk.0." target = "ffn_down_exps.weight" t1 = time.time() q_weight_cpu = gguf_config.load_gguf_tensor(key+target, "cpu") t2 = time.time() q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda:0") t3 = time.time() print() allclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu().to(torch.float32), atol=1e-6) print(f"Q6k {key+target}") print("load gguf tensor from cpu cost: ", t2-t1) print("load gguf tensor from gpu cost: ", t3-t2) print("allclose: ", allclose) ================================================ FILE: archive/ktransformers/tests/dequant_gpu_t.py ================================================ import os os.environ["CUDA_VISIBLE_DEVICES"]="1" # add path import sys sys.path.append("../..") import pycuda.autoinit import pycuda.driver as cuda from pycuda.compiler import SourceModule import numpy as np from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch from ktransformers.util.custom_loader import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k import torch import KTransformersOps torch.set_default_dtype(torch.bfloat16) import time from transformers import ( AutoConfig, ) gguf_config = GGUFLoader("/data/Qwen2-57B-A14B-Instruct-GGUF/q4_k_m") model_name = "/data/Qwen2-57B-A14B-Instruct" key = "blk.0." target = "ffn_up_exps.weight" data = gguf_config.get_mmap_tensor(key + target) _, factors, offsets, qs1, qs2= dequantize_q4_k(data) factors_cpu = torch.from_numpy(factors) offsets_cpu = torch.from_numpy(offsets) qs1_cpu = torch.from_numpy(qs1) qs2_cpu = torch.from_numpy(qs2) _, factors, offsets, qs1, qs2 = dequantize_q4_k_gpu(data) print(torch.allclose(factors.cpu(), factors_cpu)) print(torch.allclose(offsets.cpu(), offsets_cpu)) print(torch.allclose(qs1.cpu(), qs1_cpu)) print(torch.allclose(qs2.cpu(), qs2_cpu)) ================================================ FILE: archive/ktransformers/tests/function_call_test.py ================================================ from openai import OpenAI def send_messages(messages): response = client.chat.completions.create( model="deepseek-chat", messages=messages, tools=tools ) return response.choices[0].message client = OpenAI( api_key="placeholder", base_url="http://0.0.0.0:10002/v1", ) tools = [ { "type": "function", "function": { "name": "get_weather", "description": "Get weather of an location, the user shoud supply a location first", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g. San Francisco, CA", } }, "required": ["location"] }, } }, ] messages = [{"role": "user", "content": "How's the weather in Hangzhou?"}] message = send_messages(messages) print(f"User>\t {messages[0]['content']}") print(message) tool = message.tool_calls[0] messages.append(message) messages.append({"role": "tool", "tool_call_id": tool.id, "content": "24℃"}) message = send_messages(messages) print(f"Model>\t {message.content}") ================================================ FILE: archive/ktransformers/tests/humaneval/eval_api.py ================================================ # adapt from https://github.com/abacaj/code-eval?tab=readme-ov-file import argparse import os import requests from human_eval.data import write_jsonl, read_problems import tqdm from evaluation import filter_code, fix_indents from prompts import instruct_prompt def generate_text(api_url,question , model_name, stream=False, auth_token=None): headers = { 'accept': 'application/json', 'Content-Type': 'application/json', # 添加 API Key 'Authorization' : 'Bearer ' + auth_token if auth_token else '' } question = instruct_prompt(question) data = { "messages": [{"content": question, "role": "user"}], "model": model_name, "stream": stream, "temperature": 0.6 } print(f"content: {question}") response = requests.post(api_url, headers=headers, json=data,verify=False) if response.status_code == 200: result = response.json() results = result.get('choices', [{}])[0].get('message', {}).get('content', '') return [filter_code(fix_indents(results))] else: print(f"API Request failed with status code {response.status_code}") return None def run_eval_api( api_url: str, model_name: str, out_path: str, format_tabs: bool = False, auth_token: str = None, problem_file: str = None, append: bool = False, skip: int = 0 ): if(problem_file is None): problems = read_problems() else: problems = read_problems(problem_file) samples = [] pbar = tqdm.tqdm(total=len(problems) * 1) pbar.update(skip) try: for task_id in problems: # skip some tasks if skip > 0: skip -= 1 continue if format_tabs: prompt = problems[task_id]["prompt"].replace(" ", "\t") else: prompt = problems[task_id]["prompt"] completion = generate_text(api_url, prompt, model_name, auth_token=auth_token) # samples.append({"task_id": task_id, "completion": completion}) for sample in completion: result = dict( task_id=task_id, completion=sample, ) samples += [result] if append: write_jsonl(out_path, [result],append=append) pbar.update(1) if not append: write_jsonl(out_path, samples,append=append) except Exception as e: if not append: write_jsonl(out_path, samples,append=append) print(f"Error: {e}") def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append,skip): os.makedirs(os.path.dirname(output_path), exist_ok=True) run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append,skip) if __name__ == "__main__": parser = argparse.ArgumentParser(description="API Generate Tester") #parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL") parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") parser.add_argument("--model_name", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model Name") parser.add_argument("--out_path", type=str, default="results/api/eval_b.jsonl", help="Output Path") parser.add_argument("--auth_token", type=str, default=None, help="Auth Token") parser.add_argument("--format_tabs", action="store_true", help="Format Tabs") parser.add_argument("--problem_file", type=str, default=None, help="Evalset File") parser.add_argument("--no_append", action="store_false", help="Append to existing file") parser.add_argument("--skip", type=int, default=0, help="Skip first n problems") args = parser.parse_args() # api_url = "https://api.siliconflow.cn/v1/chat/completions" main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append,args.skip) ================================================ FILE: archive/ktransformers/tests/humaneval/evaluation.py ================================================ # reference: https://github.com/declare-lab/instruct-eval/blob/main/human_eval/main.py#L35 def filter_code(completion: str) -> str: # The program tends to overwrite, we only take the first function completion = completion.lstrip("\n") # we also remove ```python\n and ``` completion = completion.replace("```python\n", "").replace("```", "") if 'if __name__ == "__main__":' in completion: completion = completion.split('if __name__ == "__main__":')[0] if "# Example usage" in completion: completion = completion.split("# Example usage")[0] return completion def fix_indents(text: str) -> str: return text.replace("\t", " ") ================================================ FILE: archive/ktransformers/tests/humaneval/prompts.py ================================================ def instruct_prompt(prompt: str) -> str: return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nComplete the following Python code without any tests or explanation\n{prompt}\n\n### Response:""" def standard_prompt(prompt: str) -> str: return f"""Complete the following Python code without any tests or explanation\n{prompt}""" def write_prompt(prompt: str) -> str: return f"""Write a python program to complete the following code:\n{prompt}""" def replit_glaive_prompt(prompt: str) -> str: return f"""Below is an instruction that describes a task, paired with an input that provides further context.\n Write a response that appropriately completes the request.\n\n ### Instruction:\nWrite a program to perform the given task.\n\n Input:\n{prompt}\n\n### Response:""" ================================================ FILE: archive/ktransformers/tests/mmlu_pro_test.py ================================================ import argparse import random import time import json import requests import pandas as pd from datasets import load_dataset import os os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' os.environ['https_proxy'] = '' os.environ['http_proxy'] = '' hint = 'There is a single choice question. Answer the question by replying A, B, C, D, E, F, G, H, I, J. No other answers are accepted. Just the letter.' class DataEvaluator: def __init__(self): # self.template_prompt = template_prompt self.data = [] def load_data(self, file_path): """ Load data from a Parquet file into a list. Each record in the Parquet file should represent an individual record. """ # 读取 Parquet 文件 # dataset = load_dataset('parquet', data_files=file_path) ds = load_dataset("TIGER-Lab/MMLU-Pro") df = pd.DataFrame(ds['test']) # print(ds) # # ds_1 = ds['train'] # ds_2 = ds['validation'] # ds_3 = ds['test'] # # 将数据集转换为 Pandas DataFrame # df_test = pd.DataFrame(ds['test']) # df_val = pd.DataFrame(ds['validation']) # for _, row in df.iterrows(): # self.data.append(row.to_dict()) # df = pd.read_parquet(file_path) for _, row in df.iterrows(): self.data.append(row.to_dict()) def get_prompt(self, record): """ Combine fields from a record with the template prompt to create a full prompt. :param record: Dictionary containing fields to populate the template. :return: A formatted prompt string. """ # 查看ABCD。。。的选项 options_str = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(record['options'])]) prompt = hint + "\nQuestion: " + record['question'] + "\n" + options_str + "\nAnswer: '" return prompt def post_processing(self, text): """ Perform post-processing on the prediction string. :param text: The raw prediction string. :return: Processed prediction string. """ text = text.lstrip('\n').split('\n')[-1] return text[-1:] def score(self, pred, answers): """ Calculate scores between the prediction and the answer. Uses ROUGE scores as the evaluation metric. :param pred: The predicted string. :param answer: The reference answer string. :return: A dictionary containing ROUGE scores. """ for answer in answers: if pred == answer: return 1 return 0 # Function to generate text using API def generate_text(api_url, question, model_name, stream=False): headers = { 'accept': 'application/json', 'Content-Type': 'application/json', # 添加 API Key 'Authorization' : 'Bearer ' } data = { "messages": [{"content": question, "role": "user"}], "model": model_name, "stream": stream, # "temperature": 0.0 } print("POST data:", data) response = requests.post(api_url, headers=headers, json=data) if response.status_code == 200: result = response.json() return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip() else: print(f"API Request failed with status code {response.status_code}") return None # Main function to handle multiple evaluations def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name): start_total_time = time.time() total_score = 0 results = [] # 设置随机数种子 random.seed(42) random.shuffle(data_evaluator.data) for i in range(min(concurrent_requests, len(data_evaluator.data))): # Randomly select a data item from data for each request data_item = data_evaluator.data[i] question = data_evaluator.get_prompt(data_item) # print(question) # Start the timer for this evaluation start_time = time.time() try: # Generate prediction using the API prediction = generate_text(api_url, question, model_name) if prediction is None: raise Exception(f"Failed to get prediction for {question}") answer = data_item['answer'] # Compute score score = data_evaluator.score(data_evaluator.post_processing(prediction), answer) # Calculate the time taken elapsed_time = time.time() - start_time # Collect the result data result_data = { "question_id": data_item['question_id'], "answer": answer, "prediction": data_evaluator.post_processing(prediction), "score": score, "time": elapsed_time } # Write results to result.json with each field on a new line with open(result_file, 'a', encoding='utf-8') as f: json.dump(result_data, f, ensure_ascii=False, indent=4) f.write("\n") # Ensure each JSON object is on a new line results.append(result_data) # Aggregate scores total_score += score except Exception as e: print(f"Error processing request {i}: {e}") # Calculate total time and throughput total_time = time.time() - start_total_time throughput = concurrent_requests / total_time # Log the total time, throughput, and average ROUGE scores with open(log_file, 'a', encoding='utf-8') as log_f: log_f.write(f"Total Time: {total_time:.2f} seconds\n") log_f.write(f"Throughput: {throughput:.2f} requests per second\n") log_f.write(f"Average Scores: {total_score / concurrent_requests}\n") log_f.write('-' * 40 + '\n') print(f"Results saved to {result_file}") print(f"Log saved to {log_file}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="API Generate Tester") parser.add_argument("--concurrent", type=int, default=1000, help="Number of concurrent evaluations") parser.add_argument("--file", type=str, default="TIGER-Lab/MMLU-Pro", help="Path to the mmlu.jsonl file") parser.add_argument("--result", type=str, default="./mmlu_result_pro.json", help="Path to save the result JSON file") parser.add_argument("--log", type=str, default="./mmlu_result_pro.log", help="Path to save the log file") parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path") parser.add_argument("--api_url", type=str, default="http://localhost:15488/v1/chat/completions", help="API URL") # parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL") args = parser.parse_args() # Load the data from the provided file # template_prompt = hint + "\nQuestion: {question}\nA. {options}\nB. {option_b}\nC. {option_c}\nD. {option_d}\nAnswer: '" # template_prompt_pro = hint + "\nQuestion: {question}\nA. {options[0]}\nB. {options[1]}\nC. {options[2]}\nD. {options[3]}\nE. {options[4]}\nF. {options[5]}\nG. \ # {options[6]}\nH. {options[7]}\nI. {options[8]}\nJ. {options[9]}\nAnswer: '" # Load the data from the provided file data_evaluator = DataEvaluator() data_evaluator.load_data(args.file) # Run the main function with the specified number of concurrent evaluations main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model) ================================================ FILE: archive/ktransformers/tests/mmlu_test.py ================================================ import argparse import random import time import json import requests import pandas as pd from datasets import load_dataset import os os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' os.environ['https_proxy'] = '' os.environ['http_proxy'] = '' hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.' class DataEvaluator: def __init__(self): # self.template_prompt = template_prompt self.data = [] def load_data(self, file_path): """ Load data from a Parquet file into a list. Each record in the Parquet file should represent an individual record. """ # 读取 Parquet 文件 # dataset = load_dataset('parquet', data_files=file_path) splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet', 'dev': 'all/dev-00000-of-00001.parquet', 'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'} df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"]) for _, row in df.iterrows(): self.data.append(row.to_dict()) def get_prompt(self, record): """ Combine fields from a record with the template prompt to create a full prompt. :param record: Dictionary containing fields to populate the template. :return: A formatted prompt string. """ # 查看ABCD。。。的选项 options_str = "\n".join([f"{chr(65 + i)}. {opt}" for i, opt in enumerate(record['choices'])]) prompt = hint + "\nQuestion: " + record['question'] + "\n" + options_str + "\nAnswer: '" return prompt def post_processing(self, text): """ Perform post-processing on the prediction string. :param text: The raw prediction string. :return: Processed prediction string. """ text = text.lstrip('\n').split('\n')[-1] return text[-1:] def score(self, pred, answers): """ Calculate scores between the prediction and the answer. Uses ROUGE scores as the evaluation metric. :param pred: The predicted string. :param answer: The reference answer string. :return: A dictionary containing ROUGE scores. """ for answer in answers: if pred == answer: return 1 return 0 # Function to generate text using API def generate_text(api_url, question, model_name, stream=False): headers = { 'accept': 'application/json', 'Content-Type': 'application/json', # 添加 API Key 'Authorization' : 'Bearer ' } data = { "messages": [{"content": question, "role": "user"}], "model": model_name, "stream": stream, # "temperature": 0.0 } print("POST data:", data) response = requests.post(api_url, headers=headers, json=data) if response.status_code == 200: result = response.json() return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip() else: print(f"API Request failed with status code {response.status_code}") return None # Main function to handle multiple evaluations def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name): start_total_time = time.time() total_score = 0 results = [] # 设置随机数种子 random.seed(42) random.shuffle(data_evaluator.data) for i in range(min(concurrent_requests, len(data_evaluator.data))): # Randomly select a data item from data for each request data_item = data_evaluator.data[i] question = data_evaluator.get_prompt(data_item) # print(question) # Start the timer for this evaluation start_time = time.time() try: # Generate prediction using the API prediction = generate_text(api_url, question, model_name) if prediction is None: raise Exception(f"Failed to get prediction for {question}") answer = chr(data_item['answer'] + 65) # Compute score score = data_evaluator.score(data_evaluator.post_processing(prediction), answer) # Calculate the time taken elapsed_time = time.time() - start_time # Collect the result data result_data = { "question_id": i, "answer": answer, "prediction": data_evaluator.post_processing(prediction), "score": score, "time": elapsed_time } # Write results to result.json with each field on a new line with open(result_file, 'a', encoding='utf-8') as f: json.dump(result_data, f, ensure_ascii=False, indent=4) f.write("\n") # Ensure each JSON object is on a new line results.append(result_data) # Aggregate scores total_score += score except Exception as e: print(f"Error processing request {i}: {e}") # Calculate total time and throughput total_time = time.time() - start_total_time throughput = concurrent_requests / total_time # Log the total time, throughput, and average ROUGE scores with open(log_file, 'a', encoding='utf-8') as log_f: log_f.write(f"Total Time: {total_time:.2f} seconds\n") log_f.write(f"Throughput: {throughput:.2f} requests per second\n") log_f.write(f"Average Scores: {total_score / concurrent_requests}\n") log_f.write('-' * 40 + '\n') print(f"Results saved to {result_file}") print(f"Log saved to {log_file}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="API Generate Tester") parser.add_argument("--concurrent", type=int, default=1000, help="Number of concurrent evaluations") parser.add_argument("--file", type=str, default="cais/mmlu", help="Path to the mmlu.jsonl file") parser.add_argument("--result", type=str, default="./mmlu_result_silicon.json", help="Path to save the result JSON file") parser.add_argument("--log", type=str, default="./mmlu_result_silicon.log", help="Path to save the log file") parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path") parser.add_argument("--api_url", type=str, default="http://localhost:10003/v1/chat/completions", help="API URL") # parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL") args = parser.parse_args() # Load the data from the provided file # template_prompt = hint + "\nQuestion: {question}\nA. {options}\nB. {option_b}\nC. {option_c}\nD. {option_d}\nAnswer: '" # template_prompt_pro = hint + "\nQuestion: {question}\nA. {options[0]}\nB. {options[1]}\nC. {options[2]}\nD. {options[3]}\nE. {options[4]}\nF. {options[5]}\nG. \ # {options[6]}\nH. {options[7]}\nI. {options[8]}\nJ. {options[9]}\nAnswer: '" # Load the data from the provided file data_evaluator = DataEvaluator() data_evaluator.load_data(args.file) # Run the main function with the specified number of concurrent evaluations main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model) ================================================ FILE: archive/ktransformers/tests/mmlu_test_multi.py ================================================ import argparse import random import time import json import requests import pandas as pd from datasets import load_dataset import os import concurrent.futures import threading import re os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' os.environ['https_proxy'] = '' os.environ['http_proxy'] = '' hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.' def extract_final_answer(text): """ 提取模型预测的最终选项(如 A/B/C/D) 支持自然语言、多行、markdown、高亮、非末尾结论等格式 """ text = text.strip() # 1. 显式语句匹配(优先) explicit_patterns = [ r'Answer:\s*([A-D])\b', r'Correct answer:\s*([A-D])\b', r'The correct answer is\s*\*?\*?\s*([A-D])\b', r'Answer is\s*([A-D])\b', r'Therefore,\s*answer is\s*([A-D])\b', r'Therefore,\s*the answer should be\s*(?:Option\s*)?([A-D])\b', r'The answer should be\s*(?:Option\s*)?([A-D])\b', r'Option\s+([A-D])\s+is correct', ] for pat in explicit_patterns: match = re.search(pat, text, re.IGNORECASE) if match: return match.group(1).upper() # 2. markdown 强调 **C**, **C. something** markdown_match = re.findall(r'\*\*\s*([A-D])[\.\s]?', text) if markdown_match: return markdown_match[-1].upper() # 3. 查找单引号中的 'C' 或 "C" quote_match = re.findall(r"['\"]([A-D])['\"]", text) if quote_match: return quote_match[-1].upper() # 4. 倒数几行是否以 "C." 或 "C" 开头 lines = text.splitlines() for line in reversed(lines[-5:]): line = line.strip() match = re.match(r'^([A-D])([.\s]|$)', line) if match: return match.group(1).upper() # 再不行就返回 None return None class DataEvaluator: def __init__(self): self.data = [] def load_data(self, file_path): """ 从数据文件中加载数据,每条记录对应一个实例 """ splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet', 'dev': 'all/dev-00000-of-00001.parquet', 'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'} df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"]) for _, row in df.iterrows(): self.data.append(row.to_dict()) def get_prompt(self, record): """ 结合提示信息和记录数据生成完整的题目 """ options_str = "\n".join([f"{chr(65 + i)}. {opt}" for i, opt in enumerate(record['choices'])]) prompt = hint + "\nQuestion: " + record['question'] + "\n" + options_str + "\nAnswer: '" return prompt def post_processing(self, text): """ 对生成的文本进行后处理,提取最终答案(只返回最后一个字符) """ text = text.lstrip('\n').split('\n')[-1] return text[-1:] def score(self, pred, answer): """ 对比预测答案和正确答案,返回得分 """ if pred == answer: return 1 return 0 def generate_text(api_url, question, model_name, stream=False): headers = { 'accept': 'application/json', 'Content-Type': 'application/json', 'Authorization': 'Bearer ' # 如有需要,请填入 API Key } data = { "messages": [{"content": question, "role": "user"}], "model": model_name, "stream": stream, } print("POST data:", data) response = requests.post(api_url, headers=headers, json=data, timeout=5000000) if response.status_code == 200: result = response.json() return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip() else: print(f"API Request failed with status code {response.status_code}") return None def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name): start_total_time = time.time() total_score = 0 total_exact_score = 0 results = [] file_lock = threading.Lock() # 打乱数据顺序,并选择需要测试的实例数 random.seed(42) random.shuffle(data_evaluator.data) data_subset = data_evaluator.data[:min(concurrent_requests, len(data_evaluator.data))] batch_size = 10 # 每批次最多 10 个实例 def worker(index, data_item): nonlocal total_score nonlocal total_exact_score question = data_evaluator.get_prompt(data_item) start_time = time.time() try: prediction = generate_text(api_url, question, model_name) if prediction is None: raise Exception(f"Failed to get prediction for question: {question}") # 正确答案:将数字转换成字母(0->A, 1->B, 2->C, 3->D) answer = chr(data_item['answer'] + 65) processed_prediction = data_evaluator.post_processing(prediction) score = data_evaluator.score(processed_prediction, answer) exact_score = data_evaluator.score(extract_final_answer(prediction), answer) elapsed_time = time.time() - start_time result_data = { "question_id": index, "answer": answer, "prediction": processed_prediction, "full_prediction": prediction, "score": score, "exact_score": exact_score, "time": elapsed_time } # 写入结果时加锁保证线程安全 with file_lock: with open(result_file, 'a', encoding='utf-8') as f: json.dump(result_data, f, ensure_ascii=False, indent=4) f.write("\n") return result_data except Exception as e: print(f"Error processing request {index}: {e}") return None # 按批次处理,每批最多 10 个任务 for batch_start in range(0, len(data_subset), batch_size): batch = data_subset[batch_start: batch_start + batch_size] with concurrent.futures.ThreadPoolExecutor(max_workers=batch_size) as executor: futures = [executor.submit(worker, batch_start + j, data_item) for j, data_item in enumerate(batch)] for future in concurrent.futures.as_completed(futures): res = future.result() if res is not None: results.append(res) total_score += res['score'] total_exact_score += res['exact_score'] total_time = time.time() - start_total_time throughput = len(data_subset) / total_time if total_time > 0 else 0 with open(log_file, 'a', encoding='utf-8') as log_f: log_f.write(f"Total Time: {total_time:.2f} seconds\n") log_f.write(f"Throughput: {throughput:.2f} requests per second\n") average_score = total_score / len(data_subset) if data_subset else 0 log_f.write(f"Average Score: {average_score}\n") average_exact_score = total_exact_score / len(data_subset) if data_subset else 0 log_f.write(f"Average Exact Score: {average_exact_score}\n") log_f.write('-' * 40 + '\n') print(f"Results saved to {result_file}") print(f"Log saved to {log_file}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="API Generate Tester") parser.add_argument("--concurrent", type=int, default=1000, help="需要测试的实例总数") parser.add_argument("--file", type=str, default="cais/mmlu", help="数据文件路径") parser.add_argument("--result", type=str, default="./mmlu_result_silicon.json", help="结果文件保存路径") parser.add_argument("--log", type=str, default="./mmlu_result_silicon.log", help="日志文件保存路径") parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="模型名称或路径") parser.add_argument("--api_url", type=str, default="http://localhost:10006/v1/chat/completions", help="API URL") args = parser.parse_args() data_evaluator = DataEvaluator() data_evaluator.load_data(args.file) main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model) ================================================ FILE: archive/ktransformers/tests/parse_cover_info.py ================================================ import os import ast import argparse from coverage import Coverage def main(): parser = argparse.ArgumentParser( description="统计某个类在 .coverage 数据中的行覆盖率" ) parser.add_argument( "--data-file", default=".coverage", help="coverage 数据文件路径(默认 ./.coverage)", ) parser.add_argument( "--file", dest="file_pattern", default="ktransformers/operators/ascend/ascend_attention.py", help=( "要统计的源码文件路径(可用结尾匹配,默认 " "ktransformers/operators/ascend/ascend_attention.py)" ), ) parser.add_argument( "--class", dest="class_name", default="KDeepseekV2AttentionW8A8A2Serve", help="要统计的类名(默认 KDeepseekV2AttentionW8A8A2Serve)", ) args = parser.parse_args() if not os.path.exists(args.data_file): print(f"找不到 coverage 数据文件: {args.data_file}") raise SystemExit(1) cov = Coverage(data_file=args.data_file) cov.load() data = cov.get_data() file_pattern_norm = os.path.normpath(args.file_pattern) target_file = None for f in data.measured_files(): f_norm = os.path.normpath(f) if f_norm.endswith(file_pattern_norm) or file_pattern_norm in f_norm: target_file = f break if not target_file: print( f"没有在 coverage 数据里找到匹配文件: {args.file_pattern}\n" f"实际记录的文件有:" ) for f in data.measured_files(): print(" ", f) raise SystemExit(1) print("使用的源码文件:", target_file) executed_lines = set(data.lines(target_file) or []) try: with open(target_file, "r", encoding="utf-8") as f: source_text = f.read() except OSError as e: print(f"无法打开源码文件 {target_file}: {e}") raise SystemExit(1) source_lines = source_text.splitlines() tree = ast.parse(source_text) class_start = None class_end = None for node in tree.body: if isinstance(node, ast.ClassDef) and node.name == args.class_name: class_start = node.lineno max_lineno = node.lineno for sub in ast.walk(node): ln = getattr(sub, "end_lineno", getattr(sub, "lineno", None)) if ln is not None and ln > max_lineno: max_lineno = ln class_end = max_lineno break if class_start is None: print(f"在源码 {target_file} 中没有找到类 {args.class_name}") raise SystemExit(1) print( f"类 {args.class_name} 行范围: {class_start} ~ {class_end}" ) total = 0 covered = 0 missed_lines = [] for lineno in range(class_start, class_end + 1): line = source_lines[lineno - 1].strip() # 跳过空行和纯注释 if not line or line.startswith("#"): continue total += 1 if lineno in executed_lines: covered += 1 else: missed_lines.append(lineno) percent = (covered / total * 100) if total > 0 else 0.0 print( f"类 {args.class_name} 覆盖: {covered}/{total} 行, 覆盖率 = {percent:.1f}%" ) if missed_lines: print("未覆盖行号:", missed_lines) else: print("该类所有有效代码行均被覆盖") if __name__ == "__main__": main() ================================================ FILE: archive/ktransformers/tests/score.py ================================================ import subprocess import time import requests import sys import os def wait_for_server(base_url: str, timeout: int = None) -> None: start_time = time.time() while True: try: response = requests.get( f"{base_url}/v1/models", headers={"Authorization": "Bearer None"}, ) if response.status_code == 200: print("Server is ready.") break except requests.exceptions.RequestException: time.sleep(1) if timeout and time.time() - start_time > timeout: raise TimeoutError("Server did not become ready within timeout period") server_cmd = [ "numactl", "-N", "1", "-m", "1", "/home/qujing3/anaconda3/envs/ktransformers-dev/bin/ktransformers", "--model_path", "/home/qujing3/models/DeepSeek-R1-Q4_K_M/config", "--gguf_path", "/home/qujing3/models/DeepSeek-V3-GGUF/DeepSeek-V3-Q4_K_M", "--port", "10002", "--cpu_infer", "48", "--optimize_config_path", "ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml", "--max_new_tokens", "3000", "--cache_lens", "6000" ] print("Starting ktransformers server...") print(" ".join(server_cmd)) with open("/tmp/server_log.txt", "w") as f: server_process = subprocess.Popen(server_cmd, stdout=f, stderr=f, text=True) try: wait_for_server("http://localhost:10002", timeout=600) eval_cmd = ["python", "ktransformers/tests/humaneval/eval_api.py"] print("Running eval_api.py...") print(f"Command: {' '.join(eval_cmd)}") env = os.environ.copy() env["PYTHONUNBUFFERED"] = "1" eval_process = subprocess.Popen( eval_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env, universal_newlines=True ) import threading import queue def enqueue_output(out, queue): for line in iter(out.readline, ''): queue.put(line) out.close() stdout_queue = queue.Queue() stderr_queue = queue.Queue() stdout_thread = threading.Thread(target=enqueue_output, args=(eval_process.stdout, stdout_queue)) stderr_thread = threading.Thread(target=enqueue_output, args=(eval_process.stderr, stderr_queue)) stdout_thread.daemon = True stderr_thread.daemon = True stdout_thread.start() stderr_thread.start() while eval_process.poll() is None: try: line = stdout_queue.get_nowait() print(line, end='', flush=True) except queue.Empty: pass try: line = stderr_queue.get_nowait() print(line, end='', file=sys.stderr, flush=True) except queue.Empty: pass time.sleep(1) while not stdout_queue.empty(): print(stdout_queue.get(), end='', flush=True) while not stderr_queue.empty(): print(stderr_queue.get(), end='', file=sys.stderr, flush=True) eval_process.wait() print(f"eval_api.py completed with exit code: {eval_process.returncode}") evaluate_cmd = [ "evaluate_functional_correctness", "ktransformers/tests/humaneval/results/api/eval_b.jsonl" ] print("Running evaluate_functional_correctness...") print(f"Command: {' '.join(evaluate_cmd)}") evaluate_process = subprocess.Popen( evaluate_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, universal_newlines=True ) for line in evaluate_process.stdout: print(line, end='', flush=True) for line in evaluate_process.stderr: print(line, end='', file=sys.stderr, flush=True) evaluate_process.wait() print(f"evaluate_functional_correctness completed with exit code: {evaluate_process.returncode}") if evaluate_process.returncode != 0: print(f"evaluate_functional_correctness exited with code {evaluate_process.returncode}") sys.exit(evaluate_process.returncode) finally: print("Stopping ktransformers server...") server_process.terminate() try: server_process.wait(timeout=30) except subprocess.TimeoutExpired: print("Server did not terminate gracefully, forcing...") server_process.kill() ================================================ FILE: archive/ktransformers/tests/test_client.py ================================================ import asyncio import json import sys import aiohttp import argparse prompt_list = [ 'Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke ' ] async def fetch_event_stream(session, payload, request_id, stream): try: headers = { 'accept': 'application/json', 'Content-Type': 'application/json' } async with session.post(SERVER_URL, json=payload, headers=headers, timeout=50000) as response: print(f"Request {request_id}: Connected, status {response.status}") if response.status != 200: print(f"Request {request_id}: Error, status {response.status}") return output_text = "" if stream: async for line in response.content: try: decoded_line = line.decode("utf-8").strip() if not decoded_line or not decoded_line.startswith("data: "): continue decoded_line = decoded_line[6:].strip() if not decoded_line: continue response_data = json.loads(decoded_line) choices = response_data.get("choices", []) if not choices: continue delta = choices[0].get("delta", {}) token = delta.get("content", "") if token: output_text += token sys.stdout.write(token) sys.stdout.flush() finish_reason = choices[0].get("finish_reason", None) if finish_reason: break except json.JSONDecodeError as e: print(f"\nRequest {request_id}: JSON Decode Error - {e}") except IndexError: print(f"\nRequest {request_id}: List Index Error - choices is empty") except Exception as e: print(f"\nRequest {request_id}: Error parsing stream - {e}") else: # 非 stream 模式下,一次性接收完整 json response_data = await response.json() choices = response_data.get("choices", []) if choices: content = choices[0].get("message", {}).get("content", "") print(f"Request {request_id} Output:\n{content}") output_text += content except Exception as e: print(f"\nRequest {request_id}: Exception - {e}") async def main(prompt_id, model, stream, max_tokens, temperature, top_p): async with aiohttp.ClientSession() as session: payload = { "messages": [ {"role": "system", "content": ""}, {"role": "user", "content": prompt_list[prompt_id]} ], "model": model, "stream": stream, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p } tasks = [fetch_event_stream(session, payload, prompt_id, stream)] await asyncio.gather(*tasks) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Event Stream Request Tester") parser.add_argument("--question_id", type=int, default=0) parser.add_argument("--model", type=str, default="DeepSeek-V3") parser.add_argument("--stream", type=bool, default=True) parser.add_argument("--max_tokens", type=int, default=500) parser.add_argument("--temperature", type=float, default=0.8) parser.add_argument("--top_p", type=float, default=1) parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") args = parser.parse_args() SERVER_URL = args.api_url asyncio.run(main(args.question_id, args.model, args.stream, args.max_tokens, args.temperature, args.top_p)) ================================================ FILE: archive/ktransformers/tests/test_prefix.py ================================================ import asyncio import json import sys import aiohttp import random import argparse import yaml import os import time from time import sleep decodesz = 128 # Server URL (replace with your server URL) decodesz_list = [128] prefill_speeds = [] decode_speeds = [] async def fetch_message_once(session, request_id, messages, max_tokens, model): try: payload = { "messages": messages, "model": model, "temperature": 0.3, "top_p": 1.0, "stream": True, "return_speed": True, "max_tokens": max_tokens, } headers = { 'accept': 'application/json', 'Content-Type': 'application/json' } async with session.post(SERVER_URL, json=payload, headers=headers, timeout=500000) as response: if response.status != 200: print(f"[Request {request_id}] Error: Status {response.status}") return None, None, None buffer = "" usage_info = None answer = "" async for line in response.content: decoded_line = line.decode("utf-8").strip() if not decoded_line or not decoded_line.startswith("data: "): continue decoded_line = decoded_line[6:].strip() if not decoded_line: continue response_data = json.loads(decoded_line) if "usage" in response_data: usage_info = response_data["usage"] choices = response_data.get("choices", []) if not choices: continue delta = choices[0].get("delta", {}) token = delta.get("content", "") if token: buffer += token answer += token finish_reason = choices[0].get("finish_reason", None) if finish_reason: break return answer.strip(), usage_info, buffer.strip() except Exception as e: print(f"[Request {request_id}] Exception: {e}") return None, None, None async def multi_turn_conversation(session, request_id, rounds, max_tokens, model): prompt = ["介绍一下秦始皇", "秦始皇的成就有哪些", "秦始皇的历史影响", "介绍一下秦始皇的陵墓", "秦始皇的统一措施", "秦始皇的政治制度", "秦始皇的文化政策", "秦始皇的军事行动"] messages = [{"role": "system", "content": ""}] global prefill_speeds, decode_speeds for i in range(rounds): user_msg = f"这是第{i + 1}轮对话,请回答以下问题:{prompt[i % len(prompt)]}" messages.append({"role": "user", "content": user_msg}) print(f"\n[Request {request_id}] >> User: {user_msg}") answer, usage_info, _ = await fetch_message_once(session, request_id, messages, max_tokens, model) if answer: messages.append({"role": "user", "content": answer}) print(f"[Request {request_id}] << Assistant: {answer}") if usage_info: prefill_speed = usage_info["prompt_tokens"] / usage_info["prefill_time"] decode_speed = usage_info["completion_tokens"] / usage_info["decode_time"] prefill_speeds.append(prefill_speed) decode_speeds.append(decode_speed) print(f'[Request {request_id}] prefill speed: {prefill_speed}') print(f'[Request {request_id}] decode speed: {decode_speed}') async def main(concurrent_requests, rounds, max_tokens, model): async with aiohttp.ClientSession() as session: tasks = [multi_turn_conversation(session, i, rounds, max_tokens, model) for i in range(concurrent_requests)] await asyncio.gather(*tasks) if prefill_speeds: import numpy as np print(f"\n=== Summary ===") print(f"Total concurrency: {concurrent_requests}") print(f"Avg prefill speed: {np.mean(prefill_speeds)}") print(f"Avg decode speed: {np.mean(decode_speeds)}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Event Stream Request Tester") parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests") parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name") parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048") parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens") parser.add_argument("--rounds", type=int, default=8, help="Number of multi-turn rounds (before final query)") args = parser.parse_args() SERVER_URL = args.api_url max_tokens = args.max_tokens model = args.model asyncio.run(main(args.concurrent, args.rounds, max_tokens, model)) ================================================ FILE: archive/ktransformers/tests/test_pytorch_q8.py ================================================ import torch # 定义一个包含线性层的浮点模型 class LinearModel(torch.nn.Module): def __init__(self, in_features, out_features): super().__init__() self.linear = torch.nn.Linear(in_features, out_features) def forward(self, x): return self.linear(x) # 创建浮点模型实例 in_features = 64 out_features = 128 model_fp32 = LinearModel(in_features, out_features) # 创建量化模型实例 model_int8 = torch.ao.quantization.quantize_dynamic( model_fp32, # 原始浮点模型 {torch.nn.Linear}, # 要量化的层类型集合 dtype=torch.qint8 # 量化的目标数据类型 ) # 测试模型 batch_size = 32 input_fp32 = torch.randn(1, batch_size, in_features) # 生成随机输入数据 output_int8 = model_int8(input_fp32) # 通过量化模型运行数据 # 打印输出形状验证 print(f"输入形状: {input_fp32.shape}") print(f"输出形状: {output_int8.shape}") # 比较原始模型和量化模型的输出 with torch.no_grad(): output_fp32 = model_fp32(input_fp32) print(f"FP32输出的前几个值: {output_fp32[0, :5]}") print(f"INT8输出的前几个值: {output_int8[0, :5]}") # 计算平均误差 error = torch.abs(output_fp32 - output_int8).mean().item() print(f"平均绝对误差: {error}") # 打印模型类型信息 print(f"量化前模型类型: {type(model_fp32.linear)}") print(f"量化后模型类型: {type(model_int8.linear)}") ================================================ FILE: archive/ktransformers/tests/test_speed.py ================================================ import asyncio import json import sys import aiohttp import random import argparse import yaml import os import time from time import sleep decodesz = 128 # Server URL (replace with your server URL) decodesz_list = [128] prefill_speeds = [] decode_speeds = [] ktansformer_prompt1024="""Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. The Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere. The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. They didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. Dursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. The Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. The Dursleys knew that the Potters had a small son, too, but they had never even seen him. This boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. Mr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair.None of them noticed a large, tawny owl flutter past the window. At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls. “Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive. It was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. For a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. There was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. What could he have been thinking of? It must have been a trick of the light. Mr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. It was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. Mr. Dursley gave himself a little shake and put the cat out of his mind. As he drove toward town he thought of nothing except a large order of drills he was hoping to get that day. But on the edge of town, drills were driven out of his mind by something else. As he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. People in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! He supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. They were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. Mr. Dursley always sat with his back to the window in his office on the ninth floor.""" async def fetch_event_stream(session, request_id, prompt, max_tokens, model): try: payload = { "messages": [ {"role": "system", "content": ""}, {"role": "user", "content": prompt} ], "model": model, "temperature": 0.3, "top_p": 1.0, "stream": True, "return_speed": True, "max_tokens": max_tokens, } headers = { 'accept': 'application/json', 'Content-Type': 'application/json' } async with session.post(SERVER_URL, json=payload, headers=headers, timeout=500000) as response: if response.status != 200: print(f"[Request {request_id}] Error: Status {response.status}") return buffer = "" total_tokens = 0 decode_start_time = None decode_end_time = None usage_info = None async for line in response.content: try: decoded_line = line.decode("utf-8").strip() if not decoded_line or not decoded_line.startswith("data: "): continue decoded_line = decoded_line[6:].strip() if not decoded_line: continue response_data = json.loads(decoded_line) if "usage" in response_data: usage_info = response_data["usage"] choices = response_data.get("choices", []) if not choices: continue delta = choices[0].get("delta", {}) token = delta.get("content", "") if token: if decode_start_time is None: decode_start_time = time.time() buffer += token total_tokens += 1 decode_end_time = time.time() while "\n" in buffer: line, buffer = buffer.split("\n", 1) print(f"[Request {request_id}] {line}") finish_reason = choices[0].get("finish_reason", None) if finish_reason: break except Exception as e: print(f"[Request {request_id}] Stream Error: {e}") if buffer.strip(): print(f"[Request {request_id}] {buffer.strip()}") if usage_info: if "prefill_time" in usage_info: # print(f"[Request {request_id}] Usage:") # for key, value in usage_info.items(): # print(f" {key}: {value}") prefill_speed = usage_info["prompt_tokens"] / usage_info["prefill_time"] decode_speed = usage_info["completion_tokens"] / usage_info["decode_time"] prefill_speeds.append(prefill_speed) decode_speeds.append(decode_speed) print(f'[Request {request_id}] prefill speed: {prefill_speed}') print(f'[Request {request_id}] decode speed: {decode_speed}') except Exception as e: print(f"[Request {request_id}] Exception: {e}") async def main(concurrent_requests , prompt, max_tokens, model): async with aiohttp.ClientSession() as session: tasks = [fetch_event_stream(session, i , prompt, max_tokens, model) for i in range(concurrent_requests)] await asyncio.gather(*tasks) if len(prefill_speeds) != 0: import numpy as np print(f"concurrency: {len(prefill_speeds)}") print(f"total prefill speed: {np.sum(prefill_speeds)}\n total decode speed: {np.sum(decode_speeds)}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Event Stream Request Tester") parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests") parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name") parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048") parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") parser.add_argument("--max_tokens", type=int, default=500, help="max decode tokens") args = parser.parse_args() SERVER_URL = args.api_url max_tokens = args.max_tokens model = args.model if args.prompt_lens == 1024: prompt = ktansformer_prompt1024 elif args.prompt_lens == 2048: prompt = ktansformer_prompt1024 * 2 elif args.prompt_lens == 4096: prompt = ktansformer_prompt1024 * 4 asyncio.run(main(args.concurrent, prompt, max_tokens, model)) ================================================ FILE: archive/ktransformers/tests/triton_fp8gemm_test.py ================================================ import torch import torch.nn.functional as F from typing import Optional import pytest from typing import Tuple, Optional, Literal import time # use dir path import os import sys sys.path.insert(0, "/home/azure/ktransformers") print(sys.path) from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant from safetensors import safe_open world_size = 1 rank = 0 block_size = 128 gemm_impl: Literal["bf16", "fp8"] = "bf16" # Assuming `fp8_gemm`, `act_quant`, `weight_dequant` and other relevant functions are already defined def test_fp8_gemm_vs_torch_matmul(): # Test case 1: Create random matrices of size (M, K) and (K, N) M, K, N = 64, 128, 256 # Matrix dimensions x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda') weight = torch.randn(N, K, dtype=torch.bfloat16, device='cuda') # Apply act_quant to both matrices x_quantized, scale_x = act_quant(x, block_size) weight_quantized, scale_w = act_quant(weight, block_size) # mk continous x_quantized = x_quantized.contiguous() weight_quantized = weight_quantized.contiguous() scale_x = scale_x.contiguous() scale_w = scale_w.contiguous() # Perform fp8_gemm using the quantized tensors result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight_quantized, scale_w) # Perform torch.matmul using the original floating point tensors result_torch_matmul = torch.matmul(x, weight.T) print(f'result_torch_matmul: {result_torch_matmul.shape}') print(f'result_fp8_gemm: {result_fp8_gemm.shape}') print(f"result_fp8_gemm:\n {result_fp8_gemm}") print(f"result_torch_matmul:\n {result_torch_matmul}") def test_fp8_gemm_vs_torch_matmul_load(): file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors" with safe_open(file_path, framework="pt", device=0) as f: weight = f.get_tensor("model.layers.0.mlp.down_proj.weight") scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv") # weight_dequant weight_dequantized = weight_dequant(weight, scale) print(f"weight_dequantized: {weight_dequantized.shape}") N, K = weight_dequantized.shape M = 64 x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda') x_quantized, scale_x = act_quant(x, block_size) # Test case 1: quantized x matmal with undequantized weight result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) print(f"result_fp8_gemm:\n {result_fp8_gemm}") print(f"dtype {result_fp8_gemm.dtype}") # Perform torch.matmul using the original floating point tensors result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T) print(f"result_torch_matmul:\n {result_torch_matmul}") def test_fp8_gemm_tplops(): file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors" with safe_open(file_path, framework="pt", device=0) as f: weight = f.get_tensor("model.layers.0.mlp.down_proj.weight") scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv") # weight_dequant weight_dequantized = weight_dequant(weight, scale) print(f"weight_dequantized: {weight_dequantized.shape}") N, K = weight_dequantized.shape M = 6400 x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda') # x_quantized, scale_x = act_quant(x, block_size) # Calculate time for 1000 fp8_gemm i = 10 flops_per_gemm = 2 * M * N * K total_flops = i * flops_per_gemm x_quantized, scale_x = act_quant(x, block_size) result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) x_quantized, scale_x = act_quant(x, block_size) result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) t0 = time.time() torch.cuda.synchronize() for i in range(i): x_quantized, scale_x = act_quant(x, block_size) result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale) torch.cuda.synchronize() t1 = time.time() total_time = t1 - t0 tflops = total_flops / total_time / 1e12 print(f"total_time: {total_time}") print(f"tflops: {tflops}") if __name__ == "__main__": test_fp8_gemm_vs_torch_matmul() test_fp8_gemm_vs_torch_matmul_load() test_fp8_gemm_tplops() ================================================ FILE: archive/ktransformers/util/ascend/ascend_utils.py ================================================ # coding=utf-8 # Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved. # Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from datetime import timedelta import torch import torch_npu import torch.distributed as dist _DATA_PARALLEL_SIZE = 0 _TENSOR_PARALLEL_SIZE = 0 _DATA_PARALLEL_GROUP = None _TENSOR_PARALLEL_RANKS = None _TENSOR_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP_GLOO = None _DATA_PARALLEL_RANKS = None def setup_model_parallel(distributed_timeout_minutes: int = 30, tp: int = 1): global _DATA_PARALLEL_SIZE global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_RANKS global _TENSOR_PARALLEL_SIZE global _TENSOR_PARALLEL_RANKS global _TENSOR_PARALLEL_GROUP # os.environ["MASTER_ADDR"] = "localhost" # os.environ["MASTER_PORT"] = "12345" local_rank = int(os.getenv("LOCAL_RANK", '0')) world_size = int(os.getenv("WORLD_SIZE", '1')) torch_npu.npu.set_device(local_rank) tp_size = tp dp_size = world_size // tp_size _DATA_PARALLEL_SIZE = dp_size _TENSOR_PARALLEL_SIZE = tp_size torch.set_num_threads(8) timeout = timedelta(minutes=distributed_timeout_minutes) print(f"start to init process group ------rank is {local_rank}, world_size is {world_size}") torch.distributed.init_process_group( backend='hccl', world_size=world_size, rank=local_rank ) print(f"init process group success ------rank is {local_rank}, world_size is {world_size}") rank = torch.distributed.get_rank() nccl_comm_cfgs = {} # DP 组由每隔 tp_size 的进程组成 for dp_group_id in range(tp_size): ranks = list(range(dp_group_id, world_size, tp_size)) dp_group = torch.distributed.new_group( ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs) ) if rank in ranks: global _DATA_PARALLEL_GROUP _DATA_PARALLEL_GROUP = dp_group _DATA_PARALLEL_RANKS = ranks # TP 组由连续的 dp_size 个进程组成 for tp_group_id in range(dp_size): start_rank = tp_group_id * tp_size end_rank = (tp_group_id + 1) * tp_size ranks = list(range(start_rank, end_rank)) tp_group = torch.distributed.new_group( ranks, timeout=timeout, pg_options=get_nccl_options('tp', nccl_comm_cfgs) ) if rank in ranks: global _TENSOR_PARALLEL_GROUP _TENSOR_PARALLEL_GROUP = tp_group _TENSOR_PARALLEL_RANKS = ranks # seed must be the same in all processes torch.manual_seed(1) return local_rank, world_size def get_tensor_parallel_size(): assert _TENSOR_PARALLEL_SIZE is not None, "tensor parallel size is not set" return _TENSOR_PARALLEL_SIZE def get_tensor_parallel_group(): assert _TENSOR_PARALLEL_GROUP is not None, "tensor parallel group is not initialized" return _TENSOR_PARALLEL_GROUP def get_tensor_parallel_rank(): assert _TENSOR_PARALLEL_RANKS is not None, "tensor parallel rank is not initialized" return _TENSOR_PARALLEL_RANKS def get_data_parallel_size(): assert _DATA_PARALLEL_SIZE is not None, "data parallel size is not initialized" return _DATA_PARALLEL_SIZE def get_data_parallel_gloo(): assert _DATA_PARALLEL_GROUP_GLOO is not None, "data parallel gloo group is not initialized" return _DATA_PARALLEL_GROUP_GLOO def get_data_parallel_group(): assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" return _DATA_PARALLEL_GROUP def get_data_parallel_rank(): assert _DATA_PARALLEL_RANKS is not None, "data parallel rank is not initialized" return _DATA_PARALLEL_RANKS def get_nccl_options(pg_name, nccl_comm_cfgs): if pg_name in nccl_comm_cfgs: nccl_options = torch.distributed.ProcessGroupNCCL.Options() nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get('cga_cluster_size', 4) nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get('max_ctas', 32) nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get('min_ctas', 1) return nccl_options else: return None def get_safetensors_cut_weight(name: str, weights: torch.Tensor): translate_col_cut_tensors = ["ffn_down", "attn_output"] # "kv_b_proj" translate_row_cut_tensors = ["ffn_gate", "ffn_up", "attn_q_b"] tp = get_tensor_parallel_size() if tp == 1 or weights.shape == torch.Size([1]): return weights rank = torch.distributed.get_rank() rank %= tp assert 0 <= rank < tp and tp > 0, f"unexpected {rank=}, {tp=}" if any(t in name for t in translate_col_cut_tensors): if weights.dim() == 1: return weights dim = weights.shape[-1] assert dim % tp == 0, f"unexpected division {dim=}, {tp=}" chunk_size = dim // tp output_weights = weights[:, rank * chunk_size:(rank + 1) * chunk_size] # print(f"col cut weights {name=} from {weights.shape=} to {output_weights.shape=}") return output_weights elif any(t in name for t in translate_row_cut_tensors): dim = weights.shape[0] assert dim % tp == 0, f"unexpected division {dim=}, {tp=}" chunk_size = dim // tp output_weights = weights[rank * chunk_size: (rank + 1) * chunk_size:] # print(f"row cut weights {name=} from {weights.shape=} to {output_weights.shape=}") return output_weights else: return weights def get_absort_weight(model, config): if not dist.is_initialized(): return local_rank = dist.get_rank() tp = get_tensor_parallel_size() local_rank %= tp tp_heads = config.num_attention_heads // tp for i in range(config.num_hidden_layers): attn = model.model.layers[i].self_attn if hasattr(attn, "q_absorb") and hasattr(attn, "out_absorb"): continue if not (hasattr(attn, "kv_b_proj") and hasattr(attn, "kv_lora_rank") and hasattr(attn, "qk_nope_head_dim")): continue kv_b_proj = attn.kv_b_proj.weight.view(config.num_attention_heads, -1, attn.kv_lora_rank) q_absorb = kv_b_proj[:, :attn.qk_nope_head_dim, :].clone() out_absorb = kv_b_proj[:, attn.qk_nope_head_dim:, :].clone() q_absorb = q_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous() out_absorb = out_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous() out_absorb = out_absorb.transpose(1, 2).contiguous() setattr(attn, "q_absorb", q_absorb) setattr(attn, "out_absorb", out_absorb) if hasattr(attn, "orig_module") and hasattr(attn.orig_module, "kv_b_proj"): del attn.orig_module.kv_b_proj dist.barrier(get_tensor_parallel_group()) def allredeuce_warpper(func): def wrapper(*args, **kwargs): orig_output = func(*args, **kwargs) if isinstance(orig_output, tuple): if get_tensor_parallel_size() > 1: org_dtype = orig_output[0].dtype if org_dtype == torch.bfloat16: dist.all_reduce(orig_output[0].to(dtype=torch.float16), op=dist.ReduceOp.SUM, group=get_tensor_parallel_group()) else: dist.all_reduce(orig_output[0], op=dist.ReduceOp.SUM, group=get_tensor_parallel_group()) if org_dtype == torch.bfloat16: bf_orig_output = orig_output[0].to(dtype=org_dtype) else: bf_orig_output = orig_output[0] else: bf_orig_output = orig_output[0] return (bf_orig_output,) + orig_output[1:] else: if get_tensor_parallel_size() > 1: org_dtype = orig_output.dtype if org_dtype == torch.bfloat16: orig_output = orig_output.to(dtype=torch.float16) dist.all_reduce(orig_output, op=dist.ReduceOp.SUM, group=get_tensor_parallel_group()) if org_dtype == torch.bfloat16: orig_output = orig_output.to(dtype=org_dtype) return orig_output return wrapper ================================================ FILE: archive/ktransformers/util/cuda_graph_runner.py ================================================ ''' Description : Author : Boxin Zhang Version : 0.1.0 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import torch from typing import Dict class CUDAGraphRunner: def __init__(self): self.graph = None self.input_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {} def capture( self, model, cur_token, position_ids, cache_position, past_key_values, main_device, **kwargs, ) -> None: assert self.graph is None # Capture the graph. torch.cuda.synchronize() self.graph = torch.cuda.CUDAGraph() #self.graph.enable_debug_mode() self.model = model inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(main_device) # torch.cuda.set_device can't set "cuda", must have a index if main_device == "cuda": main_device = "cuda:0" torch.cuda.set_device(main_device) self.main_device = main_device capture_stream = torch.cuda.Stream() with torch.cuda.graph(self.graph, stream = capture_stream): logits=model(inputs_embeds=inputs_embeds, position_ids=position_ids, cache_position=cache_position, past_key_values=past_key_values, **kwargs)[0] capture_stream.wait_stream(torch.cuda.current_stream()) torch.cuda.set_device(main_device) torch.cuda.set_stream(capture_stream) if past_key_values != None: past_key_values.change_seq_length(-1) torch.cuda.synchronize(self.main_device) #self.graph.debug_dump("cuda_graph_hooked.dot") # Save the input and output buffers. self.input_buffers = { "inputs_embeds": inputs_embeds, "position_ids": position_ids, "cache_position": cache_position, } self.output_buffers = {"logits": logits} return def forward( self, cur_token, position_ids, cache_position, ) -> torch.Tensor: # Copy the input tensors to the input buffers. inputs_embeds = self.model.model.embed_tokens(cur_token.to("cpu")) self.input_buffers["inputs_embeds"].copy_(inputs_embeds) self.input_buffers["position_ids"].copy_(position_ids) self.input_buffers["cache_position"].copy_(cache_position) # Run the graph. #print("begin replay") #time.sleep(1) self.graph.replay() torch.cuda.synchronize(self.main_device) # Return the output tensor. return self.output_buffers["logits"] def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) ================================================ FILE: archive/ktransformers/util/custom_gguf.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : Azure-Tang, Boxin Zhang, chenht2022 Date : 2024-07-26 08:48:54 Version : 1.0.0 LastEditors : kkk1nak0 LastEditTime : 2024-08-14 08:20:45 Adapted from https://github.com/99991/pygguf/blob/main/gguf.py Copyright (c) 2023-2024 The ggml authors Copyright (c) 2024 Thomas Germer Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' # copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf # GGUF specification # https://github.com/ggerganov/ggml/blob/master/docs/gguf.md import struct import warnings import numpy as np import re import numpy.typing as npt from typing import Sequence import os from enum import IntEnum import torch try: import torch_npu use_torch_npu = torch_npu.npu.is_available() except: use_torch_npu = False if not torch.xpu.is_available() and not use_torch_npu: import KTransformersOps import ctypes import math class GGMLQuantizationType(IntEnum): F32 = 0 F16 = 1 Q4_0 = 2 Q4_1 = 3 Q5_0 = 6 Q5_1 = 7 Q8_0 = 8 Q8_1 = 9 Q2_K = 10 Q3_K = 11 Q4_K = 12 Q5_K = 13 Q6_K = 14 Q8_K = 15 IQ2_XXS = 16 IQ2_XS = 17 IQ3_XXS = 18 IQ1_S = 19 IQ4_NL = 20 IQ3_S = 21 IQ2_S = 22 IQ4_XS = 23 I8 = 24 I16 = 25 I32 = 26 I64 = 27 F64 = 28 IQ1_M = 29 BF16 = 30 QK_K = 256 GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { GGMLQuantizationType.F32: (1, 4), GGMLQuantizationType.F16: (1, 2), GGMLQuantizationType.Q4_0: (32, 2 + 16), GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16), GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16), GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16), GGMLQuantizationType.Q8_0: (32, 2 + 32), GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32), GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4), GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12), GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12), GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12), GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16), GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8), GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4), GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32), GGMLQuantizationType.IQ3_XXS: (256, 2 + QK_K // 4 + QK_K // 8), GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16), GGMLQuantizationType.IQ4_NL: (32, 2 + 16), GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4), GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16), GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64), GGMLQuantizationType.I8: (1, 1), GGMLQuantizationType.I16: (1, 2), GGMLQuantizationType.I32: (1, 4), GGMLQuantizationType.I64: (1, 8), GGMLQuantizationType.F64: (1, 8), GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), GGMLQuantizationType.BF16: (1, 2), } # copied from llama.cpp/gguf-py/gguf/quants.py to avoid dependence of gguf def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType): block_size, type_size = GGML_QUANT_SIZES[quant_type] if shape[-1] % block_size != 0: raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})") return (*shape[:-1], shape[-1] // block_size * type_size) GGML_TYPES = { "F32": 0, "F16": 1, "Q4_0": 2, "Q5_0": 6, "Q8_0": 8, "Q2_K": 10, "Q3_K": 11, "Q4_K": 12, "Q5_K": 13, "Q6_K": 14, "IQ4_XS": 23, "BF16": 30, } GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()} GGML_BLOCK_SIZES = { "F32": 4, "F16": 2, "BF16": 2, "Q4_0": 2 + 16, "Q5_0": 2 + 4 + 16, "Q8_0": 2 + 32, "Q2_K": 256 // 16 + 256 // 4 + 2 + 2, "Q3_K": 256 // 8 + 256 // 4 + 12 + 2, "Q4_K": 2 + 2 + 12 + 256 // 2, "Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2, "Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2, "IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64, "FP8": 1, } GGML_ELEMENTS_PER_BLOCK = { "F32": 1, "F16": 1, "BF16": 1, "Q4_0": 32, "Q5_0": 32, "Q8_0": 32, "Q2_K": 256, "Q3_K": 256, "Q4_K": 256, "Q5_K": 256, "Q6_K": 256, "IQ4_XS": 256, "FP8": 1, } DATA_TYPES = { "uint8": 0, "int8": 1, "uint16": 2, "int16": 3, "uint32": 4, "int32": 5, "float32": 6, "bool": 7, "string": 8, "array": 9, "uint64": 10, "int64": 11, "float64": 12, "FP8": 13, } def read_value(f, data_type): if data_type == DATA_TYPES["string"]: length = struct.unpack("> 0, qs[:, 16:32] >> 0, qs[:, 00:16] >> 2, qs[:, 16:32] >> 2, qs[:, 00:16] >> 4, qs[:, 16:32] >> 4, qs[:, 00:16] >> 6, qs[:, 16:32] >> 6, qs[:, 32:48] >> 0, qs[:, 48:64] >> 0, qs[:, 32:48] >> 2, qs[:, 48:64] >> 2, qs[:, 32:48] >> 4, qs[:, 48:64] >> 4, qs[:, 32:48] >> 6, qs[:, 48:64] >> 6, ], axis=1) return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) def dequantize_q2_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q2_K"] ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q2_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) return KTransformersOps.dequantize_q2_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q3_k(data): # C implementation # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1723C32-L1723C42 # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L95 block_size = GGML_BLOCK_SIZES["Q3_K"] num_blocks = len(data) // block_size data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32) bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little") bits = 4 ^ (bits << 2) qs = data_u8[:, 32:32 + 64].astype(np.int16) a, b, c = data_u8[:, 96: 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2) scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8) scales[:, 0] = (a & 15) | ((c & 3) << 4) scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4) scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4) scales[:, 3] = (b >> 4) | ((c >> 6) << 4) scales = scales.reshape(num_blocks, 16, 1).astype(np.int16) return d * (scales - 32) * np.stack([ (((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]), (((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]), (((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]), (((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]), (((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]), (((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]), (((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]), (((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]), (((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]), (((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]), (((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]), (((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]), (((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]), (((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]), (((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]), (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]) ], axis=1) def dequantize_q3_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q3_K"] ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q3_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) return KTransformersOps.dequantize_q3_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q4_k(data): # C implementation # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1929 # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116 block_size = GGML_BLOCK_SIZES["Q4_K"] num_blocks = len(data) // block_size data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) # Casting to float32 because float16 is very slow on CPU scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32) scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32) qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1) qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32) # Dequantize scales and offsets (6 bits and 4 + 2 bits) factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1) offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1) # Interleave low and high quantized bits qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32) # Dequantize final weights using scales and offsets return factors * qs2 - offsets def dequantize_q4_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q4_K"] ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q4_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) return KTransformersOps.dequantize_q4_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q5_k(data): # C implementation # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2129 # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L138 block_size = GGML_BLOCK_SIZES["Q5_K"] num_blocks = len(data) // block_size data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32) dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32) scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1) qh = data_u8[:, 16: 16 + 32].reshape(num_blocks, 32, 1) qs = data_u8[:, 48: 48 + 128].reshape(num_blocks, 4, 32) bits = np.unpackbits(qh, axis=-1, bitorder="little") qs_hi_4 = qs >> 4 qs_lo_4 = qs & 15 scales_lo_6 = scales[:, :8] & 63 scales_hi_6 = scales[:, :8] >> 6 scales_lo_4 = scales[:, 8:] & 15 scales_hi_4 = scales[:, 8:] >> 4 m1 = dmin * scales_lo_6[:, 4] m2 = dmin * scales_lo_6[:, 5] m3 = dmin * scales_lo_6[:, 6] m4 = dmin * scales_lo_6[:, 7] m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4)) m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4)) m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4)) m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4)) d1 = d * scales_lo_6[:, 0] d2 = d * scales_lo_6[:, 1] d3 = d * scales_lo_6[:, 2] d4 = d * scales_lo_6[:, 3] d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4)) d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4)) d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4)) d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4)) return np.concatenate([ d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1, d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2, d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3, d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4, d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5, d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6, d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7, d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, ], axis=1) def dequantize_q5_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q5_K"] ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q5_K"] data = np.frombuffer(data, dtype=data.dtype) device = torch.device(device) # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) return KTransformersOps.dequantize_q5_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q6_k(data): # C implementation # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2275 # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L152 block_size = GGML_BLOCK_SIZES["Q6_K"] num_blocks = len(data) // block_size data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size) scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32) # TODO use uint8 and cast later? ql = data_u8[:, :128].astype(np.int16) qh = data_u8[:, 128:192].astype(np.int16) sc = data_i8[:, 192:208, np.newaxis].astype(np.float32) # Unpack bits, subtraction requires signed data type q1 = (ql[:, :32 ] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32 q2 = (ql[:, 32:64 ] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32 q3 = (ql[:, :32 ] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32 q4 = (ql[:, 32:64 ] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32 q5 = (ql[:, 64:96 ] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32 q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32 q7 = (ql[:, 64:96 ] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32 q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32 # Dequantize return scales * np.concatenate([ sc[:, 0] * q1[:, :16], sc[:, 1] * q1[:, 16:], sc[:, 2] * q2[:, :16], sc[:, 3] * q2[:, 16:], sc[:, 4] * q3[:, :16], sc[:, 5] * q3[:, 16:], sc[:, 6] * q4[:, :16], sc[:, 7] * q4[:, 16:], sc[:, 8] * q5[:, :16], sc[:, 9] * q5[:, 16:], sc[:, 10] * q6[:, :16], sc[:, 11] * q6[:, 16:], sc[:, 12] * q7[:, :16], sc[:, 13] * q7[:, 16:], sc[:, 14] * q8[:, :16], sc[:, 15] * q8[:, 16:], ], axis=1) # @torch.jit.script def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["Q6_K"] ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q6_K"] device = torch.device(device) num_blocks = len(data) // block_size data = np.frombuffer(data, dtype=data.dtype) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) return KTransformersOps.dequantize_q6_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8) def dequantize_iq4_xs(data): # C implementation # https://github.com/ggerganov/ggml/blob/21d3a308fcb7f31cb9beceaeebad4fb622f3c337/src/ggml-quants.c#L3568 # C struct definition # https://github.com/ggerganov/ggml/blob/21d3a308fcb7f31cb9beceaeebad4fb622f3c337/src/ggml-common.h#L393 block_size = GGML_BLOCK_SIZES["IQ4_XS"] num_blocks = len(data) // block_size d = np.frombuffer(data, dtype=np.float16)[0::block_size//2].astype(np.float32).reshape(num_blocks, 1) scales_h = np.frombuffer(data, dtype=np.uint16)[1::block_size//2].reshape(num_blocks, 1) data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)[:, 4:] scales_l = data_u8[:, :4].reshape(num_blocks, 4) qs = data_u8[:, 4:].reshape(num_blocks, block_size - 8) ls = np.zeros((num_blocks, QK_K // 32), dtype=np.int8) for ib in range(QK_K // 32): ls[:, ib] = ((scales_l[:, ib // 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h[:, 0] >> 2 * ib) & 3) << 4) dl = (d * (ls - 32)).reshape(num_blocks, -1, 1) qs_lo_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) & 0xf qs_hi_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) >> 4 y = np.zeros((num_blocks, QK_K), dtype=np.float32) for ib in range(QK_K // 32): y[:, ib*32:(ib*32)+16] = dl[:, ib] * kvalues_iq4nl[qs_lo_4[:, ib]] y[:, (ib*32)+16:(ib*32)+32] = dl[:, ib] * kvalues_iq4nl[qs_hi_4[:, ib]] return y.flatten() def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()): block_size = GGML_BLOCK_SIZES["IQ4_XS"] ele_per_blk = GGML_ELEMENTS_PER_BLOCK["IQ4_XS"] device = torch.device(device) num_blocks = len(data) // block_size data = np.frombuffer(data, dtype=data.dtype) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) return KTransformersOps.dequantize_iq4_xs(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_q4_0(data): # C implementation # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1515 # C struct definition # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L141 num_blocks = len(data) // GGML_BLOCK_SIZES["Q4_0"] scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32) qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:] return np.concatenate([ scales * ((qs & 0xf).astype(np.int8) - 8), scales * ((qs >> 4).astype(np.int8) - 8), ], axis=1) def dequantize_q4_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()): raise NotImplementedError() def dequantize_q5_0(data): # C implementation # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1556 # C struct definition # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L161 num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_0"] scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32) qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4] qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:] bits = np.unpackbits(qh, axis=-1, bitorder="little") x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16 x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16 return np.concatenate([ scales * x0, scales * x1, ], axis=1) def dequantize_q5_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()): raise NotImplementedError() def dequantize_q8_0(data): # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43 num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"] scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32) qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:] return scales * qs def dequantize_q8_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()): # C struct definition # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43 block_size = GGML_BLOCK_SIZES["Q8_0"] ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q8_0"] device = torch.device(device) data = np.frombuffer(data, dtype=data.dtype) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents) return KTransformersOps.dequantize_q8_0(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype) def dequantize_f32(data): return np.frombuffer(data, dtype=np.float32) def dequantize_f32_gpu(data, device, target_dtype = torch.get_default_dtype()): data = np.frombuffer(data, dtype=np.float32) res = torch.from_numpy(data.copy()) res_gpu = torch.empty_like(res, device=device, dtype=target_dtype) res_gpu.copy_(res) return res_gpu def dequantize_f16(data): return np.frombuffer(data, dtype=np.float16) def dequantize_f16_gpu(data, device, target_dtype = torch.get_default_dtype()): data = np.frombuffer(data, dtype=np.float16) res = torch.from_numpy(data.copy()) res_gpu = torch.empty_like(res, device=device, dtype=target_dtype) res_gpu.copy_(res) return res_gpu def dequantize_bf16_gpu(data, device, target_dtype = torch.get_default_dtype()): data = np.frombuffer(data, dtype=np.float16) res = torch.from_numpy(data.copy()) res_gpu = torch.empty_like(res, device=device) res_gpu.copy_(res) return res_gpu GGML_DEQUANTIZE = { "F32": dequantize_f32, "F16": dequantize_f16, "BF16": dequantize_f16, "Q4_0": dequantize_q4_0, "Q5_0": dequantize_q5_0, "Q8_0": dequantize_q8_0, "Q2_K": dequantize_q2_k, "Q3_K": dequantize_q3_k, "Q4_K": dequantize_q4_k, "Q5_K": dequantize_q5_k, "Q6_K": dequantize_q6_k, "IQ4_XS": dequantize_iq4_xs, } GGML_DEQUANTIZE_GPU = { "F32": dequantize_f32_gpu, "F16": dequantize_f16_gpu, "BF16": dequantize_bf16_gpu, "Q4_0": dequantize_q4_0_gpu, "Q5_0": dequantize_q5_0_gpu, "Q8_0": dequantize_q8_0_gpu, "Q2_K": dequantize_q2_k_gpu, "Q3_K": dequantize_q3_k_gpu, "Q4_K": dequantize_q4_k_gpu, "Q5_K": dequantize_q5_k_gpu, "Q6_K": dequantize_q6_k_gpu, "IQ4_XS": dequantize_iq4_xs_gpu, } def translate_name_to_gguf_mixtral(name): replacement_template = { "w1.weight": "ffn_gate", "w2.weight": "ffn_down", "w3.weight": "ffn_up" } pattern = re.compile(r"model.layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.(w\d\.weight)") def replace_match(match): blk_id = match.group(1) expert_id = match.group(2) weight_type = match.group(3) if weight_type in replacement_template: return f"blk.{blk_id}.{replacement_template[weight_type]}.{expert_id}.weight" else: return match.group(0) new_name = re.sub(pattern, replace_match, name) return new_name def translate_name_to_gguf(name): name = translate_name_to_gguf_mixtral(name) if ".ffn_gate_exp." in name: name = name.replace(".ffn_gate_exp.", ".ffn_gate_exps.") if ".ffn_up_exp." in name: name = name.replace(".ffn_up_exp.", ".ffn_up_exps.") if ".ffn_down_exp." in name: name = name.replace(".ffn_down_exp.", ".ffn_down_exps.") m = re.match(r"model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)", name) if m: layer, expert, proj = m.groups() if proj == "gate_proj": return f"blk.{layer}.{expert}.ffn_gate_exps" elif proj == "up_proj": return f"blk.{layer}.{expert}.ffn_up_exps" else: return f"blk.{layer}.{expert}.ffn_down_exps" m = re.match(r"blk\.(\d+)\.mlp\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)", name) if m: layer, expert, proj = m.groups() if proj == "gate_proj": return f"blk.{layer}.{expert}.ffn_gate_exps" elif proj == "up_proj": return f"blk.{layer}.{expert}.ffn_up_exps" else: return f"blk.{layer}.{expert}.ffn_down_exps" name = name.replace("lm_head.", "output.") name = name.replace("model.embed_tokens.", "token_embd.") name = name.replace("model.norm.", "output_norm.") name = name.replace("model.layers.", "blk.") name = name.replace(".input_layernorm", ".attn_norm") name = name.replace(".mlp.down_proj", ".ffn_down") name = name.replace(".mlp.gate_proj", ".ffn_gate") name = name.replace(".mlp.up_proj", ".ffn_up") name = name.replace(".post_attention_layernorm", ".ffn_norm") name = name.replace(".self_attn.q_proj", ".attn_q") name = name.replace(".self_attn.k_proj", ".attn_k") name = name.replace(".self_attn.v_proj", ".attn_v") name = name.replace(".self_attn.o_proj", ".attn_output") name = name.replace(".self_attn.qkv_proj", ".attn_qkv") name = name.replace(".self_attn.kv_a_proj_with_mqa", ".attn_kv_a_mqa") name = name.replace(".self_attn.kv_a_layernorm", ".attn_kv_a_norm") name = name.replace(".self_attn.kv_b_proj", ".attn_kv_b") name = name.replace(".self_attn.q_a_proj", ".attn_q_a") name = name.replace(".self_attn.q_a_layernorm", ".attn_q_a_norm") name = name.replace(".self_attn.q_b_proj", ".attn_q_b") name = name.replace(".self_attn.q_norm", ".attn_q_norm") name = name.replace(".self_attn.k_norm", ".attn_k_norm") name = name.replace(".shared_expert.", ".shared_experts.") name = name.replace(".shared_expert_", ".shared_experts_") name = name.replace(".gate_up_proj.", ".up_proj") name = name.replace(".mlp.shared_experts.down_proj", ".ffn_down_shexp") name = name.replace(".mlp.gate.e_score_correction_bias", ".exp_probs_b.bias") name = name.replace(".mlp.gate", ".ffn_gate_inp") name = name.replace(".mlp.shared_experts.gate_proj", ".ffn_gate_shexp") name = name.replace(".mlp.shared_experts.up_proj", ".ffn_up_shexp") name = name.replace(".mlp.shared_experts_gate", ".ffn_gate_inp_shexp") name = name.replace(".mlp.experts", "") name = name.replace(".mlp.experts.ffn_down_exps", ".ffn_down_exps") name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps") name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps") name = name.replace(".block_sparse_moe.gate.", ".ffn_gate_inp.") name = name.replace(".block_sparse_moe.experts", "") name = name.replace(".feed_forward.experts", "") name = name.replace(".feed_forward.router", ".ffn_gate_inp") name = name.replace(".feed_forward.shared_experts.down_proj", ".ffn_down_shexp") name = name.replace(".feed_forward.shared_experts.gate_proj", ".ffn_gate_shexp") name = name.replace(".feed_forward.shared_experts.up_proj", ".ffn_up_shexp") return name if __name__ == '__main__': gguf_path = '/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH' loader = GGUFLoader(gguf_path) loader.load_gguf_tensor('token_embd.weight') ================================================ FILE: archive/ktransformers/util/custom_loader.py ================================================ import struct import warnings import numpy as np import re import numpy.typing as npt from typing import Sequence import os from enum import IntEnum import torch try: import torch_npu use_torch_npu = torch_npu.npu.is_available() except: use_torch_npu = False if not torch.xpu.is_available() and not use_torch_npu: import KTransformersOps from safetensors import safe_open if not use_torch_npu: from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant from ktransformers.util.custom_gguf import * from safetensors.torch import save_file from abc import ABC, abstractmethod from typing import Dict, Any, Optional, Union class ModelLoader(ABC): """ Abstract base class for model loaders. Defines the interface that all model loaders must implement. """ tensor_file_map = {} @abstractmethod def has_tensor(cls, name: str): """ Check if the tensor exists in the loader. Args: name: Name of the tensor to check Returns: bool: True if the tensor exists, False otherwise """ pass class SafeTensorLoader(ModelLoader): tensor_file_map: dict tensor_type_map: dict file_handle_map: dict tensor_device_map: dict def __init__(self, file_path: str): self.__load_tensor_file_map(file_path) def __load_tensor_file_map(self, file_path: str): # 处理传入路径,确保是文件夹路径 if not os.path.exists(file_path): raise FileNotFoundError(f"Path not found: {file_path}") if os.path.isfile(file_path): folder_path = os.path.dirname(file_path) else: folder_path = file_path self.file_handle_map = {} self.tensor_file_map = {} self.tensor_type_map = {} self.tensor_device_map = {} found_safetensor = False for root, _, files in os.walk(folder_path): files = sorted(files) for file in files: if file.endswith(".safetensors"): found_safetensor = True file_path = os.path.join(root, file) if file not in self.file_handle_map: try: handle = safe_open(file_path, framework="pt") self.file_handle_map[file] = handle except Exception as e: print(f"Error opening Safetensor file {file_path}: {e}") continue f = self.file_handle_map.get(file) if f is None: continue try: for key in f.keys(): self.tensor_file_map[key] = file except Exception as e: print(f"Error reading Safetensor file {file_path}: {e}") # if not found_safetensor: # raise FileNotFoundError(f"No Safetensor files found in {folder_path}") def load_tensor(self, key: str, device: str = "cpu"): if translate_name_to_gguf(key) in self.tensor_file_map: key = translate_name_to_gguf(key) elif key in self.tensor_file_map: pass else: raise KeyError(f"Key {key} not found in Safetensor files") file = self.tensor_file_map[key] f = self.file_handle_map.get(file) if f is None: raise FileNotFoundError(f"File {file} not found in Safetensor files") if use_torch_npu: tensor = f.get_tensor(key).to(torch.float16) else: tensor = f.get_tensor(key) return tensor.to(device) def load_experts(self, key: str, device: str="cpu"): ''' Load experts from safetensor key: the name of the experts device: the device to load the experts to return: dict, {up: tensor, down: tensor, gate: tensor, up_type: int, down_type: int, gate_type: int} {xxx}_type: the type of the up tensor, corresponding to the ggml type ''' if self.has_tensor(translate_name_to_gguf(key)+".ffn_gate_exps.weight"): # legacy branch for loading hybrid model base_key = translate_name_to_gguf(key) # Load experts from safetensor gate_key = f"{base_key}.ffn_gate_exps.weight" gate_type_key = f"{base_key}.ffn_gate_exps.ggml_type" up_key = f"{base_key}.ffn_up_exps.weight" up_type_key = f"{base_key}.ffn_up_exps.ggml_type" down_key = f"{base_key}.ffn_down_exps.weight" down_type_key = f"{base_key}.ffn_down_exps.ggml_type" gate_tensor = self.load_tensor(gate_key, device).numpy() up_tensor = self.load_tensor(up_key, device).numpy() down_tensor = self.load_tensor(down_key, device).numpy() gate_type = self.load_tensor(gate_type_key, device).item() up_type = self.load_tensor(up_type_key, device).item() down_type = self.load_tensor(down_type_key, device).item() return { "up": up_tensor, "gate": gate_tensor, "down": down_tensor, "up_type": up_type, "gate_type": gate_type, "down_type": down_type } else: # Load experts from safetensor base_key = key # e.g. "model.layers.3.mlp.experts" experts_count = 0 key_no_proj = False if self.has_tensor(f"{base_key}.{experts_count}.up.weight"): key_no_proj = True # First, count how many experts we have by checking for expert 0's up_proj while self.has_tensor(f"{base_key}.{experts_count}.up_proj.weight") or self.has_tensor(f"{base_key}.{experts_count}.up.weight"): experts_count += 1 if experts_count == 0: raise ValueError(f"No experts found for key {base_key}") # Initialize empty lists to store tensors for each projection type up_projs = [] gate_projs = [] down_projs = [] # Load all expert weights for expert_id in range(experts_count): if key_no_proj: up_key = f"{base_key}.{expert_id}.up.weight" gate_key = f"{base_key}.{expert_id}.gate.weight" down_key = f"{base_key}.{expert_id}.down.weight" else: up_key = f"{base_key}.{expert_id}.up_proj.weight" gate_key = f"{base_key}.{expert_id}.gate_proj.weight" down_key = f"{base_key}.{expert_id}.down_proj.weight" up_tensor = self.load_tensor(up_key, device) gate_tensor = self.load_tensor(gate_key, device) down_tensor = self.load_tensor(down_key, device) up_projs.append(up_tensor) gate_projs.append(gate_tensor) down_projs.append(down_tensor) # Stack the tensors along a new dimension up_tensor = torch.stack(up_projs, dim=0) gate_tensor = torch.stack(gate_projs, dim=0) down_tensor = torch.stack(down_projs, dim=0) # Get original dtype for GGML type determination orig_up_dtype = up_tensor.dtype orig_gate_dtype = gate_tensor.dtype orig_down_dtype = down_tensor.dtype # Convert to numpy with proper bfloat16 support up_numpy = up_tensor.view(torch.uint16).numpy() gate_numpy = gate_tensor.view(torch.uint16).numpy() down_numpy = down_tensor.view(torch.uint16).numpy() # Determine tensor data types for GGML conversion def get_ggml_type(dtype): if dtype == torch.float32: return GGMLQuantizationType.F32 elif dtype == torch.float16: return GGMLQuantizationType.F16 elif dtype == torch.bfloat16: return GGMLQuantizationType.BF16 else: raise ValueError(f"Unsupported tensor dtype: {dtype}") return { "up": up_numpy, "gate": gate_numpy, "down": down_numpy, "up_type": get_ggml_type(orig_up_dtype), "gate_type": get_ggml_type(orig_gate_dtype), "down_type": get_ggml_type(orig_down_dtype) } def load_gate(self, key: str, device: str="cpu"): ''' Load gate from safetensor key: the name of the gate device: the device to load the gate to return: dict, {'weight': tensor, 'e_score_correction_bias': tensor} ''' target = ["weight", "e_score_correction_bias"] res = {'weight': None, 'e_score_correction_bias': None} if self.has_tensor(translate_name_to_gguf(key)+".ffn_gate_exps.weight"): # legacy branch for loading hybrid model base_key = key for k in target: translated_key = translate_name_to_gguf(f"{base_key}.{k}") if self.has_tensor(translated_key): tensor = self.load_tensor(translated_key, device) res[k] = tensor else: # Load gate from safetensor base_key = key for k in target: if self.has_tensor(f"{base_key}.{k}"): tensor = self.load_tensor(f"{base_key}.{k}", device) res[k] = tensor return res def close_all_handles(self): for handle in self.file_handle_map.values(): handle.close() self.file_handle_map.clear() def load_dequantized_tensor(self, key: str, device: str = "cpu"): if key in self.tensor_file_map and translate_name_to_gguf(key): pass elif translate_name_to_gguf(key) in self.tensor_file_map: key = translate_name_to_gguf(key) else: raise KeyError(f"Key {key} not found in Safetensor files") file = self.tensor_file_map[key] f = self.file_handle_map.get(file) if f is None: raise FileNotFoundError(f"File {file} not found in Safetensor files") tensor = f.get_tensor(key).to(device) if key.endswith(".weight"): if key[:-7] + ".weight_scale_inv" in self.tensor_file_map: weight_scale_inv = f.get_tensor(key[:-7] + ".weight_scale_inv").to(device) tensor = weight_dequant(tensor, weight_scale_inv) return tensor.to(device) def has_tensor(self, name: str): return name in self.tensor_file_map or translate_name_to_gguf(name) in self.tensor_file_map class GGUFLoader(ModelLoader): tensor_info: dict gguf_path: str tensor_file_map: dict # {tensor_name: tensor_file_path} gguf_file_meta: dict safetensor_loader: SafeTensorLoader def __init__(self, gguf_path: str, quantize: str = None): # Check dir exist if not os.path.exists(gguf_path): raise FileNotFoundError(f"GGUF dir not found: {gguf_path}") if os.path.isfile(gguf_path): gguf_path = os.path.dirname(gguf_path) self.safetensor_loader = None self.tensor_info = {} self.gguf_path = gguf_path self.tensor_file_map = {} self.file_data_map = {} self.gguf_file_meta = {} self.tensor_device_map = {} if use_torch_npu: if quantize == "w8a8_dynamic": safetensor_loader = W8A8SafeTensorLoader(gguf_path) else: safetensor_loader = SafeTensorLoader(gguf_path) if safetensor_loader.tensor_file_map: self.safetensor_loader = safetensor_loader return # Walk through all the .gguf files in the directory found_gguf = False for root, dirs, files in os.walk(gguf_path): for file in files: if file.endswith(".gguf"): found_gguf = True file_name = os.path.join(root, file) with open(file_name, "rb") as f: self.load_gguf(f) if file_name not in self.file_data_map: self.file_data_map[file_name] = np.memmap(file_name, mode = 'r') if not found_gguf: raise FileNotFoundError(f"Cannot find any .gguf files in: {gguf_path}") def load_gguf(self, f): f.seek(0) assert f.read(4) == b'GGUF' values = struct.unpack("torch.Tensor: name = translate_name_to_gguf(name) t = self.tensor_info[name] shape = t["shape"] ggml_type = t["ggml_type"] if ggml_type not in GGML_NAMES: raise NotImplementedError(f"ggml_type {ggml_type} not implemented") ggml_name = GGML_NAMES[ggml_type] # TODO: experts may fused in quant block, split it assert elements_per_expert % GGML_ELEMENTS_PER_BLOCK[ggml_name] == 0, "experts may fused in quant block, please use CPU dequant" blocks_per_experts = elements_per_expert // GGML_ELEMENTS_PER_BLOCK[ggml_name] block_size = GGML_BLOCK_SIZES[ggml_name] offset = expert_id * block_size * blocks_per_experts data = data[offset: offset + block_size * blocks_per_experts] if "cuda" in device.lower(): values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype) else: values = GGML_DEQUANTIZE[ggml_name](data) values = torch.from_numpy(values.copy()) if ggml_name == "BF16": values = values.view(torch.bfloat16) values = values.view(shape[-2::-1]) return values def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->torch.Tensor: name = translate_name_to_gguf(name) t = self.tensor_info[name] if target_dtype == None: target_dtype = torch.get_default_dtype() shape = t["shape"] ggml_type = t["ggml_type"] if ggml_type not in GGML_NAMES: raise NotImplementedError(f"ggml_type {ggml_type} not implemented") ggml_name = GGML_NAMES[ggml_type] data = self.get_mmap_tensor(name) block_size = GGML_BLOCK_SIZES[ggml_name] elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name] num_elements = int(np.prod(shape)) num_blocks = num_elements // elements_per_block blocks_per_iter = 16384 if num_blocks > blocks_per_iter: # dequant large tensor values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device) for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter): blocks_begin = i * blocks_per_iter blocks_end = min(blocks_begin + blocks_per_iter, num_blocks) if "cuda" in device.lower(): try: cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype) except: cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size]) cur_values = torch.from_numpy(cur_values.copy()).to(device) else: cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size]) cur_values = torch.from_numpy(cur_values.copy()) cur_values = cur_values.view(-1, elements_per_block) if ggml_name == "BF16": cur_values = cur_values.view(torch.bfloat16) values[blocks_begin : blocks_end] = cur_values else: if "cuda" in device.lower(): values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) else: np_values = np.copy(GGML_DEQUANTIZE[ggml_name](data)) values = torch.from_numpy(np_values).to(device) del np_values if ggml_name == "BF16": values = values.view(torch.bfloat16) values = values.view(shape[::-1]) if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: n_head = self.gguf_file_meta['llama.attention.head_count'] values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:]) .swapaxes(1, 2) .reshape(values.shape)) elif "attn_k" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: n_head = self.gguf_file_meta['llama.attention.head_count_kv'] values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:]) .swapaxes(1, 2) .reshape(values.shape)) return values def has_tensor(self, name: str): name = translate_name_to_gguf(name) return name in self.tensor_info def get_ggml_type(self, name: str): name = translate_name_to_gguf(name) if name not in self.tensor_info: raise KeyError(f"Key {name} not found in GGUF files") return self.tensor_info[name]["ggml_type"] class ModelLoaderFactory: """ Factory class for creating model loaders. Automatically detects the model format based on file extensions in the directory. """ @staticmethod def create_loader(path: str): """ Create a model loader for the given path by detecting the model format. The function checks for the presence of .safetensors or .gguf files in the specified path and creates the appropriate loader. Args: path: Path to the model directory or file Returns: An appropriate ModelLoader instance (SafeTensorLoader or GGUFLoader) Raises: FileNotFoundError: If no supported model files are found in the path """ if not os.path.exists(path): raise FileNotFoundError(f"Path not found: {path}") # Normalize to directory path if a file was provided if os.path.isfile(path): if path.endswith(".safetensors"): return SafeTensorLoader(path) elif path.endswith(".gguf"): return GGUFLoader(path) else: folder_path = os.path.dirname(path) else: folder_path = path # Check for safetensors files has_safetensors = False has_gguf = False for root, _, files in os.walk(folder_path): for file in files: if file.endswith(".safetensors"): has_safetensors = True break elif file.endswith(".gguf"): has_gguf = True break if has_safetensors or has_gguf: break # Create the appropriate loader based on detected file types # Prioritize SafeTensor over GGUF if both are present if has_safetensors: try: return SafeTensorLoader(folder_path) except Exception as e: print(f"Failed to create SafeTensorLoader: {e}") # Fall through to try GGUF if SafeTensor fails if not has_gguf: raise if has_gguf: try: return GGUFLoader(folder_path) except Exception as e: print(f"Failed to create GGUFLoader: {e}") raise # No supported model files found raise FileNotFoundError(f"No .safetensors or .gguf files found in: {folder_path}") class W8A8SafeTensorLoader(SafeTensorLoader): def load_tensor(self, key: str, device: str = "cpu"): if key not in self.tensor_file_map: raise KeyError(f"Key {key} not found in Safetensor files") file = self.tensor_file_map[key] f = self.file_handle_map.get(file) if f is None: raise FileNotFoundError(f"File {file} not found in Safetensor files") tensor = f.get_tensor(key) if 'deq_scale' in key: tensor = torch.from_numpy( np.frombuffer(tensor.to(torch.float16).to(torch.float32).numpy().tobytes(), dtype=np.int32).astype(np.int64)) if 'input_scale' in key: tensor = tensor.to(torch.float16) if "weight_scale" in key or "weight_offset" in key: if "ffn" in key: tensor = tensor.to(torch.float32) else: tensor = tensor.to(torch.float16) if 'input_offset' in key: tensor = tensor.to(torch.int8) if tensor.dtype == torch.bfloat16: tensor = tensor.to(torch.float16) return tensor.to(device) def load_dequantized_tensor(self, key: str, device: str = "cpu"): tensor = self.load_tensor(key, device) return tensor ================================================ FILE: archive/ktransformers/util/modeling_rope_utils.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Optional, Tuple from transformers.configuration_utils import PretrainedConfig from transformers.utils import is_torch_available, logging logger = logging.get_logger(__name__) if is_torch_available(): import torch def _compute_default_rope_parameters( config: Optional[PretrainedConfig] = None, device: Optional["torch.device"] = None, seq_len: Optional[int] = None, **rope_kwargs, ) -> Tuple["torch.Tensor", float]: """ Computes the inverse frequencies according to the original RoPE implementation Args: config ([`~transformers.PretrainedConfig`]): The model configuration. device (`torch.device`): The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. rope_kwargs (`Dict`, *optional*): BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ if config is not None and len(rope_kwargs) > 0: raise ValueError( "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" ) if len(rope_kwargs) > 0: base = rope_kwargs["base"] dim = rope_kwargs["dim"] elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) attention_factor = 1.0 # Unused in this type of RoPE # Compute the inverse frequencies inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) return inv_freq, attention_factor def _compute_linear_scaling_rope_parameters( config: Optional[PretrainedConfig] = None, device: Optional["torch.device"] = None, seq_len: Optional[int] = None, **rope_kwargs, ) -> Tuple["torch.Tensor", float]: """ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev Args: config ([`~transformers.PretrainedConfig`]): The model configuration. device (`torch.device`): The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. rope_kwargs (`Dict`, *optional*): BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ if config is not None and len(rope_kwargs) > 0: raise ValueError( "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" ) if len(rope_kwargs) > 0: factor = rope_kwargs["factor"] elif config is not None: factor = config.rope_scaling["factor"] # Gets the default RoPE parameters inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) # Then applies linear scaling to the frequencies. # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so # applying scaling to the inverse frequencies is equivalent. inv_freq /= factor return inv_freq, attention_factor def _compute_dynamic_ntk_parameters( config: Optional[PretrainedConfig] = None, device: Optional["torch.device"] = None, seq_len: Optional[int] = None, **rope_kwargs, ) -> Tuple["torch.Tensor", float]: """ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla Args: config ([`~transformers.PretrainedConfig`]): The model configuration. device (`torch.device`): The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length, used to update the dynamic RoPE at inference time. rope_kwargs (`Dict`, *optional*): BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling if config is not None and len(rope_kwargs) > 0: raise ValueError( "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" ) if len(rope_kwargs) > 0: base = rope_kwargs["base"] dim = rope_kwargs["dim"] max_position_embeddings = rope_kwargs["max_position_embeddings"] factor = rope_kwargs["factor"] elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) max_position_embeddings = config.max_position_embeddings factor = config.rope_scaling["factor"] attention_factor = 1.0 # Unused in this type of RoPE # seq_len: default to max_position_embeddings, e.g. at init time seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings # Compute the inverse frequencies base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) return inv_freq, attention_factor def _compute_yarn_parameters( config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs ) -> Tuple["torch.Tensor", float]: """ Computes the inverse frequencies with NTK scaling. Please refer to the [original paper](https://arxiv.org/abs/2309.00071) Args: config ([`~transformers.PretrainedConfig`]): The model configuration. device (`torch.device`): The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. rope_kwargs (`Dict`, *optional*): BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. """ # No need to keep BC with yarn, unreleased when this new pattern was created. if len(rope_kwargs) > 0: raise ValueError( f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}" ) base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 head_dim = getattr(config, "qk_rope_head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) factor = config.rope_scaling["factor"] attention_factor = config.rope_scaling.get("attention_factor") mscale = config.rope_scaling.get("mscale") mscale_all_dim = config.rope_scaling.get("mscale_all_dim") # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two # values to compute the default attention scaling factor, instead of using `factor`. if "original_max_position_embeddings" in config.rope_scaling: original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"] factor = config.max_position_embeddings / original_max_position_embeddings else: original_max_position_embeddings = config.max_position_embeddings def get_mscale(scale, mscale=1): if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 # Sets the attention factor as suggested in the paper if attention_factor is None: if mscale and mscale_all_dim: attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)) else: attention_factor = get_mscale(factor) # Optional config options # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) beta_fast = config.rope_scaling.get("beta_fast") or 32 beta_slow = config.rope_scaling.get("beta_slow") or 1 # Compute the inverse frequencies def find_correction_dim(num_rotations, dim, base, max_position_embeddings): """Inverse dimension formula to find the dimension based on the number of rotations""" return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): """Find dimension range bounds based on rotations""" low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) return max(low, 0), min(high, dim - 1) def linear_ramp_factor(min, max, dim): if min == max: max += 0.001 # Prevent singularity linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs # to expand the possible context length. In other words, interpolation = apply scaling factor. pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (factor * pos_freqs) low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings) # Get n-dimensional rotational scaling corrected for extrapolation inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) inv_freq = ( inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + inv_freq_extrapolation * inv_freq_extrapolation_factor ) return inv_freq, attention_factor def _compute_longrope_parameters( config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs ) -> Tuple["torch.Tensor", float]: """ Computes the inverse frequencies with LongRoPE scaling. Please refer to the [original implementation](https://github.com/microsoft/LongRoPE) Args: config ([`~transformers.PretrainedConfig`]): The model configuration. device (`torch.device`): The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. rope_kwargs (`Dict`, *optional*): BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. """ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling # No need to keep BC with longrope, unreleased when this new pattern was created. if len(rope_kwargs) > 0: raise ValueError( "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got " f"{rope_kwargs}" ) base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) long_factor = config.rope_scaling["long_factor"] short_factor = config.rope_scaling["short_factor"] factor = config.rope_scaling.get("factor") attention_factor = config.rope_scaling.get("attention_factor") # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two # values to compute the default attention scaling factor, instead of using `factor`. if hasattr(config, "original_max_position_embeddings"): original_max_position_embeddings = config.original_max_position_embeddings factor = config.max_position_embeddings / config.original_max_position_embeddings else: original_max_position_embeddings = config.max_position_embeddings # Sets the attention factor as suggested in the paper if attention_factor is None: if factor <= 1.0: attention_factor = 1.0 else: attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings)) # Compute the inverse frequencies -- scaled based on the target sequence length if seq_len and seq_len > original_max_position_embeddings: ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) else: ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device) inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) return inv_freq, attention_factor def _compute_llama3_parameters( config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs ) -> Tuple["torch.Tensor", float]: """ Computes the inverse frequencies for llama 3.1. Args: config ([`~transformers.PretrainedConfig`]): The model configuration. device (`torch.device`): The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. rope_kwargs (`Dict`, *optional*): BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. """ # Gets the default RoPE parameters inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) factor = config.rope_scaling["factor"] # `8` in the original implementation low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor wavelen = 2 * math.pi / inv_freq # wavelen < high_freq_wavelen: do nothing # wavelen > low_freq_wavelen: divide by factor inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) # otherwise: interpolate between the two, using a smooth factor smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) return inv_freq_llama, attention_factor # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE # parameterizations, as long as the callable has the same signature. ROPE_INIT_FUNCTIONS = { "default": _compute_default_rope_parameters, "linear": _compute_linear_scaling_rope_parameters, "dynamic": _compute_dynamic_ntk_parameters, "yarn": _compute_yarn_parameters, "longrope": _compute_longrope_parameters, "llama3": _compute_llama3_parameters, } def _check_received_keys( rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None, ignore_keys: Optional[set] = None, ): """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present if "type" in received_keys: received_keys -= {"type"} required_keys.add("rope_type") # Some models need to store model-specific keys, and we don't want to throw warning at them if ignore_keys is not None: received_keys -= ignore_keys missing_keys = required_keys - received_keys if missing_keys: raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") if optional_keys is not None: unused_keys = received_keys - required_keys - optional_keys else: unused_keys = received_keys - required_keys if unused_keys: logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type"} received_keys = set(rope_scaling.keys()) _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} received_keys = set(rope_scaling.keys()) _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` optional_keys = {"original_max_position_embeddings"} received_keys = set(rope_scaling.keys()) _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} optional_keys = { "attention_factor", "beta_fast", "beta_slow", "original_max_position_embeddings", "mscale", "mscale_all_dim", } received_keys = set(rope_scaling.keys()) _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") attention_factor = rope_scaling.get("attention_factor") if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): logger.warning( f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" ) beta_fast = rope_scaling.get("beta_fast") if beta_fast is not None and not isinstance(beta_fast, float): logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") beta_slow = rope_scaling.get("beta_slow") if beta_slow is not None and not isinstance(beta_slow, float): logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") if (beta_fast or 32) < (beta_slow or 1): logger.warning( f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" ) def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "short_factor", "long_factor"} # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} received_keys = set(rope_scaling.keys()) _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) short_factor = rope_scaling.get("short_factor") if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") if not len(short_factor) == dim // 2: logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}") long_factor = rope_scaling.get("long_factor") if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") if not len(long_factor) == dim // 2: logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is # unique to longrope (= undesirable) if hasattr(config, "original_max_position_embeddings"): logger.warning_once( "This model has set a `original_max_position_embeddings` field, to be used together with " "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`" "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, " "as it is compatible with most model architectures." ) else: factor = rope_scaling.get("factor") if factor is None: logger.warning("Missing required keys in `rope_scaling`: 'factor'") elif not isinstance(factor, float) or factor < 1.0: logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") attention_factor = rope_scaling.get("attention_factor") if attention_factor is not None: if not isinstance(attention_factor, float) or attention_factor < 0.0: logger.warning( f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" ) def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} received_keys = set(rope_scaling.keys()) _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") low_freq_factor = rope_scaling["low_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"] if low_freq_factor is None or not isinstance(low_freq_factor, float): logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") if high_freq_factor is None or not isinstance(high_freq_factor, float): logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") if high_freq_factor <= low_freq_factor: logger.warning( "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" f"{high_freq_factor} and low_freq_factor={low_freq_factor}" ) original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): logger.warning( "`rope_scaling`'s original_max_position_embeddings field must be an integer, got " f"{original_max_position_embeddings}" ) if original_max_position_embeddings >= config.max_position_embeddings: logger.warning( "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got " f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}" ) # Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. ROPE_VALIDATION_FUNCTIONS = { "default": _validate_default_rope_parameters, "linear": _validate_linear_scaling_rope_parameters, "dynamic": _validate_dynamic_scaling_rope_parameters, "yarn": _validate_yarn_parameters, "longrope": _validate_longrope_parameters, "llama3": _validate_llama3_parameters, } def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None): """ Validate the RoPE config arguments, given a `PretrainedConfig` object """ rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig` if rope_scaling is None: return # BC: "rope_type" was originally "type" rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) if validation_fn is not None: validation_fn(config, ignore_keys=ignore_keys) else: logger.warning( f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" ) ================================================ FILE: archive/ktransformers/util/npu_graph_runner.py ================================================ ''' Description : Author : Boxin Zhang Version : 0.1.0 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' from typing import Dict import threading import torch import torch_npu class NPUGraphRunner: def __init__(self, deviceId): torch.npu.set_compile_mode(jit_compile=False) self.deviceId = deviceId self.input_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {} self.past_key_value = None def init(self, batch_size, seq_length): self.graph = torch.npu.NPUGraph() self.main_stream = torch_npu.npu.Stream(device=self.deviceId) self.share_experts_stream = torch_npu.npu.Stream(device=self.deviceId) self.logits = torch.zeros((batch_size, seq_length, 7168), dtype=torch.float16).to(self.deviceId) # deepseekV3 hidden_size self.workspace = None self.model_capture = True torch_npu.npu._subscribe_report(self.main_stream) def destroy(self): torch_npu.npu._unsubscribe_report(self.main_stream) del self.graph destory_runner(self.deviceId) def capture( self, model, cur_token, position_ids, cache_position, past_key_values, main_device, **kwargs, ) -> None: inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(main_device) with torch.no_grad(): with torch.npu.graph(self.graph, stream=self.main_stream, auto_dispatch_capture=True): logits = model(inputs_embeds=inputs_embeds, position_ids=position_ids, cache_position=cache_position, past_key_values=past_key_values, is_prefill=False, **kwargs) self.input_buffers = { "inputs_embeds": inputs_embeds, "position_ids": position_ids, "cache_position": cache_position, } self.output_buffers = { "logits": logits, } def forward( self, inputs_embeds, position_ids, cache_position, ) -> torch.Tensor: thread = threading.Thread(target=self.graph.update, kwargs={"cpu_update_input": [{"actual_seq_lengths_kv": self.past_key_value.position}]}) thread.start() self.input_buffers["inputs_embeds"].copy_(inputs_embeds) self.input_buffers["position_ids"].copy_(position_ids) self.input_buffers["cache_position"].copy_(cache_position) torch_npu.npu.synchronize() with torch_npu.npu.stream(self.main_stream): # Run the graph. self.graph.replay() thread.join() # Return the output tensor. return self.output_buffers["logits"] def launch_callback(self, func, data, block, stream): torch_npu.npu._launch_host_func(stream, func, data) def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) runner_dict = dict() def check_runner(deviceId: int): runner = runner_dict.get(deviceId) if runner is None: return True else: return False def destory_runner(deviceId: int): # print("the new NPUGraphRunner and deviceId is ", deviceId) runner = runner_dict.get(deviceId) if runner is not None: runner_dict[deviceId] = None def get_or_create_runner(deviceId: int): runner = runner_dict.get(deviceId) if runner is None: runner = NPUGraphRunner(deviceId) runner_dict[deviceId] = runner return runner ================================================ FILE: archive/ktransformers/util/textstream.py ================================================ from typing import Any, List, Optional, Set class TextStreamer: def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): self.tokenizer = tokenizer self.skip_prompt = skip_prompt self.decode_kwargs = decode_kwargs # variables used in the streaming process self.token_cache = [] self.print_len = 0 self.next_tokens_are_prompt = True def reset(self): self.token_cache = [] self.print_len = 0 def put(self, value)->Optional[str]: """ Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. """ if not isinstance(value,int): raise ValueError("TextStreamer only supports batch size 1, and int type input") if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False return None # Add the new token to the cache and decodes the entire thing. self.token_cache.append(value) text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True,**self.decode_kwargs) # After the symbol for a new line, we flush the cache. if text.endswith("\n"): printable_text = text[self.print_len :] self.reset() # If the last token is a CJK character, we print the characters. elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): printable_text = text[self.print_len :] self.print_len += len(printable_text) # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, # which may change with the subsequent token -- there are probably smarter ways to do this!) else: printable_text = text[self.print_len : text.rfind(" ") + 1] self.print_len += len(printable_text) return printable_text def end(self)->Optional[str]: """Flushes any remaining cache and prints a newline to stdout.""" # Flush the cache, if it exists if len(self.token_cache) > 0: text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs) printable_text = text[self.print_len :] self.reset() else: printable_text = "" self.next_tokens_are_prompt = True return printable_text def _is_chinese_char(self, cp): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) # # Note that the CJK Unicode block is NOT all Japanese and Korean characters, # despite its name. The modern Korean Hangul alphabet is a different block, # as is Japanese Hiragana and Katakana. Those alphabets are used to write # space-separated words, so they are not treated specially and handled # like the all of the other languages. if ( (cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) # or (cp >= 0x20000 and cp <= 0x2A6DF) # or (cp >= 0x2A700 and cp <= 0x2B73F) # or (cp >= 0x2B740 and cp <= 0x2B81F) # or (cp >= 0x2B820 and cp <= 0x2CEAF) # or (cp >= 0xF900 and cp <= 0xFAFF) or (cp >= 0x2F800 and cp <= 0x2FA1F) # ): # return True return False ================================================ FILE: archive/ktransformers/util/utils.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : Boxin Zhang, Azure-Tang Version : 0.1.0 Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import re import sys import threading import torch import torch.distributed as dist from torch import nn import itertools import time import enum from transformers import ( LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, MinPLogitsWarper, TypicalLogitsWarper, EpsilonLogitsWarper, EtaLogitsWarper, ) from ktransformers.util.custom_loader import ModelLoaderFactory, ModelLoader, SafeTensorLoader, translate_name_to_gguf from ktransformers.operators import base_operator from ktransformers.models.custom_cache import StaticCache from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.textstream import TextStreamer if not torch.xpu.is_available(): from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton # from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton import socket warm_uped = False CUR_DEVICE = None W8A8_ENABLE = False Q4_GGUF_LODER = None _USE_NPU_GRAPH = False _MAX_DECODE_PROFILE = 1 WARM_UP_SKIP_CNT = [1, 1] _SPECULATE_STEP = 1 try: import torch_npu use_torch_npu = torch_npu.npu.is_available() from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size except: use_torch_npu = False def get_use_npu_graph(): assert _USE_NPU_GRAPH is not None, "use npu graph is not setting" return _USE_NPU_GRAPH from enum import StrEnum class StatKey(StrEnum): Embedding = "Embedding" GraphCapture = "GraphCapture" GraphReplay = "GraphReplay" ExpertsForward1 = "ExpertsForward1" ExpertsForward2 = "ExpertsForward2" CPUExperts = "CPUExperts" GraphDestroy = "GraphDestroy" DecodeOneTokenPost = "DecodeOneTokenPost" DecodeOneToken = "DecodeOneToken" GraphInit = "GraphInit" class TimeStat: def __init__(self): # open_status = os.environ["KT_PERF_STAT"] if "KT_PERF_STAT" in os.environ else "0" # if open_status == "0": # self.on = False # else: # self.on = True self.on = True self.prefill_stats = dict() self.decode_stats = dict() for key in StatKey: self.prefill_stats[key] = StatItem() self.decode_stats[key] = StatItem() self.reset_all() def record_start_time(self): start_time = time.time_ns() return start_time def add_time_stat(self, key: StatKey, time_ns, is_prefill): if not key: return # torch.cuda.synchronize() cost = time.time_ns() - time_ns if is_prefill: item = self.prefill_stats[key] else: item = self.decode_stats[key] item.add_item(cost) def print_all(self): # rank = f"[rank:{torch.distributed.get_rank()}]" rank = f"[rank:0]" msg = f"\n{rank} Prefill Time Stat\n" msg += rank + " {:27}{:>15}{:>15}{:>15}{:>15}{:>15}\n".format("", "min(ms)", "max(ms)", "avg(ms)", "count", "total(ms)") for key, value in self.prefill_stats.items(): msg += rank + f" {key.value:<25}:{value.get_stat()}\n" msg += f"\n{rank} Decode Time Stat\n" msg += rank + " {:27}{:>15}{:>15}{:>15}{:>15}{:>15}\n".format("", "min(ms)", "max(ms)", "avg(ms)", "count", "total(ms)") for key, value in self.decode_stats.items(): msg += rank + f" {key.value:<25}:{value.get_stat()}\n" print(msg) def reset_all(self): for _, value in self.prefill_stats.items(): value.reset() for _, value in self.decode_stats.items(): value.reset() class StatItem: def __init__(self): self.min_time = 100000000 self.max_time = 0 self.total_time_ns = 0 self.count = 0 def add_item(self, cost_time_ns): self.count += 1 self.total_time_ns += cost_time_ns self.min_time = min(self.min_time, cost_time_ns) self.max_time = max(self.max_time, cost_time_ns) def reset(self): self.min_time = 100000000 self.max_time = 0 self.total_time_ns = 0 self.count = 0 def get_stat(self): min_time = self.min_time / 1000 / 1000 max_time = self.max_time / 1000 / 1000 if self.count != 0: avg_time = self.total_time_ns / self.count / 1000 / 1000 else: avg_time = 0 total = self.total_time_ns / 1000 / 1000 return f"{min_time:15.2f}{max_time:15.2f}{avg_time:15.2f}{self.count:15}{total:15.2f}" timeStat = TimeStat() def get_free_ports(n: int, continue_prot: list): sockets = [] ports = [] for _ in range(n): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) port = s.getsockname()[1] if port in continue_prot: s.close() continue ports.append(port) sockets.append(s) for s in sockets: s.close() return ports def get_current_device(): if use_torch_npu: return f"npu:{torch.npu.current_device()}" else: return f"cuda:{torch.npu.current_device()}" def get_compute_capability(device:torch.device = None): if use_torch_npu: return 0 if torch.cuda.is_available(): if device is None: num_gpus = torch.cuda.device_count() min_compute_capability_major = 100 for gpu_id in range(num_gpus): gpu_props = torch.cuda.get_device_properties(gpu_id) min_compute_capability_major = min(min_compute_capability_major, gpu_props.major) return min_compute_capability_major else: return torch.cuda.get_device_properties(device) def set_module(model, submodule_key, module): tokens = submodule_key.split('.') sub_tokens = tokens[:-1] cur_mod = model for s in sub_tokens: if hasattr(cur_mod, s): cur_mod = getattr(cur_mod, s) else: # nn.ModuleList or nn.ModuleList cur_mod=cur_mod[int(s)] if hasattr(cur_mod, tokens[-1]): setattr(cur_mod, tokens[-1], module) else: # nn.ModuleList or nn.ModuleList cur_mod[int(tokens[-1])] = module def set_param(module: nn.Module, name: str, weights: torch.Tensor): param=nn.parameter.Parameter(weights, requires_grad=False) if isinstance(module, nn.Linear) and len(weights.shape)==1: param.unsqueeze_(0) setattr(module, name, param) def get_device(gguf_module_key:str, device_map:dict): if gguf_module_key in device_map: return device_map[gguf_module_key]["generate_device"] else: return "cuda" def get_all_used_cuda_device(device_map:dict): all_device_list = set() for key in device_map: all_device_list.add(device_map[key]["generate_device"]) if "generate_device" in device_map[key] else None all_device_list.add(device_map[key]["prefill_device"]) if "prefill_device" in device_map[key] else None if "cpu" in all_device_list: all_device_list.remove("cpu") if use_torch_npu: all_device_list = set([device.replace('cuda', 'npu') for device in all_device_list]) all_device_list = list(all_device_list) return all_device_list def load_cur_state_dict_npu(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="npu"): prefix = prefix.replace("orig_module.", "") persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set} local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items()) local_state = {k: v for k, v in local_name_params if v is not None} for name, param in local_state.items(): key = prefix + name translated_key = translate_name_to_gguf(key) # TODO: Merge all loader. # I know this is ugly but lets do it for now. if gguf_loader.safetensor_loader is not None: load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map else: load_dequantized_tensor = gguf_loader.load_gguf_tensor tensor_file_map = gguf_loader.tensor_file_map if translated_key in tensor_file_map: target_dtype = torch.get_default_dtype() device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map) # Todo need fix device = "cpu" if "embd" in translated_key else get_current_device() print(f"loading layer {translated_key} to {device}") torch.cuda.empty_cache() weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype) set_param(module, name, weights) del weights else: #print(load_config.tensor_file_map.keys()) raise Exception(f"can't find {translated_key} in GGUF file!") def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"): if use_torch_npu: load_cur_state_dict_npu(module, gguf_loader, prefix, device) return prefix = prefix.replace("orig_module.", "") persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set} local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items()) local_state = {k: v for k, v in local_name_params if v is not None} for name, param in local_state.items(): key = prefix + name translated_key = key # TODO: Merge all loader. # I know this is ugly but lets do it for now. if isinstance(gguf_loader, SafeTensorLoader): load_dequantized_tensor = gguf_loader.load_dequantized_tensor else: load_dequantized_tensor = gguf_loader.load_gguf_tensor tensor_file_map = gguf_loader.tensor_file_map if gguf_loader.has_tensor(translated_key) or "kv_b_proj" in translated_key: target_dtype = torch.get_default_dtype() device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map) print(f"loading {translated_key} to {device}") if torch.cuda.is_available(): torch.cuda.empty_cache() elif torch.xpu.is_available(): torch.xpu.empty_cache() if "kv_b_proj" in translated_key and not gguf_loader.has_tensor(translated_key): attn_k_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_k_b"), device=device).to(dtype=target_dtype) attn_k_b = attn_k_b.transpose(1, 2).contiguous() attn_v_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_v_b"), device=device).to(dtype=target_dtype) kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1) kv_b_proj = kv_b_proj.contiguous() if kv_b_proj.ndim == 2 else kv_b_proj.flatten(0, 1).contiguous() set_param(module, name, kv_b_proj) del attn_k_b del attn_v_b else: weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype) set_param(module, name, weights) del weights else: #print(load_config.tensor_file_map.keys()) raise Exception(f"can't find {translated_key} in GGUF file!") def sync_all_device(all_device_list): for device in all_device_list: if "cuda" in device.lower(): torch.cuda.synchronize(device) elif "xpu" in device.lower(): torch.xpu.synchronize(device) elif use_torch_npu: torch_npu.synchronize(device) else: raise RuntimeError("The device {} is not available".format(device)) torch_device_mapping ={"cuda": "cuda:0", "xpu": "xpu:0"} def xpu_fp16_model(config): # This function is to check if we run this model on XPU with FP16 dtype if not torch.xpu.is_available(): return False if config.architectures[0] == "DeepseekV3ForCausalLM": return True if config.architectures[0] == "Qwen3MoeForCausalLM" and config.hidden_size == 4096: # Qwen3-30B seems have precision issue with FP16 # so we only use FP16 for Qwen3-235B now return True return False def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"): #print(f"recursively loading weights {prefix}") if not isinstance(module, base_operator.BaseInjectedModule): load_cur_state_dict(module, gguf_loader, prefix, device=device) for name, child in module._modules.items(): load_weights(child, gguf_loader, prefix+name+".", device=device) else: module.load() def tf_logits_warper(generation_config): """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances used for multinomial sampling. """ # instantiate warpers list warpers = LogitsProcessorList() # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) if generation_config.num_beams > 1: if isinstance(generation_config._eos_token_tensor, list): min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 elif isinstance(generation_config._eos_token_tensor, torch.Tensor): min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1 else: min_tokens_to_keep = 2 else: min_tokens_to_keep = 1 # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` if generation_config.temperature is not None and generation_config.temperature != 1.0: warpers.append(TemperatureLogitsWarper(generation_config.temperature)) if generation_config.top_k is not None and generation_config.top_k != 0: warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) if generation_config.top_p is not None and generation_config.top_p < 1.0: warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)) if generation_config.min_p is not None: # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084) warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)) if generation_config.typical_p is not None and generation_config.typical_p < 1.0: warpers.append( TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) ) if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: warpers.append( EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep) ) if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: warpers.append( EtaLogitsWarper( epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device ) ) # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: warpers.append(LogitNormalization()) return warpers def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True, mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False, num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None, static_cache = None, draft_model=None, draft_cache=None): import os os.environ["TOKENIZERS_PARALLELISM"] = "false" torch._dynamo.config.suppress_errors = True batch_size, seq_length = inputs.shape device_map = model.gguf_loader.tensor_device_map if use_torch_npu: CUR_DEVICE = f"npu:{torch.npu.current_device()}" vocabulary_size = model.config.vocab_size topp = torch.tensor([[model.generation_config.top_p]], dtype=torch.float16).npu() topk = torch.tensor([[model.generation_config.top_k]], dtype=torch.int32).npu() temperature = torch.tensor([[model.generation_config.temperature]], dtype=torch.float16).npu() next_token_fake = torch.tensor([[1]], dtype=torch.int32).npu() next_token_probs = torch.tensor([[1.0]], dtype=torch.float16).npu() torch_device = torch.npu.current_device() else: torch_device = get_device('model.layers.0.self_attn', device_map) torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device inputs = inputs.to(torch_device) all_cuda_device = get_all_used_cuda_device(device_map) tokens = [] def decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True): if cuda_graph_runner is None: use_cuda_graph = False inputs_embeds = model.model.embed_tokens(cur_token.to('cpu')).to(torch_device) if use_cuda_graph: if cuda_graph_runner.model_capture: cuda_graph_runner.capture(model, cur_token, position_ids, cache_position, past_key_values, CUR_DEVICE, return_dict=False, use_cache=True) cuda_graph_runner.model_capture = False ret = cuda_graph_runner(inputs_embeds, position_ids, cache_position) logits = ret[0] next_token = torch.argmax(logits, dim=-1) else: torch_npu.npu.set_device(torch_device) logits = model(inputs_embeds=inputs_embeds, position_ids=position_ids, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, is_prefill=False)[0] if past_key_values != None: past_key_values.change_seq_length(1) if generation_config.do_sample: logits = logits / temperature torch.manual_seed(0) probs = logits.view(batch_size, vocabulary_size) sm = nn.Softmax(dim=-1) probs = sm(probs).half().npu() next_token = next_token_fake torch_npu._npu_topk_topp_sampling(probs, topk, topp, next_token, next_token_probs) next_token = next_token.squeeze(-1) else: next_token_scores = logits_warper(inputs, logits[:, -1, :]) next_token = torch.argmax(next_token_scores, dim=-1) return next_token def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True): if use_torch_npu: return decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph) if cuda_graph_runner is None: use_cuda_graph = False if use_cuda_graph: logits = cuda_graph_runner(cur_token, position_ids, cache_position) else: # custom_stream = torch.cuda.Stream() if torch.cuda.is_available(): torch.cuda.set_device(torch_device) elif torch.xpu.is_available(): torch.xpu.set_device(torch_device) elif use_torch_npu: torch_npu.set_device(torch_device) else: raise RuntimeError(f"The device: {torch_device} is not available") inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to(torch_device) # with torch.cuda.stream(custom_stream): logits=model(inputs_embeds=inputs_embeds, position_ids=position_ids, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True)[0] if past_key_values != None and isinstance(past_key_values, StaticCache): past_key_values.change_seq_length(1) sync_all_device(all_cuda_device) next_token_scores = logits_warper(inputs, logits[:, -1, :]) if generation_config.do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) next_token = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_token = torch.argmax(next_token_scores, dim=-1) return next_token # TODO: use CUDA Graph for chunk prefill, may get small improvement def chunk_prefill(inputs, cache_position, past_key_values): if mode == "long_context": inputs_embeds = model.model.embed_tokens(inputs.to("cpu")) else: inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device) # inputs_embeds = torch_npu.npu_format_cast_(inputs_embeds, 29) if use_flashinfer_mla: MLAWrapperSingleton.update_buffer(past_key_values.max_pages) MLAWrapperSingleton.need_plan_all() ret = model( inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, is_prefill=True ) logits = ret[0][:,-1,:].unsqueeze(0).clone().to(torch_device) return logits def decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof=None): global warm_uped global _USE_NPU_GRAPH if use_cuda_graph: from ktransformers.util.npu_graph_runner import get_or_create_runner npu_graph_runner = get_or_create_runner(CUR_DEVICE) npu_graph_runner.init(batch_size, seq_length) with torch_npu.npu.stream(npu_graph_runner.main_stream): gen_num_tokens = 1 while gen_num_tokens < max_new_tokens: start_time = timeStat.record_start_time() if use_flashinfer_mla: MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None, num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size, model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16) if gen_num_tokens == 1: warm_uped = True _USE_NPU_GRAPH = True #np_graph_runner.capture(model, draft_model, next_token, torch.tensor(draft_token), position_ids, cache_position, past_key_values, draft_cache, torch_device, return_dict=False, use_cache=True) cuda_graph_runner = npu_graph_runner next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph) next_token = next_token.to(torch_device) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) generated_ids[:, cache_position] = next_token.int() tokens.append(int(next_token)) seq_length += 1 if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>': print(stream.end(), end="", flush=True) break else: if torch.distributed.get_rank() % get_tensor_parallel_size() == 0: print(stream.put(next_token.item()), end="", flush=True) cache_position += 1 past_key_values.position[0] += 1 position_ids = cache_position.unsqueeze(0) gen_num_tokens += 1 if prof is not None: prof.step() npu_graph_runner.destroy() _USE_NPU_GRAPH = False else: gen_num_tokens = 1 while gen_num_tokens < max_new_tokens: if use_flashinfer_mla: MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None, num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size, model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16) next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph) next_token = next_token.to(torch_device) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) generated_ids[:, cache_position] = next_token.int() tokens.append(int(next_token)) seq_length += 1 if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>': print(stream.end(), end="", flush=True) break else: if torch.distributed.get_rank() % get_tensor_parallel_size() == 0: print(stream.put(next_token.item()), end="", flush=True) cache_position += 1 past_key_values.position[0] += 1 position_ids = cache_position.unsqueeze(0) gen_num_tokens += 1 if prof is not None: prof.step() if prof is not None: prof.stop() if torch.cuda.is_available(): torch.cuda.set_device(torch_device) elif torch.xpu.is_available(): torch.xpu.set_device(torch_device) elif use_torch_npu: torch_npu.set_device(torch_device) else: raise RuntimeError(f"The device: {torch_device} is not available") with torch.no_grad(): stream = TextStreamer(tokenizer) if torch.xpu.is_available(): from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache, DynamicNormalCache if model.config.architectures[0] in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]: past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None) else: past_key_values = DynamicNormalCache.from_legacy_cache(None) elif use_torch_npu and static_cache: assert isinstance(static_cache, StaticCache), '[ERROR] static_cache format not equal to StaticCache' past_key_values = static_cache if past_key_values.max_batch_size < batch_size or past_key_values.max_cache_len < seq_length + max_new_tokens: print('[WARN] current staticCache size exceeded, try create new staticCache...') past_key_values = StaticCache( config=model.config, max_batch_size=1, max_cache_len=seq_length + max_new_tokens, device=device_map, dtype=model.dtype ) else: past_key_values.reset() elif mode != 'long_context': past_key_values = StaticCache( config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype ) else: past_key_values = None generation_config, model_kwargs = model._prepare_generation_config( None, do_sample=False # change this to modify generate config #top_k=5, top_p=0.85, temperature=0.1 ) logits_warper = tf_logits_warper(generation_config) cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32) if use_torch_npu: past_key_values.position[0] = seq_length + 1 generated_ids = torch.zeros( batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device ) generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int) start_time = time.time() logits = None def prefill_wrapper(prof=None): nonlocal logits chunk_start = 0 while chunk_start < seq_length: chunk_end = min(chunk_start + chunk_size, seq_length) if past_key_values != None: past_key_values.cur_idx=cache_position[chunk_start:chunk_end] logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values) chunk_start += chunk_size if prof is not None: prof.step() if prof is not None: prof.stop() if logits is None: raise ValueError('logits cannot be None') if use_torch_npu: global WARM_UP_SKIP_CNT prof_prefill = os.environ["PROF_PREFILL"] if "PROF_PREFILL" in os.environ else "0" if prof_prefill == "1" and WARM_UP_SKIP_CNT[0] <= 0: experimental_config = torch_npu.profiler._ExperimentalConfig( aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False ) with torch_npu.profiler.profile( activities=[ torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU ], schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0), on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prefill_prof"), record_shapes=True, profile_memory=True, with_stack=False, with_flops=False, with_modules=False, experimental_config=experimental_config) as prof: prefill_wrapper(prof) else: prefill_wrapper() WARM_UP_SKIP_CNT[0] -= 1 else: chunk_start = 0 while chunk_start < seq_length: chunk_end = min(chunk_start + chunk_size, seq_length) if past_key_values != None: past_key_values.cur_idx=cache_position[chunk_start:chunk_end] logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values) chunk_start += chunk_size next_token_scores = logits_warper(inputs, logits[:, -1, :]) if generation_config.do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) next_token = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_token = torch.argmax(next_token_scores, dim=-1) first_token_time = time.time() - start_time # print(f"------------------------------------- prefill next_token {next_token} draft_token {draft_token} ") if use_flashinfer_mla: MLAWrapperSingleton.reset_buffer() prefill_count = seq_length prefill_time = first_token_time if use_torch_npu and torch.distributed.get_rank() % get_tensor_parallel_size() == 0: if force_think: print("") print(stream.put(next_token.item()), end="", flush=True) elif not use_torch_npu: if force_think: print("") print(stream.put(next_token.item()), end="", flush=True) generated_ids[:, seq_length] = next_token tokens.append(int(next_token)) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32) position_ids = cache_position.unsqueeze(0) seq_length += 1 cuda_graph_runner = None start_time = time.time() if not use_torch_npu: for i in range(1, max_new_tokens): if use_flashinfer_mla: MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None, num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size, model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16) global warm_uped if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ): warm_uped = True cuda_graph_runner = CUDAGraphRunner() cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True) next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) generated_ids[:, cache_position] = next_token.int() tokens.append(int(next_token)) seq_length += 1 if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>': print(stream.end(), end="", flush=True) break else: print(stream.put(next_token.item()), end="", flush=True) cache_position += 1 position_ids = cache_position.unsqueeze(0) else: prof_decode = os.environ["PROF_DECODE"] if "PROF_DECODE" in os.environ else "0" prof_ranks = os.environ["PROF_RANK"] if "PROF_RANK" in os.environ else "0" prof_ranks = [int(r.strip()) for r in prof_ranks.split(",")] if prof_decode == "1" and torch.distributed.get_rank() in prof_ranks and WARM_UP_SKIP_CNT[1] <= 0: experimental_config = torch_npu.profiler._ExperimentalConfig( aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False ) with torch_npu.profiler.profile( activities=[ torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU ], schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=_MAX_DECODE_PROFILE, repeat=1, skip_first=0), on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./decode_prof"), record_shapes=True, profile_memory=True, with_stack=False, with_flops=False, with_modules=False, experimental_config=experimental_config) as prof: decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof) else: decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length) WARM_UP_SKIP_CNT[1] -= 1 total_time = time.time() - start_time tokens_generated = len(tokens) tokens_per_second = tokens_generated / total_time if not use_torch_npu: print("") print(f"prompt eval count: {prefill_count} token(s)") print(f"prompt eval duration: {prefill_time}s") print(f"prompt eval rate: {prefill_count/prefill_time} tokens/s") print(f"eval count: {tokens_generated} token(s)") print(f"eval duration: {total_time}s") print(f"eval rate: {tokens_per_second} tokens/s") else: tp_size = get_tensor_parallel_size() if torch.distributed.get_rank() % tp_size == 0: rank = f"[rank:{torch.distributed.get_rank()}]" msg = f"\n{rank} Eval Time\n" msg += rank + f"prompt eval count: {prefill_count} token(s)\n" msg += rank + f"prompt eval duration: {prefill_time:.9f}s\n" msg += rank + f"prompt eval rate: {prefill_count/prefill_time:.9f} tokens/s\n" msg += rank + f"eval count: {tokens_generated} token(s)\n" msg += rank + f"eval duration: {total_time:.9f}s\n" msg += rank + f"eval rate: {tokens_per_second:.9f} tokens/s\n" print(msg) return tokens class InferenceState(enum.Enum): UNLOAD = 0 PREFILL = 1 GENERATE = 2 RESTORE = 3 ================================================ FILE: archive/ktransformers/util/vendors.py ================================================ from __future__ import annotations from enum import IntEnum, auto from typing import Optional, Union, List import torch class GPUVendor(IntEnum): NVIDIA = auto() AMD = auto() MooreThreads = auto() MetaX = auto() MUSA = auto() Unknown = auto() class DeviceManager: """ Device manager that provides a unified interface for handling different GPU vendors """ def __init__(self): self.gpu_vendor = self._detect_gpu_vendor() self.available_devices = self._get_available_devices() def _detect_gpu_vendor(self) -> GPUVendor: """Detect GPU vendor type""" if not torch.cuda.is_available(): # Check MUSA availability (assuming a musa module exists) try: import musa if musa.is_available(): return GPUVendor.MUSA except (ImportError, AttributeError): pass return GPUVendor.Unknown device_name = torch.cuda.get_device_name(0).lower() if any(name in device_name for name in ["nvidia", "geforce", "quadro", "tesla", "titan", "rtx", "gtx"]): return GPUVendor.NVIDIA elif any(name in device_name for name in ["amd", "radeon", "rx", "vega", "instinct", "firepro", "mi"]): return GPUVendor.AMD elif any(name in device_name for name in ["mthreads", "moore", "mtt"]): return GPUVendor.MooreThreads elif any(name in device_name for name in ["metax", "meta"]): return GPUVendor.MetaX elif "musa" in device_name: return GPUVendor.MUSA # Backend check try: if hasattr(torch.version, 'hip') and torch.version.hip is not None: return GPUVendor.AMD elif hasattr(torch.version, 'cuda') and torch.version.cuda is not None: return GPUVendor.NVIDIA except: pass return GPUVendor.Unknown def _get_available_devices(self) -> List[int]: """Get list of available device indices""" devices = [] if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD: devices = list(range(torch.cuda.device_count())) elif self.gpu_vendor == GPUVendor.MUSA: try: import musa devices = list(range(musa.device_count())) except (ImportError, AttributeError): pass return devices def get_device_str(self, device_id: Union[int, str]) -> str: """ Get device string for the given device ID Args: device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string Returns: Device string representation (e.g., "cuda:0", "musa:1", "cpu") """ if device_id == -1 or device_id == "cpu": return "cpu" if isinstance(device_id, int): if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD: if device_id < torch.cuda.device_count(): return f"cuda:{device_id}" elif self.gpu_vendor == GPUVendor.MUSA: try: import musa if device_id < musa.device_count(): return f"musa:{device_id}" except (ImportError, AttributeError): pass return "cpu" def to_torch_device(self, device_id: Union[int, str] = 0) -> torch.device: """ Convert device ID to torch.device object Args: device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string Returns: torch.device object """ device_str = self.get_device_str(device_id) # Handle MUSA device if device_str.startswith("musa:"): try: import musa index = int(device_str.split(":")[-1]) return musa.device(index) except (ImportError, ValueError, AttributeError): return torch.device("cpu") # Standard PyTorch device return torch.device(device_str) def move_tensor_to_device(self, tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor: """ Move tensor to specified device Args: tensor: PyTorch tensor to move device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string Returns: Tensor moved to the specified device """ device = self.to_torch_device(device_id) return tensor.to(device) def is_available(self, index: int = 0) -> bool: """ Check if device at specified index is available Args: index: Device index to check Returns: True if the device is available, False otherwise """ if index < 0: return True # CPU is always available return index in self.available_devices def get_all_devices(self) -> List[int]: """ Get all available device indices Returns: List of available device indices (0, 1, 2, etc.) """ return self.available_devices # Create global device manager instance device_manager = DeviceManager() # Convenience functions def get_device(device_id: Union[int, str] = 0) -> torch.device: """ Get torch.device object for the specified device ID Args: device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string Returns: torch.device object """ return device_manager.to_torch_device(device_id) def to_device(tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor: """ Move tensor to specified device Args: tensor: PyTorch tensor to move device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string Returns: Tensor moved to the specified device """ return device_manager.move_tensor_to_device(tensor, device_id) # Get devices cpu_device = get_device(-1) # CPU using index -1 cpu_device2 = get_device("cpu") # CPU using string "cpu" gpu0 = get_device(0) # First GPU # Move tensors x = torch.randn(3, 3) x_gpu = to_device(x, 0) # Move to first GPU x_cpu1 = to_device(x, -1) # Move to CPU using index -1 x_cpu2 = to_device(x, "cpu") # Move to CPU using string "cpu" ================================================ FILE: archive/ktransformers/util/weight_loader.py ================================================ from abc import ABC, abstractmethod import os import torch import numpy as np from safetensors import safe_open from typing import Dict, Any, Optional, Union class ModelLoader(ABC): """ Abstract base class for model loaders. Defines the interface that all model loaders must implement. """ @abstractmethod def load_tensor(self, name: str, device: str = "cpu") -> torch.Tensor: """ Load a tensor by name. Args: name: Name of the tensor to load device: Device to load the tensor to Returns: The loaded tensor """ pass @classmethod @abstractmethod def supports_format(cls, path: str) -> bool: """ Check if this loader supports the given path format. Args: path: Path to check Returns: True if this loader supports the given path, False otherwise """ pass class SafeTensorLoader(ModelLoader): """ Loader for SafeTensor format models. """ def __init__(self, path: str): """ Initialize the SafeTensor loader. Args: path: Path to the model directory or file """ self.tensor_file_map = {} # Maps tensor names to file paths self.file_handle_map = {} # Maps file names to file handles self._load_tensor_file_map(path) def _load_tensor_file_map(self, path: str) -> None: """ Load the tensor file map from the given path. Args: path: Path to the model directory or file """ # Normalize path to directory if not os.path.exists(path): raise FileNotFoundError(f"Path not found: {path}") if os.path.isfile(path): folder_path = os.path.dirname(path) else: folder_path = path found_safetensor = False for root, _, files in os.walk(folder_path): files = sorted(files) for file in files: if file.endswith(".safetensors"): found_safetensor = True file_path = os.path.join(root, file) if file not in self.file_handle_map: try: handle = safe_open(file_path, framework="pt") self.file_handle_map[file] = handle except Exception as e: print(f"Error opening Safetensor file {file_path}: {e}") continue f = self.file_handle_map.get(file) if f is None: continue try: for key in f.keys(): self.tensor_file_map[key] = file except Exception as e: print(f"Error reading Safetensor file {file_path}: {e}") if not found_safetensor: # Not raising an error here allows for the factory to try other loaders print(f"No Safetensor files found in {folder_path}") def load_tensor(self, name: str, device: str = "cpu") -> torch.Tensor: """ Load a tensor by name. Args: name: Name of the tensor to load device: Device to load the tensor to Returns: The loaded tensor """ if name not in self.tensor_file_map: raise KeyError(f"Key {name} not found in Safetensor files") file = self.tensor_file_map[name] f = self.file_handle_map.get(file) if f is None: raise FileNotFoundError(f"File {file} not found in Safetensor files") tensor = f.get_tensor(name) return tensor.to(device) def load_dequantized_tensor(self, name: str, device: str = "cpu") -> torch.Tensor: """ Load and dequantize a tensor. Args: name: Name of the tensor to load device: Device to load the tensor to Returns: The dequantized tensor """ if name not in self.tensor_file_map: raise KeyError(f"Key {name} not found in Safetensor files") file = self.tensor_file_map[name] f = self.file_handle_map.get(file) if f is None: raise FileNotFoundError(f"File {file} not found in Safetensor files") tensor = f.get_tensor(name).to(device) if name.endswith(".weight"): if name[:-7] + ".weight_scale_inv" in self.tensor_file_map: weight_scale_inv = f.get_tensor(name[:-7] + ".weight_scale_inv").to(device) # Assuming weight_dequant function is imported from ktransformers.ktransformers_ext.triton.fp8gemm import weight_dequant tensor = weight_dequant(tensor, weight_scale_inv) return tensor.to(device) def close_all_handles(self) -> None: """ Close all file handles. """ for handle in self.file_handle_map.values(): handle.close() self.file_handle_map.clear() @classmethod def supports_format(cls, path: str) -> bool: """ Check if this loader supports the given path format. Args: path: Path to check Returns: True if safetensor files are found in the path, False otherwise """ # Normalize path to directory if not os.path.exists(path): return False if os.path.isfile(path): if path.endswith(".safetensors"): return True folder_path = os.path.dirname(path) else: folder_path = path # Check if any safetensor files exist in the folder for root, _, files in os.walk(folder_path): for file in files: if file.endswith(".safetensors"): return True return False class GGUFLoader(ModelLoader): """ Loader for GGUF format models. """ def __init__(self, path: str): """ Initialize the GGUF loader. Args: path: Path to the model directory or file """ # Check if path exists if not os.path.exists(path): raise FileNotFoundError(f"GGUF dir not found: {path}") if os.path.isfile(path): self.gguf_path = os.path.dirname(path) else: self.gguf_path = path self.tensor_info = {} # Stores tensor metadata self.tensor_file_map = {} # Maps tensor names to file paths self.file_data_map = {} # Maps file paths to memory-mapped data self.gguf_file_meta = {} # Stores GGUF metadata # For compatibility with the factory pattern self.safetensor_loader = None # Scan all GGUF files in the directory found_gguf = False for root, _, files in os.walk(self.gguf_path): for file in files: if file.endswith(".gguf"): found_gguf = True file_path = os.path.join(root, file) with open(file_path, "rb") as f: self._load_gguf(f) if file_path not in self.file_data_map: self.file_data_map[file_path] = np.memmap(file_path, mode='r') if not found_gguf: raise FileNotFoundError(f"Cannot find any .gguf files in: {self.gguf_path}") def _load_gguf(self, f) -> None: """ Load GGUF file metadata and tensor info. Args: f: File handle of the GGUF file """ # Implementation should follow the original GGUFLoader._load_gguf # This is a simplified version for illustration f.seek(0) assert f.read(4) == b'GGUF' # Read header values = struct.unpack(" Any: """ Read a value from the file according to its data type. Args: f: File handle data_type: Type of data to read Returns: The read value """ # Simplified implementation # In a complete implementation, this would handle all data types if data_type == 8: # DATA_TYPES["string"] length = struct.unpack(" torch.Tensor: """ Load a tensor by name. Args: name: Name of the tensor to load device: Device to load the tensor to Returns: The loaded tensor """ # This should call load_gguf_tensor with the appropriate parameters return self.load_gguf_tensor(name, device) def load_gguf_tensor(self, name: str, device: str = "cpu", target_dtype = None) -> torch.Tensor: """ Load a GGUF tensor by name. Args: name: Name of the tensor to load device: Device to load the tensor to target_dtype: Target data type for the tensor Returns: The loaded tensor """ # Implementation would follow the original GGUFLoader.load_gguf_tensor # This is a placeholder for illustration if name not in self.tensor_info: raise KeyError(f"Tensor {name} not found") # Actual implementation would dequantize the tensor data # and return a torch.Tensor return torch.zeros(1, device=device) # Placeholder @classmethod def supports_format(cls, path: str) -> bool: """ Check if this loader supports the given path format. Args: path: Path to check Returns: True if GGUF files are found in the path, False otherwise """ # Normalize path to directory if not os.path.exists(path): return False if os.path.isfile(path): return path.endswith(".gguf") # Check if any GGUF files exist in the folder for root, _, files in os.walk(path): for file in files: if file.endswith(".gguf"): return True return False ================================================ FILE: archive/ktransformers/website/.browserslistrc ================================================ > 1% last 2 versions not dead not ie 11 ================================================ FILE: archive/ktransformers/website/.eslintrc.js ================================================ module.exports = { root: true, env: { node: true }, 'extends': [ 'plugin:vue/vue3-essential', 'eslint:recommended', '@vue/typescript/recommended' ], parserOptions: { ecmaVersion: 2020 }, rules: { 'no-console': process.env.NODE_ENV === 'production' ? 'warn' : 'off', 'no-debugger': process.env.NODE_ENV === 'production' ? 'warn' : 'off' }, overrides: [ { files: [ '**/__tests__/*.{j,t}s?(x)', '**/tests/unit/**/*.spec.{j,t}s?(x)' ], env: { jest: true } } ] } ================================================ FILE: archive/ktransformers/website/.gitignore ================================================ .DS_Store node_modules /dist # local env files .env.local .env.*.local # Log files npm-debug.log* yarn-debug.log* yarn-error.log* pnpm-debug.log* # Editor directories and files .idea .vscode *.suo *.ntvs* *.njsproj *.sln *.sw? ================================================ FILE: archive/ktransformers/website/README.md ================================================ # ## Project setup ``` npm install ``` ### Compiles and hot-reloads for development ``` npm run serve ``` ### Compiles and minifies for production ``` npm run build ``` ### Run your unit tests ``` npm run test:unit ``` ### Lints and fixes files ``` npm run lint ``` ### Customize configuration See [Configuration Reference](https://cli.vuejs.org/config/). ================================================ FILE: archive/ktransformers/website/config.d.ts ================================================ declare module '*.js' { const config: { apiUrl: string; port:number; }; export { config }; } ================================================ FILE: archive/ktransformers/website/jest.config.js ================================================ module.exports = { preset: '@vue/cli-plugin-unit-jest/presets/typescript' } ================================================ FILE: archive/ktransformers/website/package.json ================================================ { "name": "", "version": "", "private": true, "scripts": { "serve": "vue-cli-service serve", "build": "vue-cli-service build", "test:unit": "vue-cli-service test:unit", "lint": "vue-cli-service lint" }, "dependencies": { "@types/pdfjs-dist": "^2.10.378", "@types/websocket": "^1.0.10", "@vue/cli": "^5.0.8", "ant-design-vue": "^4.2.1", "apexcharts": "^3.49.1", "axios": "^1.7.0", "axios-extensions": "^3.1.6", "better-scroll": "^2.5.1", "element-plus": "^2.7.3", "marked": "^12.0.2", "marked-highlight": "^2.1.1", "pdf-lib": "^1.17.1", "pdfobject": "^2.3.0", "v-clipboard": "^3.0.0-next.1", "vue": "^3.4.27", "vue-i18n": "^9.13.1", "vue-pdf": "^4.3.0", "vue-router": "^4.0.3", "vue3-apexcharts": "^1.5.3", "vuex": "^4.0.0", "webpack": "^5.91.0", "webpack-cli": "^5.1.4", "websocket": "^1.0.35" }, "devDependencies": { "@types/jest": "^27.0.1", "@types/pdfobject": "^2.2.5", "@typescript-eslint/eslint-plugin": "^5.4.0", "@typescript-eslint/parser": "^5.4.0", "@vue/cli-plugin-eslint": "~5.0.0", "@vue/cli-plugin-router": "~5.0.0", "@vue/cli-plugin-typescript": "~5.0.0", "@vue/cli-plugin-unit-jest": "~5.0.0", "@vue/cli-plugin-vuex": "~5.0.0", "@vue/cli-service": "~5.0.0", "@vue/eslint-config-typescript": "^9.1.0", "@vue/test-utils": "^2.0.0-0", "@vue/vue3-jest": "^27.0.0-alpha.1", "babel-jest": "^27.0.6", "eslint": "^7.32.0", "eslint-plugin-vue": "^8.0.3", "jest": "^27.0.5", "stylus": "^0.55.0", "stylus-loader": "^6.1.0", "ts-jest": "^27.0.4", "typescript": "~4.5.5" }, "_id": "@", "readme": "ERROR: No README data found!" } ================================================ FILE: archive/ktransformers/website/public/config.js ================================================ window.configWeb = { apiUrl: 'http://119.255.238.12:15670/v1', port: 8080, }; ================================================ FILE: archive/ktransformers/website/public/css/reset.css ================================================ html, body, div, span, applet, object, iframe, h1, h2, h3, h4, h5, h6, p, blockquote, pre, a, abbr, acronym, address, big, cite, code, del, dfn, em, img, ins, kbd, q, s, samp, small, strike, strong, sub, sup, tt, var, b, u, i, center, dl, dt, dd, ol, ul, li, fieldset, form, label, legend,textarea, table, caption, tbody, tfoot, thead, tr, th, td, article, aside, canvas, details, embed, figure, figcaption, footer, header, hgroup, menu, nav, output, ruby, section, summary, time, mark, audio, video { margin: 0; padding: 0; border: 0; font-size: 100%; *font: inherit; font-family: Arial, Microsoft YaHei, SimHei, Tahoma, sans-serif !important; vertical-align: baseline; } /* HTML5 display-role reset for older browsers */ article, aside, details, figcaption, figure, footer, header, hgroup, menu, nav, section { display: block; } body { line-height: 1; -webkit-text-size-adjust: 100%!important; margin: 0; } html,body { height: 100%; width: 100%; overflow: hidden; } ol, ul { list-style: none; } blockquote, q { quotes: none; } blockquote:before, blockquote:after, q:before, q:after { content: ''; content: none; } table { border-collapse: collapse; border-spacing: 0; } .clearfix:before, .clearfix:after { content:""; display:table } .clearfix:after { clear:both } /*显示省略号*/ .ellipsis{ overflow: hidden; text-overflow: ellipsis; white-space: nowrap; } ================================================ FILE: archive/ktransformers/website/public/index.html ================================================ KTransformers
================================================ FILE: archive/ktransformers/website/src/App.vue ================================================ ================================================ FILE: archive/ktransformers/website/src/api/api-client.ts ================================================ import axios, { AxiosInstance } from 'axios'; import {baseURL} from '@/conf/config'; const apiClient: AxiosInstance = axios.create({ baseURL: baseURL, // baseURL: '/api', headers: { 'Content-Type': 'application/json', }, withCredentials: true, }); export default apiClient; ================================================ FILE: archive/ktransformers/website/src/api/assistant.ts ================================================ import apiClient from './api-client'; import { IAssistant,IDeleteResult, IAssistantWithStatus } from '../utils/types'; function filterAndConvert( assistantsWithStatus: IAssistantWithStatus[], statusCondition: string ): IAssistant[] { return assistantsWithStatus .filter((assistant) => assistant.build_status.status === statusCondition) .map(({ build_status, ...rest }) => rest); } interface IAssistantData { model: string; prefix_system_prompt?: string; suffix_system_prompt?: string; name?: string; description?: string; tools?: any[]; tool_resources?: object; metadata?:{[key:string]:any} top_p?: number; temperature?: number; response_format?: string; instructions?: string; } export const createAssistant = async (data: IAssistantData): Promise => { const assistant_data: { model: string; instructions?: string; name?: string; description?: string; tools?: any[]; tool_resources?: object; metadata?:{[key:string]:any} top_p?: number; temperature?: number; response_format?: string; } = { model: data.model }; if (data.prefix_system_prompt) { assistant_data.instructions = data.prefix_system_prompt; } if (data.suffix_system_prompt) { assistant_data.instructions = data.suffix_system_prompt; } if (data.name) { assistant_data.name = data.name; } if (data.description) { assistant_data.description = data.description; } if (data.tools) { assistant_data.tools = data.tools; } if (data.tool_resources) { assistant_data.tool_resources = data.tool_resources; } if (data.metadata) { assistant_data.metadata = data.metadata } if (typeof data.top_p !== 'undefined') { assistant_data.top_p = data.top_p; } if (typeof data.temperature !== 'undefined') { assistant_data.temperature = data.temperature; } if (data.response_format) { assistant_data.response_format = data.response_format; } if (data.instructions) { assistant_data.instructions = data.instructions; } console.log(assistant_data) const response = await apiClient.post( '/assistants/', assistant_data ); console.log("response", response) return response.data; }; export const listAssistants = async ( limit?: number, order?: string, after?: string, before?: string, run_id?: string, ): Promise => { const params: { limit?: number, order?: string, after?: string, before?: string, run_id?: string } = {}; if (typeof limit !== 'undefined') { params.limit = limit; } if (typeof order !== 'undefined') { params.order = order; } if (typeof after !== 'undefined') { params.after = after; } if (typeof before !== 'undefined') { params.before = before; } if (typeof run_id !== 'undefined') { params.run_id = run_id; } const response = await apiClient.get('/assistants/status', { params }); let tmp = response.data let result = [] as IAssistant[] const filteredAssistants = filterAndConvert(tmp, 'completed'); return filteredAssistants }; export const getAssistant = async ( assistant_id: string ): Promise => { const response = await apiClient.get(`/assistants/${assistant_id}`); return response.data; } export const deleteAssistant = async ( assistant_id: string ): Promise => { const response = await apiClient.delete(`/assistants/${assistant_id}`); return response.data; } export const getRelatedThreadId = async ( assistant_id: string ): Promise => { const response = await apiClient.get(`/assistants/${assistant_id}/related_thread`); return response.data; } export const listAssistantsWithStatus = async ( limit?: number, order?: string, after?: string, before?: string, run_id?: string, ): Promise => { const params: { limit?: number, order?: string, after?: string, before?: string, run_id?: string } = {}; if (typeof limit !== 'undefined') { params.limit = limit; } if (typeof order !== 'undefined') { params.order = order; } if (typeof after !== 'undefined') { params.after = after; } if (typeof before !== 'undefined') { params.before = before; } if (typeof run_id !== 'undefined') { params.run_id = run_id; } console.log(params) const response = await apiClient.get('/assistants/status', { params }); return response.data; }; ================================================ FILE: archive/ktransformers/website/src/api/message.ts ================================================ import apiClient from './api-client'; import { IMessage,IDeleteResult } from '../utils/types'; export const createMessage = async ( thread_id: string, content: string, role?: string, attachments?: any[], metadata?:{[key:string]:any} ): Promise => { const message_data: { content: string; role?: string; attachments?: any[]; metadata?:{[key:string]:any} } = { content, }; if (metadata) { message_data.metadata = metadata; } if (role) { message_data.role = role; } if (attachments) { message_data.attachments = attachments; } const response = await apiClient.post(`/threads/${thread_id}/messages`, message_data); return response.data; }; export const listMessages = async ( thread_id: string, limit?: number, order?: string, after?: string, before?: string, run_id?: string, ): Promise => { const params: { limit?: number, order?: string, after?: string, before?: string, run_id?: string } = {}; if (typeof limit !== 'undefined') { params.limit = limit; } if (typeof order !== 'undefined') { params.order = order; } if (typeof after !== 'undefined') { params.after = after; } if (typeof before !== 'undefined') { params.before = before; } if (typeof run_id !== 'undefined') { params.run_id = run_id; } const response = await apiClient.get(`/threads/${thread_id}/messages`, { params }); return response.data; }; export const deleteMessage = async(thread_id:string, message_id:string): Promise => { const response = await apiClient.delete(`/threads/${thread_id}/messages/${message_id}`); return response.data; } ================================================ FILE: archive/ktransformers/website/src/api/run.ts ================================================ import apiClient from './api-client'; import { IRun } from '../utils/types'; import {baseURL} from '@/conf/config'; interface IRunData { assistant_id: string; model?: string; instructions?: string; additional_instructions?: string; additional_messages?: any[]; tools?: any[]; metadata?: { [key: string]: any } temperature?: number; top_p?: number; stream?: boolean; max_prompt_tokens?: number; max_completion_tokens?: number; truncation_strategy?: object; tool_choice?: string; response_format?: string | object; } export async function* createRun( data: IRunData, thread_id: string ): AsyncGenerator { const run_data = { ...data, assistant_id: data.assistant_id, }; const response = await fetch(`${baseURL}/threads/${thread_id}/runs`, { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify(run_data), }); if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`); } if (!response.body) { throw new Error('Response body is missing'); } const reader = response.body.getReader(); const decoder = new TextDecoder(); let buffer = ''; try { while (true) { const { done, value } = await reader.read(); if (done) return; buffer += decoder.decode(value, { stream: true }); let eventIndex = buffer.indexOf("\n\n"); while (eventIndex !== -1) { const event = buffer.slice(0, eventIndex); buffer = buffer.slice(eventIndex + 2); if (event.startsWith("event: thread.run.created")) { const dataIndex = event.indexOf("data: "); if (dataIndex !== -1) { const datads = event.slice(39, 75) yield datads; } } else if (event.startsWith("event: thread.message.delta")) { const dataIndex = event.indexOf("data: "); if (dataIndex !== -1) { const data = JSON.parse(event.slice(dataIndex + 6)); yield data.delta.content[0].text.value || ''; } } else if (event.startsWith("event: done")) { return; } eventIndex = buffer.indexOf("\n\n"); } } } catch (e) { console.error('An error occurred while reading the response stream:', e); // throw e; return e } } // 定义取消运行的函数 export async function cancelRun(threadId: string, runId: string){ const run_data = { thread_id:threadId, run_id:runId, }; try { const response = await fetch(`${baseURL}/threads/${threadId}/runs/${runId}/cancel`, { method: 'POST', }); if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`); } return response; } catch (error) { console.error('An error occurred while cancelling the run:', error); throw error; } } ================================================ FILE: archive/ktransformers/website/src/api/thread.ts ================================================ import apiClient from './api-client'; import { IThread, IMessage, IThreadAndMessageAndAssistant, IDeleteResult } from '../utils/types'; export const createThread = async ( message?: IMessage, tool_resources?: object, metadata?: { [key: string]: any } ): Promise => { const thread_data: { message?: object, metadata?: { [key: string]: any } } = {}; if (message) { thread_data.message = message; } if (metadata) { thread_data.metadata = metadata; } const response = await apiClient.post( '/threads', thread_data); return response.data; }; export const listThreads = async ( limit?: number, order?: string, ): Promise => { const params: { limit?: number, order?: string, } = { limit, order }; const response = await apiClient.get('/threads', { params }); return response.data; }; export const deleteThread = async ( thread_id: string ): Promise => { const response = await apiClient.delete(`/threads/${thread_id}`); return response.data; } export const getThread = async ( thread_id: string ): Promise => { const response = await apiClient.get(`/threads/${thread_id}`); return response.data; } ================================================ FILE: archive/ktransformers/website/src/assets/css/mixins.styl ================================================ /*Define color variables*/ $bg_gray_light_normal = #F9F9F9 $bg_gray_light_hover = #E8E8E8 $bg_gray_light_active = #E8E8E8 $border_gray_light_normal = rgba(0, 0, 0, .15) $border_gray_light_hover = #8080FF $gray_20 = #333333 $gray_40 = #585858 $gray_50 = #7F7F7F $gray_60 = #9F9F9F $gray_70 = #BFBFBF $gray_80 = #DFDFDF $gray_85 = #F2F2F2 $gray_90 = #F7F7F7 $gray = #53525B $gray_dark = #42414a $gray_hover = #121212 $gray_action = #6C757D $primary = #409eff $primary_hover = #428bca $primary_middle = #9DDDF9 $primary_light = #D4F0FC $cyan = #66CCCC $cyan_hover = #46C2C2 /*Define common modules*/ $input-duration = .25s input-border() -webkit-transition: border-color ease-in-out $input-duration,-webkit-box-shadow ease-in-out $input-duration -o-transition: border-color ease-in-out $input-duration,box-shadow ease-in-out $input-duration transition: border-color ease-in-out $input-duration,box-shadow ease-in-out $input-duration input-focus() border-color: #66afe9 outline: 0 z-index: 100 -webkit-box-shadow: inset 0 1px 1px rgba(0,0,0,.075),0 0 8px rgba(102,175,233,.6) box-shadow: inset 0 1px 1px rgba(0,0,0,.075),0 0 8px rgba(102,175,233,.6) /*Define common class*/ .flex-column display: -webkit-box display: -webkit-flex display: flex box-sizing: border-box -webkit-box-orient: vertical -webkit-box-direction: normal -webkit-flex-direction: column flex-direction: column height: 100% .flex-row position: relative display: -webkit-box display: -ms-flexbox display: flex box-sizing: border-box -webkit-box-align: center -ms-flex-align: center align-items: center .flex-unit -webkit-box-flex: 1 -ms-flex: 1 flex: 1 // overflow: hidden .clearfix &:after clear: both content: "\20" display: block height: 0 visibility: hidden a,a:hover text-decoration:none button:focus outline: none .btn display: inline-block margin-bottom: 0 padding:0px 15px font-size: 14px height: 34px line-height: 32px float: left /*去掉inline-block之间的空格*/ font-weight: normal text-align: center white-space: nowrap vertical-align: middle cursor: pointer background-image: none border-radius: 3px -webkit-user-select: none -moz-user-select: none -ms-user-select: none -o-user-select: none user-select: none &:hover .dropdown-list display: block i font-size: 16px .text float: right margin-left: 3px .btn-gray color: $gray_action background-color: #FFFFFF border: 1px solid $gray_action &:not(.is-disabled):hover color: #FFFFFF background-color: $gray_action border: 1px solid $gray_action .btn-primary color: #FFFFFF background-color: $primary border: 1px solid $primary &:not(.is-disabled):hover color: #FFFFFF background-color: $primary_hover border: 1px solid $primary_hover .chat-box position: relative .chat-input border: 1px solid $border_gray_light_normal height: 48px line-height: 48px font-size: 16px outline: 0 box-sizing: border-box padding:0 30px0 20px color: #7F7F7F width: 800px border-radius: 12px position: relative &:focus input-focus() i position: absolute font-size: 26px right: 13px bottom:0px color: $border_gray_light_normal z-index: 100 cursor: pointer &:hover color: $border_gray_light_hover ================================================ FILE: archive/ktransformers/website/src/assets/iconfont/demo.css ================================================ /* Logo 字体 */ @font-face { font-family: "iconfont logo"; src: url('https://at.alicdn.com/t/font_985780_km7mi63cihi.eot?t=1545807318834'); src: url('https://at.alicdn.com/t/font_985780_km7mi63cihi.eot?t=1545807318834#iefix') format('embedded-opentype'), url('https://at.alicdn.com/t/font_985780_km7mi63cihi.woff?t=1545807318834') format('woff'), url('https://at.alicdn.com/t/font_985780_km7mi63cihi.ttf?t=1545807318834') format('truetype'), url('https://at.alicdn.com/t/font_985780_km7mi63cihi.svg?t=1545807318834#iconfont') format('svg'); } .logo { font-family: "iconfont logo"; font-size: 160px; font-style: normal; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; } /* tabs */ .nav-tabs { position: relative; } .nav-tabs .nav-more { position: absolute; right: 0; bottom: 0; height: 42px; line-height: 42px; color: #666; } #tabs { border-bottom: 1px solid #eee; } #tabs li { cursor: pointer; width: 100px; height: 40px; line-height: 40px; text-align: center; font-size: 16px; border-bottom: 2px solid transparent; position: relative; z-index: 1; margin-bottom: -1px; color: #666; } #tabs .active { border-bottom-color: #f00; color: #222; } .tab-container .content { display: none; } /* 页面布局 */ .main { padding: 30px 100px; width: 960px; margin: 0 auto; } .main .logo { color: #333; text-align: left; margin-bottom: 30px; line-height: 1; height: 110px; margin-top: -50px; overflow: hidden; *zoom: 1; } .main .logo a { font-size: 160px; color: #333; } .helps { margin-top: 40px; } .helps pre { padding: 20px; margin: 10px 0; border: solid 1px #e7e1cd; background-color: #fffdef; overflow: auto; } .icon_lists { width: 100% !important; overflow: hidden; *zoom: 1; } .icon_lists li { width: 100px; margin-bottom: 10px; margin-right: 20px; text-align: center; list-style: none !important; cursor: default; } .icon_lists li .code-name { line-height: 1.2; } .icon_lists .icon { display: block; height: 100px; line-height: 100px; font-size: 42px; margin: 10px auto; color: #333; -webkit-transition: font-size 0.25s linear, width 0.25s linear; -moz-transition: font-size 0.25s linear, width 0.25s linear; transition: font-size 0.25s linear, width 0.25s linear; } .icon_lists .icon:hover { font-size: 100px; } .icon_lists .svg-icon { /* 通过设置 font-size 来改变图标大小 */ width: 1em; /* 图标和文字相邻时,垂直对齐 */ vertical-align: -0.15em; /* 通过设置 color 来改变 SVG 的颜色/fill */ fill: currentColor; /* path 和 stroke 溢出 viewBox 部分在 IE 下会显示 normalize.css 中也包含这行 */ overflow: hidden; } .icon_lists li .name, .icon_lists li .code-name { color: #666; } /* markdown 样式 */ .markdown { color: #666; font-size: 14px; line-height: 1.8; } .highlight { line-height: 1.5; } .markdown img { vertical-align: middle; max-width: 100%; } .markdown h1 { color: #404040; font-weight: 500; line-height: 40px; margin-bottom: 24px; } .markdown h2, .markdown h3, .markdown h4, .markdown h5, .markdown h6 { color: #404040; margin: 1.6em 0 0.6em 0; font-weight: 500; clear: both; } .markdown h1 { font-size: 28px; } .markdown h2 { font-size: 22px; } .markdown h3 { font-size: 16px; } .markdown h4 { font-size: 14px; } .markdown h5 { font-size: 12px; } .markdown h6 { font-size: 12px; } .markdown hr { height: 1px; border: 0; background: #e9e9e9; margin: 16px 0; clear: both; } .markdown p { margin: 1em 0; } .markdown>p, .markdown>blockquote, .markdown>.highlight, .markdown>ol, .markdown>ul { width: 80%; } .markdown ul>li { list-style: circle; } .markdown>ul li, .markdown blockquote ul>li { margin-left: 20px; padding-left: 4px; } .markdown>ul li p, .markdown>ol li p { margin: 0.6em 0; } .markdown ol>li { list-style: decimal; } .markdown>ol li, .markdown blockquote ol>li { margin-left: 20px; padding-left: 4px; } .markdown code { margin: 0 3px; padding: 0 5px; background: #eee; border-radius: 3px; } .markdown strong, .markdown b { font-weight: 600; } .markdown>table { border-collapse: collapse; border-spacing:0; empty-cells: show; border: 1px solid #e9e9e9; width: 95%; margin-bottom: 24px; } .markdown>table th { white-space: nowrap; color: #333; font-weight: 600; } .markdown>table th, .markdown>table td { border: 1px solid #e9e9e9; padding: 8px 16px; text-align: left; } .markdown>table th { background: #F7F7F7; } .markdown blockquote { font-size: 90%; color: #999; border-left: 4px solid #e9e9e9; padding-left: 0.8em; margin: 1em 0; } .markdown blockquote p { margin: 0; } .markdown .anchor { opacity: 0; transition: opacity 0.3s ease; margin-left: 8px; } .markdown .waiting { color: #ccc; } .markdown h1:hover .anchor, .markdown h2:hover .anchor, .markdown h3:hover .anchor, .markdown h4:hover .anchor, .markdown h5:hover .anchor, .markdown h6:hover .anchor { opacity: 1; display: inline-block; } .markdown>br, .markdown>p>br { clear: both; } .hljs { display: block; background: white; padding: 0.5em; color: #333333; overflow-x: auto; } .hljs-comment, .hljs-meta { color: #969896; } .hljs-string, .hljs-variable, .hljs-template-variable, .hljs-strong, .hljs-emphasis, .hljs-quote { color: #df5000; } .hljs-keyword, .hljs-selector-tag, .hljs-type { color: #a71d5d; } .hljs-literal, .hljs-symbol, .hljs-bullet, .hljs-attribute { color: #0086b3; } .hljs-section, .hljs-name { color: #63a35c; } .hljs-tag { color: #333333; } .hljs-title, .hljs-attr, .hljs-selector-id, .hljs-selector-class, .hljs-selector-attr, .hljs-selector-pseudo { color: #795da3; } .hljs-addition { color: #55a532; background-color: #eaffea; } .hljs-deletion { color: #bd2c00; background-color: #ffecec; } .hljs-link { text-decoration: underline; } /* 代码高亮 */ /* PrismJS 1.15.0 https://prismjs.com/download.html#themes=prism&languages=markup+css+clike+javascript */ /** * prism.js default theme for JavaScript, CSS and HTML * Based on dabblet (http://dabblet.com) * @author Lea Verou */ code[class*="language-"], pre[class*="language-"] { color: black; background: none; text-shadow: 0 1px white; font-family: Consolas, Monaco, 'Andale Mono', 'Ubuntu Mono', monospace; text-align: left; white-space: pre; word-spacing: normal; word-break: normal; word-wrap: normal; line-height: 1.5; -moz-tab-size: 4; -o-tab-size: 4; tab-size: 4; -webkit-hyphens: none; -moz-hyphens: none; -ms-hyphens: none; hyphens: none; } pre[class*="language-"]::-moz-selection, pre[class*="language-"] ::-moz-selection, code[class*="language-"]::-moz-selection, code[class*="language-"] ::-moz-selection { text-shadow: none; background: #b3d4fc; } pre[class*="language-"]::selection, pre[class*="language-"] ::selection, code[class*="language-"]::selection, code[class*="language-"] ::selection { text-shadow: none; background: #b3d4fc; } @media print { code[class*="language-"], pre[class*="language-"] { text-shadow: none; } } /* Code blocks */ pre[class*="language-"] { padding: 1em; margin: .5em 0; overflow: auto; } :not(pre)>code[class*="language-"], pre[class*="language-"] { background: #f5f2f0; } /* Inline code */ :not(pre)>code[class*="language-"] { padding: .1em; border-radius: .3em; white-space: normal; } .token.comment, .token.prolog, .token.doctype, .token.cdata { color: slategray; } .token.punctuation { color: #999; } .namespace { opacity: .7; } .token.property, .token.tag, .token.boolean, .token.number, .token.constant, .token.symbol, .token.deleted { color: #905; } .token.selector, .token.attr-name, .token.string, .token.char, .token.builtin, .token.inserted { color: #690; } .token.operator, .token.entity, .token.url, .language-css .token.string, .style .token.string { color: #9a6e3a; background: hsla(0, 0%, 100%, .5); } .token.atrule, .token.attr-value, .token.keyword { color: #07a; } .token.function, .token.class-name { color: #DD4A68; } .token.regex, .token.important, .token.variable { color: #e90; } .token.important, .token.bold { font-weight: bold; } .token.italic { font-style: italic; } .token.entity { cursor: help; } ================================================ FILE: archive/ktransformers/website/src/assets/iconfont/demo_index.html ================================================ iconfont Demo

  • 复制
    &#xe8b0;
  • 箭头下
    &#xe85e;
  • 进度
    &#xe651;
  • 环形进度条
    &#xe617;
  • 向左1
    &#xe779;
  • &#xe608;
  • 编辑
    &#xe7dd;
  • 删除
    &#xe614;
  • 上传
    &#xe618;
  • 探索-选中
    &#xe621;
  • ellipsis
    &#xe657;
  • 发送
    &#xe60c;
  • 列表
    &#xe62d;
  • 列表
    &#xe639;
  • 重试
    &#xe6bd;
  • Fork 记录
    &#xe826;

Unicode 引用


Unicode 是字体在网页端最原始的应用方式,特点是:

  • 支持按字体的方式去动态调整图标大小,颜色等等。
  • 默认情况下不支持多色,直接添加多色图标会自动去色。

注意:新版 iconfont 支持两种方式引用多色图标:SVG symbol 引用方式和彩色字体图标模式。(使用彩色字体图标需要在「编辑项目」中开启「彩色」选项后并重新生成。)

Unicode 使用步骤如下:

第一步:拷贝项目下面生成的 @font-face

@font-face {
  font-family: 'iconfont';
  src: url('iconfont.woff2?t=1717950820214') format('woff2'),
       url('iconfont.woff?t=1717950820214') format('woff'),
       url('iconfont.ttf?t=1717950820214') format('truetype'),
       url('iconfont.svg?t=1717950820214#iconfont') format('svg');
}

第二步:定义使用 iconfont 的样式

.iconfont {
  font-family: "iconfont" !important;
  font-size: 16px;
  font-style: normal;
  -webkit-font-smoothing: antialiased;
  -moz-osx-font-smoothing: grayscale;
}

第三步:挑选相应图标并获取字体编码,应用于页面

<span class="iconfont">&#x33;</span>

"iconfont" 是你项目下的 font-family。可以通过编辑项目查看,默认是 "iconfont"。

  • 复制
    .icon-copy
  • 箭头下
    .icon-arrow-down
  • 进度
    .icon-usage-progress
  • 环形进度条
    .icon-gen-progress
  • 向左1
    .icon-back
  • .icon-point
  • 编辑
    .icon-edit
  • 删除
    .icon-delete
  • 上传
    .icon-upload-1
  • 探索-选中
    .icon-explore
  • ellipsis
    .icon-ellipsis
  • 发送
    .icon-sent
  • 列表
    .icon-list-list
  • 列表
    .icon-list-icon
  • 重试
    .icon-zhongshi
  • Fork 记录
    .icon-log

font-class 引用


font-class 是 Unicode 使用方式的一种变种,主要是解决 Unicode 书写不直观,语意不明确的问题。

与 Unicode 使用方式相比,具有如下特点:

  • 相比于 Unicode 语意明确,书写更直观。可以很容易分辨这个 icon 是什么。
  • 因为使用 class 来定义图标,所以当要替换图标时,只需要修改 class 里面的 Unicode 引用。

使用步骤如下:

第一步:引入项目下面生成的 fontclass 代码:

<link rel="stylesheet" href="./iconfont.css">

第二步:挑选相应图标并获取类名,应用于页面:

<span class="iconfont icon-xxx"></span>

" iconfont" 是你项目下的 font-family。可以通过编辑项目查看,默认是 "iconfont"。

  • 复制
    #icon-copy
  • 箭头下
    #icon-arrow-down
  • 进度
    #icon-usage-progress
  • 环形进度条
    #icon-gen-progress
  • 向左1
    #icon-back
  • #icon-point
  • 编辑
    #icon-edit
  • 删除
    #icon-delete
  • 上传
    #icon-upload-1
  • 探索-选中
    #icon-explore
  • ellipsis
    #icon-ellipsis
  • 发送
    #icon-sent
  • 列表
    #icon-list-list
  • 列表
    #icon-list-icon
  • 重试
    #icon-zhongshi
  • Fork 记录
    #icon-log

Symbol 引用


这是一种全新的使用方式,应该说这才是未来的主流,也是平台目前推荐的用法。相关介绍可以参考这篇文章 这种用法其实是做了一个 SVG 的集合,与另外两种相比具有如下特点:

  • 支持多色图标了,不再受单色限制。
  • 通过一些技巧,支持像字体那样,通过 font-size, color 来调整样式。
  • 兼容性较差,支持 IE9+,及现代浏览器。
  • 浏览器渲染 SVG 的性能一般,还不如 png。

使用步骤如下:

第一步:引入项目下面生成的 symbol 代码:

<script src="./iconfont.js"></script>

第二步:加入通用 CSS 代码(引入一次就行):

<style>
.icon {
  width: 1em;
  height: 1em;
  vertical-align: -0.15em;
  fill: currentColor;
  overflow: hidden;
}
</style>

第三步:挑选相应图标并获取类名,应用于页面:

<svg class="icon" aria-hidden="true">
  <use xlink:href="#icon-xxx"></use>
</svg>
================================================ FILE: archive/ktransformers/website/src/assets/iconfont/iconfont.css ================================================ @font-face { font-family: "iconfont"; /* Project id 4550268 */ src: url('iconfont.woff2?t=1717950820214') format('woff2'), url('iconfont.woff?t=1717950820214') format('woff'), url('iconfont.ttf?t=1717950820214') format('truetype'), url('iconfont.svg?t=1717950820214#iconfont') format('svg'); } .iconfont { font-family: "iconfont" !important; font-size: 16px; font-style: normal; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; } .icon-copy:before { content: "\e8b0"; } .icon-arrow-down:before { content: "\e85e"; } .icon-usage-progress:before { content: "\e651"; } .icon-gen-progress:before { content: "\e617"; } .icon-back:before { content: "\e779"; } .icon-point:before { content: "\e608"; } .icon-edit:before { content: "\e7dd"; } .icon-delete:before { content: "\e614"; } .icon-upload-1:before { content: "\e618"; } .icon-explore:before { content: "\e621"; } .icon-ellipsis:before { content: "\e657"; } .icon-sent:before { content: "\e60c"; } .icon-list-list:before { content: "\e62d"; } .icon-list-icon:before { content: "\e639"; } .icon-zhongshi:before { content: "\e6bd"; } .icon-log:before { content: "\e826"; } ================================================ FILE: archive/ktransformers/website/src/assets/iconfont/iconfont.js ================================================ window._iconfont_svg_string_4550268='',function(l){var t=(t=document.getElementsByTagName("script"))[t.length-1],c=t.getAttribute("data-injectcss"),t=t.getAttribute("data-disable-injectsvg");if(!t){var i,o,e,a,h,n=function(t,c){c.parentNode.insertBefore(t,c)};if(c&&!l.__iconfont__svg__cssinject__){l.__iconfont__svg__cssinject__=!0;try{document.write("")}catch(t){console&&console.log(t)}}i=function(){var t,c=document.createElement("div");c.innerHTML=l._iconfont_svg_string_4550268,(c=c.getElementsByTagName("svg")[0])&&(c.setAttribute("aria-hidden","true"),c.style.position="absolute",c.style.width=0,c.style.height=0,c.style.overflow="hidden",c=c,(t=document.body).firstChild?n(c,t.firstChild):t.appendChild(c))},document.addEventListener?~["complete","loaded","interactive"].indexOf(document.readyState)?setTimeout(i,0):(o=function(){document.removeEventListener("DOMContentLoaded",o,!1),i()},document.addEventListener("DOMContentLoaded",o,!1)):document.attachEvent&&(e=i,a=l.document,h=!1,d(),a.onreadystatechange=function(){"complete"==a.readyState&&(a.onreadystatechange=null,s())})}function s(){h||(h=!0,e())}function d(){try{a.documentElement.doScroll("left")}catch(t){return void setTimeout(d,50)}s()}}(window); ================================================ FILE: archive/ktransformers/website/src/assets/iconfont/iconfont.json ================================================ { "id": "4550268", "name": "Lexllama", "font_family": "iconfont", "css_prefix_text": "icon-", "description": "Lexllama开源项目使用", "glyphs": [ { "icon_id": "11372665", "name": "复制", "font_class": "copy", "unicode": "e8b0", "unicode_decimal": 59568 }, { "icon_id": "34202237", "name": "箭头下", "font_class": "arrow-down", "unicode": "e85e", "unicode_decimal": 59486 }, { "icon_id": "7766233", "name": "进度", "font_class": "usage-progress", "unicode": "e651", "unicode_decimal": 58961 }, { "icon_id": "38865122", "name": "环形进度条", "font_class": "gen-progress", "unicode": "e617", "unicode_decimal": 58903 }, { "icon_id": "577406", "name": "向左1", "font_class": "back", "unicode": "e779", "unicode_decimal": 59257 }, { "icon_id": "1920286", "name": "点", "font_class": "point", "unicode": "e608", "unicode_decimal": 58888 }, { "icon_id": "8866967", "name": "编辑", "font_class": "edit", "unicode": "e7dd", "unicode_decimal": 59357 }, { "icon_id": "10199175", "name": "删除", "font_class": "delete", "unicode": "e614", "unicode_decimal": 58900 }, { "icon_id": "1010111", "name": "上传", "font_class": "upload-1", "unicode": "e618", "unicode_decimal": 58904 }, { "icon_id": "351773", "name": "探索-选中", "font_class": "explore", "unicode": "e621", "unicode_decimal": 58913 }, { "icon_id": "564941", "name": "ellipsis", "font_class": "ellipsis", "unicode": "e657", "unicode_decimal": 58967 }, { "icon_id": "1048859", "name": "发送", "font_class": "sent", "unicode": "e60c", "unicode_decimal": 58892 }, { "icon_id": "1304951", "name": "列表", "font_class": "list-list", "unicode": "e62d", "unicode_decimal": 58925 }, { "icon_id": "8676284", "name": "列表", "font_class": "list-icon", "unicode": "e639", "unicode_decimal": 58937 }, { "icon_id": "22290034", "name": "重试", "font_class": "zhongshi", "unicode": "e6bd", "unicode_decimal": 59069 }, { "icon_id": "22961085", "name": "Fork 记录", "font_class": "log", "unicode": "e826", "unicode_decimal": 59430 } ] } ================================================ FILE: archive/ktransformers/website/src/components/chat/index.vue ================================================ ================================================ FILE: archive/ktransformers/website/src/conf/config.ts ================================================ declare global { interface Window { configWeb: { apiUrl: string; port: string; }; } } export const baseURL = window.configWeb.apiUrl; export const basePort = window.configWeb.port; ================================================ FILE: archive/ktransformers/website/src/locals/en.js ================================================ // en.js export default { home: { explore: 'Explore', language: 'Choose Language', english: 'English', chinese: 'Chinese', today: 'Today', previous:'Previous', withoutAssistantTip:'The KTransformers of this record has been deleted. The user can only view historical conversation information and cannot continue the conversation!', deleteThreadTip:'Deleting records will clear historical information~' }, chat:{ inputTip:"Send a message and chat with the KTransformers ~", }, explore:{ description: "Based on Lexllama, let’s create your own KTransformers~", configuring: "Configuring", completed: "Completed", assistantName: "Name", assistantDescription: "Description", assistantStatus: "Status", createAssistant: "Create New KTransformers", deleteAssistant: "Are you sure to delete this? After deleting the KTransformers, its KVCache will also be cleared simultaneously~", }, config:{ title:'Configure your KTransformers', fileTip:"Only support text, docx, .ppt, .pdf format.", reConfigTip:'Reconfig KTransformers needs to delete kvcache, please choose carefully', secletFile:'Select Files', outOfSize:'File size exceeds 10MB, please reselect', fileExist:'The file already exists, please reselect', createAssistant:'Assistant created successfully, click the build button to start building KVCache', }, build:{ title:'Building Logs', step1:'Parse uploded files', parsingFileStep1:'File upload and reception completed', parsingFileStep2:{ parse:"Parsing", file:"file(s)", total:'total', }, parsingFileStep3:'Prompt loaded, ready to generate KVCache', step2:'Generate KVCache', generateStep1:'Generate KVCache calculation plan', generateStep2:{ calculate:"calculating", token:"tokens", total:'total', }, generateStep3:'KVCache has been generated successfully', durationTime:'Duration:', remainTime:'Time left:', buildProgress:'Building Progress', storageUsage:'KVCache Storage Usage', } } ================================================ FILE: archive/ktransformers/website/src/locals/index.js ================================================ // index.js import { createI18n } from 'vue-i18n' import zh from './zh' import en from './en' const messages = { en, zh, } const language = (navigator.language || 'en').toLocaleLowerCase() // 这是获取浏览器的语言 const i18n = createI18n({ legacy: false, // you must set `false`, to use Compostion API locale: localStorage.getItem('lang') || language.split('-')[0] || 'en', // 首先从缓存里拿,没有的话就用浏览器语言, fallbackLocale: 'en', // 设置备用语言 messages, }) export default i18n ================================================ FILE: archive/ktransformers/website/src/locals/zh.js ================================================ // zh.js export default { home: { explore: '探索', language: '选择语言', english: '英语', chinese: '中文', today: '今天', previous:'历史', withoutAssistantTip:'本记录的KTransformers已被删除,用户只能查看历史对话信息而无法继续对话!', deleteThreadTip:'删除记录会清除历史信息哦~' }, chat:{ inputTip:"发送信息和 KTransformers 畅聊吧~", }, explore:{ description: "基于Lexllama,一起来创建你的专属KTransformers吧~", configuring: "配置中", completed: "完成", assistantName: "名称", assistantDescription: "描述", assistantStatus: "Status", createAssistant: "创建新的KTransformers", deleteAssistant: "是否确认删除KTransformers,删除KTransformers之后其KVCache也会被同步清理掉哦~", }, config:{ title:'配置你的KTransformers', fileTip:"仅支持上传文件格式为 .text, docx, .ppt, .pdf format.", secletFile:'选择文件', outOfSize:'文件大小超出10MB,请重新选择', fileExist:'文件已存在,请重新选择', createAssistant:'KTransformers创建成功,点击build按钮开始构建KVCache', }, build:{ title:'构建日志', step1:'解析上传文件', parsingFileStep1:'文件上传接收完成', parsingFileStep2:{ parse:"正在解析第", file:"文件", total:'共', }, parsingFileStep3:'Prompt装载完毕,准备生成KVCache', step2:'生成 KVCache', generateStep1:'生成KVCache计算计划', generateStep2:{ calculate:"正在计算", token:"tokens", total:'共', }, generateStep3:'KVCache已生成完成', durationTime:'持续时间:', remainTime:'剩余时间:', buildProgress:'构建进度', storageUsage:'存储使用:', } } ================================================ FILE: archive/ktransformers/website/src/main.ts ================================================ import { createApp } from 'vue' import App from './App.vue' import router from './router' import store from './store' import ElementPlus from 'element-plus' import 'element-plus/dist/index.css' import VueApexCharts from "vue3-apexcharts" import i18n from '@/locals' const app = createApp(App) app.use(ElementPlus) app.use(i18n) app.use(VueApexCharts) app.use(store) app.use(router) app.mount('#app') ================================================ FILE: archive/ktransformers/website/src/router/index.ts ================================================ import { createRouter, createWebHashHistory, RouteRecordRaw, createWebHistory } from 'vue-router' import HomeView from '@/views/home.vue' const routes: Array = [ { path: '/', name: 'home', component: HomeView, redirect: '/chat', children: [{ path: '/chat', name: '', component: () => import(/* webpackChunkName: "about" */ '../components/chat/index.vue') },] }, ] const router = createRouter({ history: createWebHashHistory(), routes }) export default router ================================================ FILE: archive/ktransformers/website/src/shims-vue.d.ts ================================================ /* eslint-disable */ declare module '*.vue' { import type { DefineComponent } from 'vue' const component: DefineComponent<{}, {}, any> export default component } declare module '@/locals' declare module 'pdfobject'; ================================================ FILE: archive/ktransformers/website/src/store/index.ts ================================================ import { createStore } from 'vuex' export default createStore({ state: { }, getters: { }, mutations: { }, actions: { }, modules: { } }) ================================================ FILE: archive/ktransformers/website/src/utils/copy.ts ================================================ import { ElMessage } from "element-plus"; const copy = (value: string) => { //Try using the navigator.clipboard.writeText method if (navigator.clipboard && window.isSecureContext) { navigator.clipboard.writeText(value) .then(() => { //Using ElMessage to Display Success Messages in Windows Systems if (navigator.appVersion.includes("Win")) { ElMessage({ message: "内容复制成功!", type: "success", plain: true, }); } else { //Using custom DOM elements to display success messages in macOS system showCopySuccessMessage(); } }) .catch(() => { //Using ElMessage to Display Failure Messages in Windows Systems if (navigator.appVersion.includes("Win")) { ElMessage({ message: "内容复制失败!", type: "error", plain: true, }); } else { //Using custom DOM elements to display failure messages in macOS system showCopyErrorMessage(); } }); } else { const textarea = document.createElement("textarea"); textarea.value = value; document.body.appendChild(textarea); textarea.select(); try { const successful = document.execCommand('copy'); if (successful) { if (navigator.appVersion.includes("Win")) { ElMessage({ message: "内容复制成功!", type: "success", plain: true, }); } else { showCopySuccessMessage(); } } else { if (navigator.appVersion.includes("Win")) { ElMessage({ message: "内容复制失败!", type: "error", plain: true, }); } else { showCopyErrorMessage(); } } } catch (err) { if (navigator.appVersion.includes("Win")) { ElMessage({ message: "内容复制失败!", type: "error", plain: true, }); } else { showCopyErrorMessage(); } } document.body.removeChild(textarea); } }; function showCopySuccessMessage() { const messageElement = document.createElement('div'); messageElement.textContent = '内容复制成功!'; messageElement.style.position = 'fixed'; messageElement.style.bottom = '10px'; messageElement.style.left = '50%'; messageElement.style.transform = 'translateX(-50%)'; messageElement.style.padding = '10px'; messageElement.style.backgroundColor = '#4CAF50'; messageElement.style.color = 'white'; messageElement.style.borderRadius = '15px'; messageElement.style.zIndex = '1000'; document.body.appendChild(messageElement); setTimeout(() => { document.body.removeChild(messageElement); }, 3000); } function showCopyErrorMessage() { const messageElement = document.createElement('div'); messageElement.textContent = '内容复制失败!'; messageElement.style.position = 'fixed'; messageElement.style.bottom = '10px'; messageElement.style.left = '50%'; messageElement.style.transform = 'translateX(-50%)'; messageElement.style.padding = '10px'; messageElement.style.backgroundColor = '#F44336'; messageElement.style.color = 'white'; messageElement.style.borderRadius = '5px'; messageElement.style.zIndex = '1000'; document.body.appendChild(messageElement); setTimeout(() => { document.body.removeChild(messageElement); }, 3000); } export default copy; ================================================ FILE: archive/ktransformers/website/src/utils/types.ts ================================================ export interface IAssistant { id: string; object: string; created_at: number; name?: string; description?: string; model: string; instructions?: string; tools: any[]; tool_resources?: object; metadata?:{[key:string]:any} top_p?: number; temperature?: number; response_format: string | object; } export interface IAssistantWithStatus { build_status:{status:string} id: string; object: string; created_at: number; name?: string; description?: string; model: string; instructions?: string; tools: any[]; tool_resources?: object; metadata?:{[key:string]:any} top_p?: number; temperature?: number; response_format: string | object; } export interface IMessage { id: string; object: string; created_at: number; thread_id: string; status: string; incomplete_details?: object; completed_at?: number; incomplete_at?: number; role: string; content: any[]; assistant_id?: string; run_id?: string; attachments?: any[]; metadata:{[key:string]:any} } export interface IThread { id: string; object: string; created_at: number; tool_resources?: object; metadata?:{[key:string]:any} } export interface IRun { id: string; object: string; created_at: number; thread_id: string, assistant_id: string, status: string, required_action?: object, last_error?: object, expires_at?: number, started_at?: number, cancelled_at?: number, failed_at?: number, completed_at?: number, incomplete_details?: object, model: string, instructions: string, tools: any[], metadata: Map, usage?: object, temperature?: number, top_p?: number, max_prompt_tokens?: number, max_completion_tokens?: number, truncation_strategy: object, tool_choice: string | object, response_format: string | object, } export interface IFile { id: string, bytes: number, created_at: number, filename: string, object: string, purpose: string, } export interface IMessageData { role: string; content: any[]; created_at?: number; assistant_id?: string, } export interface IThreadAndMessageAndAssistant { thread: IThread; first_message: IMessage; assistant: IAssistantWithStatus } export interface IDeleteResult { id: string; object: string; deleted: boolean; } export interface IBuildData { parsed_file_count:number; total_file_count:number; prefilling_current:number; prefilling_total:number; build_completed_time:number; build_started_time:number; storage_total:number; storage_usage:number; status:string } ================================================ FILE: archive/ktransformers/website/src/views/home.vue ================================================ ================================================ FILE: archive/ktransformers/website/tests/unit/example.spec.ts ================================================ import { shallowMount } from '@vue/test-utils' import HelloWorld from '@/components/HelloWorld.vue' describe('HelloWorld.vue', () => { it('renders props.msg when passed', () => { const msg = 'new message' const wrapper = shallowMount(HelloWorld, { props: { msg } }) expect(wrapper.text()).toMatch(msg) }) }) ================================================ FILE: archive/ktransformers/website/tsconfig.json ================================================ { "compilerOptions": { "target": "es5", "module": "esnext", "strict": true, "jsx": "preserve", "importHelpers": true, "moduleResolution": "node", "skipLibCheck": true, "esModuleInterop": true, "allowSyntheticDefaultImports": true, "forceConsistentCasingInFileNames": true, "useDefineForClassFields": true, "sourceMap": true, "allowJs": true, "baseUrl": ".", "types": [ "webpack-env", "jest" ], "paths": { "@/*": [ "src/*" ] }, "lib": [ "esnext", "dom", "dom.iterable", "scripthost" ] }, "include": [ "src/**/*.ts", "src/**/*.tsx", "src/**/*.vue", "tests/**/*.ts", "tests/**/*.tsx", "config.d.ts" ], "exclude": [ "node_modules" ] } ================================================ FILE: archive/ktransformers/website/vue.config.js ================================================ module.exports = { // 配置 webpack-dev-server 行为。 devServer: { open: false, // 编译后默认打开浏览器 host: '0.0.0.0', // 域名 port: 8082, // 端口 https: false, // 是否https proxy: { '/api': { target: 'http://localhost:9016/v1', // 你的后端服务器地址 changeOrigin: true, // 是否允许跨域 pathRewrite: { '/api': '' // 将 '/api' 前缀替换为空,如果你的后端不需要这个前缀 } } } }, publicPath: '/web/', // 基本路径 outputDir: 'dist', // 构建时的输出目录 assetsDir: 'static', // 放置静态资源的目录 indexPath: 'index.html', // html 的输出路径 filenameHashing: true, // 文件名哈希值 lintOnSave: false, // 是否在保存的时候使用 `eslint-loader` 进行检查。 // 组件是如何被渲染到页面中的? (ast:抽象语法树;vDom:虚拟DOM) // template ---> ast ---> render ---> vDom ---> 真实的Dom ---> 页面 // runtime-only:将template在打包的时候,就已经编译为render函数 // runtime-compiler:在运行的时候才去编译template runtimeCompiler: false, transpileDependencies: [], // babel-loader 默认会跳过 node_modules 依赖。 productionSourceMap: false, // 是否为生产环境构建生成 source map //调整内部的 webpack 配置 configureWebpack: () => {}, chainWebpack: () => {}, } ================================================ FILE: archive/merge_tensors/merge_safetensor_gguf.py ================================================ # this script targets to merge the fp8 safe tensor and the gguf quantized tensors. import os # insert the path of the project import sys # sys.path.insert(0, "/home/azure/ktransformers") import argparse import torch from ktransformers.util.custom_loader import GGUFLoader, translate_name_to_gguf from safetensors import safe_open from safetensors.torch import save_file import re from collections import defaultdict def read_safetensor_keys_from_folder(folder_path)->dict: """ :param folder_path: folder path :return: key_to_file_map """ # check if the folder path is exist if not os.path.exists(folder_path): raise FileNotFoundError(f"GGUF dir not found: {folder_path}") if os.path.isfile(folder_path): folder_path = os.path.dirname(folder_path) key_to_file_map = {} found_safetensor = False for root, dirs, files in os.walk(folder_path): # sort files files = sorted(files) for file in files: if file.endswith(".safetensors"): found_safetensor = True file_path = os.path.join(root, file) try: with safe_open(file_path, framework="pt") as f: for key in f.keys(): if "model.layers.61" in key: # skip MTP layer continue # try: # if int(key.split('.')[2]) > 4: # continue # except: # pass key_to_file_map[key] = file_path except Exception as e: print(f"Error reading Safetensor file {file_path}: {e}") if not found_safetensor: raise FileNotFoundError(f"No Safetensor files found in {folder_path}") return key_to_file_map tensor_from_gguf = [] # todo: add keys in gguf that should be used in the final tensor def translate_name(name:str)->str: """ :param name: name of the tensor :return: translated name """ name = translate_name_to_gguf(name) name = name.replace(".up_proj.", ".ffn_up_exps.") name = name.replace(".down_proj.", ".ffn_down_exps.") name = name.replace(".gate_proj.", ".ffn_gate_exps.") name = name.replace(".ffn_gate_inp.e_score_correction_bias", ".exp_probs_b.bias") return name def combine_tensor_sources(safetensor_path:str, gguf_path:str): gguf_loader = GGUFLoader(gguf_path) gguf_tensor_file_map = gguf_loader.tensor_file_map safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path) # build a map for the key to the tensor # according to the key, we can get the tensor from the file target_tensor_map = {} for key in safetensor_tensor_file_map.keys(): # for all experts, we use the gguf tensor if ".mlp.experts." in key: if '.weight_scale_inv' in key: continue key = '.'.join(key.split('.')[:5]+key.split('.')[-2:]) translated_key = translate_name(key) target_tensor_map[key] = gguf_tensor_file_map[translated_key] continue if any(target_key in key for target_key in tensor_from_gguf): target_tensor_map[key] = gguf_tensor_file_map[translate_name(key)] else: target_tensor_map[key] = safetensor_tensor_file_map[key] return target_tensor_map, gguf_loader def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader): # Ensure output directory exists os.makedirs(output_path, exist_ok=True) # Cache for safetensor file handles and GGUF loaders safetensors_cache = {} gguf_cache = {} # Group tensors by layer layer_groups = defaultdict(list) non_layer_keys = [] layer_pattern = re.compile(r'\.layers\.(\d+)\.') for key in target_tensor_map: match = layer_pattern.search(key) if match: layer_num = int(match.group(1)) layer_groups[layer_num].append(key) else: non_layer_keys.append(key) # Calculate total shards total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1 if total_shards == 0: raise ValueError("No tensors to save") shard_idx = 0 # Save non-layer tensors to the first shard if they exist if non_layer_keys: tensors = {} for key in non_layer_keys: file_path = target_tensor_map[key] tensor = None ggml_type = None if file_path.endswith('.safetensors'): if file_path not in safetensors_cache: safetensors_cache[file_path] = safe_open(file_path, framework='pt') f = safetensors_cache[file_path] tensor = f.get_tensor(key) elif file_path.endswith('.gguf'): gguf_name = translate_name(key) tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name) else: raise ValueError(f"Unsupported file format: {file_path}") tensors[translate_name(key)] = tensor if ggml_type: ggml_type = torch.tensor(ggml_type) ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type" tensors[ggml_key] = ggml_type output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors") print(f"Saving non-layer tensors to {output_file}") save_file(tensors, output_file) print(tensors.keys()) shard_idx += 1 # Save each layer's tensors to subsequent shards for layer_num in sorted(layer_groups.keys()): layer_keys = layer_groups[layer_num] tensors = {} for key in layer_keys: file_path = target_tensor_map[key] tensor = None ggml_type = None if file_path.endswith('.safetensors'): if file_path not in safetensors_cache: safetensors_cache[file_path] = safe_open(file_path, framework='pt') f = safetensors_cache[file_path] tensor = f.get_tensor(key) tensor_info = tensor.shape elif file_path.endswith('.gguf'): gguf_name = translate_name(key) tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name) # tensor_info = gguf_loader.tensor_info[gguf_name] # ggml_type = gguf_loader.tensor_info[gguf_name]['ggml_type'] else: raise ValueError(f"Unsupported file format: {file_path}") tensors[translate_name(key)] = tensor if ggml_type: ggml_type = torch.tensor(ggml_type) ggml_key = translate_name(key)[:-7] + ".ggml_type" if translate_name(key).endswith(".weight") else translate_name(key) + ".ggml_type" tensors[ggml_key] = ggml_type output_file = os.path.join(output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors") print(f"Saving layer {layer_num} to {output_file}") # print(tensors.keys()) save_file(tensors, output_file) shard_idx += 1 return def main(): # 创建命令行参数解析器 parser = argparse.ArgumentParser(description="Read parameters from Safetensor and GGUF files") parser.add_argument("--safetensor_path", type=str, help="Path to the Safetensor file", default="/mnt/data/model/DeepSeek-V3") parser.add_argument("--gguf_path", type=str, help="Path to the GGUF file", default="/mnt/data/model/DeepseekV3-q4km-gguf") parser.add_argument("--output_path", type=str, help="Path to the output file", default="/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8") # print all the arguments print("All the arguments:") print(parser.parse_args()) # 解析命令行参数 args = parser.parse_args() safetensor_path = args.safetensor_path gguf_path = args.gguf_path output_path = args.output_path target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path) write_combined_tensor(target_tensor_map, output_path, gguf_loader) return if __name__ == "__main__": main() ================================================ FILE: archive/merge_tensors/merge_safetensor_gguf_for_qwen3.py ================================================ # coding=utf-8 # Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved. # Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import argparse import torch from ktransformers.util.custom_loader import GGUFLoader, translate_name_to_gguf from safetensors import safe_open from safetensors.torch import save_file import re from collections import defaultdict def read_safetensor_keys_from_folder(folder_path) -> dict: if not os.path.exists(folder_path): raise FileNotFoundError(f"Safetensors dir not found: {folder_path}") if os.path.isfile(folder_path): folder_path = os.path.dirname(folder_path) key_to_file_map = {} found_safetensor = False for root, dirs, files in os.walk(folder_path): files = sorted(files) for file in files: if not file.endswith(".safetensors"): continue found_safetensor = True file_path = os.path.join(root, file) try: with safe_open(file_path, framework="pt") as f: for key in f.keys(): key_to_file_map[key] = file_path except Exception as e: print(f"Error reading Safetensor file {file_path}: {e}") if not found_safetensor: raise FileNotFoundError(f"No Safetensor files found in {folder_path}") return key_to_file_map # 可选:如果你希望对某些非 MoE tensor 也用 GGUF,可以把关键子串填到下面这个列表里 tensor_from_gguf = [] # e.g. ["self_attn.q_proj.weight"] def translate_name(name: str) -> str: name = translate_name_to_gguf(name) name = name.replace(".up_proj.", ".ffn_up_exps.") name = name.replace(".down_proj.", ".ffn_down_exps.") name = name.replace(".gate_proj.", ".ffn_gate_exps.") name = name.replace(".ffn_gate_inp.e_score_correction_bias", ".exp_probs_b.bias") return name def combine_tensor_sources(safetensor_path: str, gguf_path: str): gguf_loader = GGUFLoader(gguf_path) gguf_tensor_file_map = gguf_loader.tensor_file_map safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path) target_tensor_map = {} for key, st_file in safetensor_tensor_file_map.items(): if ".mlp.experts." in key and key.endswith(".weight"): parts = key.split(".") if len(parts) < 8: raise ValueError(f"Unexpected MoE expert key format: {key}") norm_key = ".".join(parts[:5] + parts[-2:]) gguf_name = translate_name(norm_key) if gguf_name not in gguf_tensor_file_map: raise KeyError( f"[MoE] GGUF tensor not found for safetensors key {key} -> {gguf_name}" ) target_tensor_map[norm_key] = gguf_tensor_file_map[gguf_name] continue if any(tag in key for tag in tensor_from_gguf): gguf_name = translate_name(key) if gguf_name not in gguf_tensor_file_map: raise KeyError( f"[Non-MoE] GGUF tensor not found for safetensors key {key} -> {gguf_name}" ) target_tensor_map[key] = gguf_tensor_file_map[gguf_name] else: target_tensor_map[key] = st_file return target_tensor_map, gguf_loader def write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader): os.makedirs(output_path, exist_ok=True) safetensors_cache = {} layer_groups = defaultdict(list) non_layer_keys = [] layer_pattern = re.compile(r"\.layers\.(\d+)\.") for key in target_tensor_map: m = layer_pattern.search(key) if m: layer_num = int(m.group(1)) layer_groups[layer_num].append(key) else: non_layer_keys.append(key) total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1 if total_shards <= 0: raise ValueError("No tensors to save") shard_idx = 0 if non_layer_keys: tensors = {} for key in non_layer_keys: file_path = target_tensor_map[key] tensor = None ggml_type = None if file_path.endswith(".safetensors"): if file_path not in safetensors_cache: safetensors_cache[file_path] = safe_open(file_path, framework="pt") f = safetensors_cache[file_path] tensor = f.get_tensor(key) elif file_path.endswith(".gguf"): gguf_name = translate_name(key) tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name) else: raise ValueError(f"Unsupported file format: {file_path}") out_key = translate_name(key) tensors[out_key] = tensor if ggml_type is not None: ggml_type = torch.tensor(ggml_type) if out_key.endswith(".weight"): ggml_key = out_key[:-7] + ".ggml_type" else: ggml_key = out_key + ".ggml_type" tensors[ggml_key] = ggml_type output_file = os.path.join( output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors" ) print(f"[WRITE] Saving non-layer tensors to {output_file}") save_file(tensors, output_file) shard_idx += 1 for layer_num in sorted(layer_groups.keys()): layer_keys = layer_groups[layer_num] tensors = {} for key in layer_keys: file_path = target_tensor_map[key] tensor = None ggml_type = None if file_path.endswith(".safetensors"): if file_path not in safetensors_cache: safetensors_cache[file_path] = safe_open(file_path, framework="pt") f = safetensors_cache[file_path] tensor = f.get_tensor(key) elif file_path.endswith(".gguf"): gguf_name = translate_name(key) tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name) else: raise ValueError(f"Unsupported file format: {file_path}") out_key = translate_name(key) tensors[out_key] = tensor if ggml_type is not None: ggml_type = torch.tensor(ggml_type) if out_key.endswith(".weight"): ggml_key = out_key[:-7] + ".ggml_type" else: ggml_key = out_key + ".ggml_type" tensors[ggml_key] = ggml_type output_file = os.path.join( output_path, f"model-{shard_idx:05}-of-{total_shards:05}.safetensors" ) print(f"[WRITE] Saving layer {layer_num} to {output_file}") save_file(tensors, output_file) shard_idx += 1 def main(): parser = argparse.ArgumentParser( description="Merge FP8 safetensors and GGUF tensors for Qwen3-30B-A3B" ) parser.add_argument( "--safetensor_path", type=str, help="Path to the FP8 Safetensor folder", default="/mnt/data/model/Qwen3-30B-A3B-FP8", ) parser.add_argument( "--gguf_path", type=str, help="Path to the GGUF file or folder", default="/mnt/data/model/Qwen3-30B-A3B-GGUF", ) parser.add_argument( "--output_path", type=str, help="Path to the output safetensors folder", default="/mnt/data/model/ktrans-safetensors/Qwen3-30B-A3B-q4km-fp8", ) args = parser.parse_args() print("[ARGS]", args) safetensor_path = args.safetensor_path gguf_path = args.gguf_path output_path = args.output_path target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path) write_combined_tensor(target_tensor_map, output_path, gguf_loader) if __name__ == "__main__": main() ================================================ FILE: archive/pyproject.toml ================================================ [build-system] requires = [ "setuptools", "torch >= 2.3.0", "ninja", "packaging", "cpufeature" ] build-backend = "setuptools.build_meta" [project] name = "ktransformers" dynamic = ["version"] dependencies = [ "torch >= 2.3.0", "transformers", "fastapi >= 0.111.0", "uvicorn >= 0.30.1", "langchain >= 0.2.0", "blessed >= 1.20.0", "accelerate >= 0.31.0", "sentencepiece >= 0.1.97", "setuptools", "ninja", "wheel", "colorlog", "build", "fire", "protobuf", ] requires-python = ">=3.10" authors = [ {name = "KVCache.AI", email = "zhang.mingxing@outlook.com"} ] maintainers = [ {name = "james0zan", email = "zhang.mingxing@outlook.com"}, {name = "awake", email = "awake@approaching.ai"}, {name = "unicorn chan", email = "nl@approaching.ai"} ] description = "KTransformers, pronounced as Quick Transformers, is designed to enhance your Transformers experience with advanced kernel optimizations and placement/parallelism strategies." readme = "README.md" license = {file = "LICENSE"} keywords = ["ktransformers", "llm"] classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12" ] [project.urls] Homepage = "https://kvcache.ai" Repository = "https://github.com/kvcache-ai/ktransformers.git" Issues = "https://github.com/kvcache-ai/ktransformers/issues" [project.scripts] ktransformers = "ktransformers.server.main:main" [tool.setuptools.packages.find] where = ["./", ] include = ["ktransformers","ktransformers.*"] [tool.black] line-length = 120 preview = true unstable = true ================================================ FILE: archive/requirements-local_chat.txt ================================================ fire transformers numpy torch>=2.3.0 packaging cpufeature; sys_platform == 'win32' or sys_platform == 'Windows' protobuf tiktoken blobfile ================================================ FILE: archive/setup.py ================================================ #!/usr/bin/env python # coding=utf-8 ''' Description : Author : chenxl Date : 2024-07-27 16:15:27 Version : 1.0.0 LastEditors : chenxl LastEditTime : 2024-08-14 16:36:19 Adapted from: https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py Copyright (c) 2023, Tri Dao. Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ''' import os import sys import re import ast from collections import deque import subprocess import select import time import platform import shutil from typing import List, Optional, Literal import http.client import urllib.request import urllib.error from pathlib import Path from packaging.version import parse import torch import torch.version from wheel.bdist_wheel import bdist_wheel as _bdist_wheel from setuptools import setup, Extension from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME try: from torch_musa.utils.simple_porting import SimplePorting from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME except ImportError: MUSA_HOME=None KTRANSFORMERS_BUILD_XPU = torch.xpu.is_available() try: import torch_npu KTRANSFORMERS_BUILD_NPU = torch_npu.npu.is_available() except: KTRANSFORMERS_BUILD_NPU = False # 检测 DEV_BACKEND 环境变量 dev_backend = os.environ.get("DEV_BACKEND", "").lower() if dev_backend == "xpu": triton_dep = [ "pytorch-triton-xpu==3.3.0" ] else: triton_dep = ["triton>=3.2"] with_balance = os.environ.get("USE_BALANCE_SERVE", "0") == "1" class CpuInstructInfo: CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE") FANCY = "FANCY" AVX512 = "AVX512" AVX2 = "AVX2" CMAKE_NATIVE = "-DLLAMA_NATIVE=ON" CMAKE_FANCY = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON" CMAKE_AVX512 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON" CMAKE_AVX2 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON" class VersionInfo: THIS_DIR = os.path.dirname(os.path.abspath(__file__)) PACKAGE_NAME = "ktransformers" BASE_WHEEL_URL:str = ( "https://github.com/kvcache-ai/ktransformers/releases/download/{tag_name}/{wheel_filename}" ) FORCE_BUILD = os.getenv("KTRANSFORMERS_FORCE_BUILD", "FALSE") == "TRUE" def get_musa_bare_metal_version(self, musa_dir): raw_output = subprocess.run( [musa_dir + "/bin/mcc", "-v"], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.decode("utf-8") output = raw_output.split() release_idx = output.index("version") + 1 bare_metal_version = parse(output[release_idx].split(",")[0]) musa_version = f"{bare_metal_version.major}{bare_metal_version.minor}" return musa_version def get_rocm_bare_metal_version(self, rocm_dir): """ Get the ROCm version from the ROCm installation directory. Args: rocm_dir: Path to the ROCm installation directory Returns: A string representation of the ROCm version (e.g., "63" for ROCm 6.3) """ try: # Try using rocm_agent_enumerator to get version info raw_output = subprocess.check_output( [rocm_dir + "/bin/rocminfo", "--version"], universal_newlines=True, stderr=subprocess.STDOUT) # Extract version number from output match = re.search(r'(\d+\.\d+)', raw_output) if match: version_str = match.group(1) version = parse(version_str) rocm_version = f"{version.major}{version.minor}" return rocm_version except (subprocess.CalledProcessError, FileNotFoundError): # If rocminfo --version fails, try alternative methods pass try: # Try reading version from release file with open(os.path.join(rocm_dir, "share/doc/hip/version.txt"), "r") as f: version_str = f.read().strip() version = parse(version_str) rocm_version = f"{version.major}{version.minor}" return rocm_version except (FileNotFoundError, IOError): pass # If all else fails, try to extract from directory name dir_name = os.path.basename(os.path.normpath(rocm_dir)) match = re.search(r'rocm-(\d+\.\d+)', dir_name) if match: version_str = match.group(1) version = parse(version_str) rocm_version = f"{version.major}{version.minor}" return rocm_version # Fallback to extracting from hipcc version try: raw_output = subprocess.check_output( [rocm_dir + "/bin/hipcc", "--version"], universal_newlines=True, stderr=subprocess.STDOUT) match = re.search(r'HIP version: (\d+\.\d+)', raw_output) if match: version_str = match.group(1) version = parse(version_str) rocm_version = f"{version.major}{version.minor}" return rocm_version except (subprocess.CalledProcessError, FileNotFoundError): pass # If we still can't determine the version, raise an error raise ValueError(f"Could not determine ROCm version from directory: {rocm_dir}") def get_cuda_bare_metal_version(self, cuda_dir): raw_output = subprocess.check_output( [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 bare_metal_version = parse(output[release_idx].split(",")[0]) cuda_version = f"{bare_metal_version.major}{bare_metal_version.minor}" return cuda_version def get_cuda_version_of_torch(self): if KTRANSFORMERS_BUILD_NPU: return 'aarch64' torch_cuda_version = parse(torch.version.cuda) cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" return cuda_version def get_platform(self,): """ Returns the platform name as used in wheel filenames. """ if sys.platform.startswith("linux"): return f'linux_{platform.uname().machine}' elif sys.platform == "win32": return "win_amd64" else: raise ValueError("Unsupported platform: {}".format(sys.platform)) def get_cpu_instruct(self,): if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY: return "fancy" elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512: return "avx512" elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX2: return "avx2" else: print("Using native cpu instruct") if sys.platform.startswith("linux"): if KTRANSFORMERS_BUILD_NPU: return 'aarch64' with open('/proc/cpuinfo', 'r', encoding="utf-8") as cpu_f: cpuinfo = cpu_f.read() flags_line = [line for line in cpuinfo.split( '\n') if line.startswith('flags')][0] flags = flags_line.split(':')[1].strip().split(' ') # fancy with AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI for flag in flags: if 'avx512bw' in flag: return 'fancy' for flag in flags: if 'avx512' in flag: return 'avx512' for flag in flags: if 'avx2' in flag: return 'avx2' raise ValueError( "Unsupported cpu Instructions: {}".format(flags_line)) elif sys.platform == "win32": from cpufeature.extension import CPUFeature if CPUFeature.get("AVX512bw", False): return 'fancy' if CPUFeature.get("AVX512f", False): return 'avx512' if CPUFeature.get("AVX2", False): return 'avx2' raise ValueError( "Unsupported cpu Instructions: {}".format(str(CPUFeature))) else: raise ValueError("Unsupported platform: {}".format(sys.platform)) def get_torch_version(self,): torch_version_raw = parse(torch.__version__) torch_version = f"{torch_version_raw.major}{torch_version_raw.minor}" return torch_version def get_flash_version(self,): version_file = os.path.join( Path(VersionInfo.THIS_DIR), VersionInfo.PACKAGE_NAME, "__init__.py") with open(version_file, "r", encoding="utf-8") as f: version_match = re.search( r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) flash_version = ast.literal_eval(version_match.group(1)) return flash_version def get_package_version(self, full_version=False): flash_version = str(self.get_flash_version()) torch_version = self.get_torch_version() cpu_instruct = self.get_cpu_instruct() backend_version = "" if CUDA_HOME is not None: backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}" elif MUSA_HOME is not None: backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}" elif ROCM_HOME is not None: backend_version = f"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}" elif torch.xpu.is_available(): backend_version = f"xpu" elif KTRANSFORMERS_BUILD_NPU: backend_version = f"npu{torch_npu.__version__}" else: raise ValueError("Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set and XPU is not available.") package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}" if full_version: return package_version if not VersionInfo.FORCE_BUILD: return flash_version return package_version class BuildWheelsCommand(_bdist_wheel): def get_wheel_name(self,): version_info = VersionInfo() package_version = version_info.get_package_version(full_version=True) flash_version = version_info.get_flash_version() python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" wheel_filename = f"{VersionInfo.PACKAGE_NAME}-{package_version}-{python_version}-{python_version}-{version_info.get_platform()}.whl" wheel_url = VersionInfo.BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_filename=wheel_filename) return wheel_filename, wheel_url def run(self): if VersionInfo.FORCE_BUILD: super().run() return wheel_filename, wheel_url = self.get_wheel_name() print("Guessing wheel URL: ", wheel_url) try: urllib.request.urlretrieve(wheel_url, wheel_filename) # Make the archive # Lifted from the root wheel processing command # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 if not os.path.exists(self.dist_dir): os.makedirs(self.dist_dir) impl_tag, abi_tag, plat_tag = self.get_tag() archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") print("Raw wheel path", wheel_path) shutil.move(wheel_filename, wheel_path) except (urllib.error.HTTPError, urllib.error.URLError, http.client.RemoteDisconnected): print("Precompiled wheel not found. Building from source...") # If the wheel could not be downloaded, build from source super().run() ANSI_ESCAPE = re.compile( r'\033[@-Z\\-_\[\]P]|\033\[[0-?]*[ -/]*[@-~]|\033][^\007\033]*\007|[\000-\037]' ) def colored(text, color=None, bold=False): fmt = [] if color== 'red': fmt.append('31') elif color == 'green': fmt.append('32') if bold: fmt.append('1') return f"\033[{';'.join(fmt)}m{text}\033[0m" def split_line(text: str) -> List[str]: """Split text into lines based on terminal width.""" term_width = shutil.get_terminal_size().columns or 80 if not text.strip(): return [] # Split by explicit newlines and wrap long lines lines = [] for line in text.split('\n'): while len(line) > term_width: lines.append(line[:term_width]) line = line[term_width:] if line: lines.append(line) return lines ANSI_ESCAPE = re.compile( r'\033[@-Z\\-_\[\]P]|\033\[[0-?]*[ -/]*[@-~]|\033][^\007\033]*\007|[\000-\037]' ) def colored(text, color=None, bold=False): fmt = [] if color== 'red': fmt.append('31') elif color == 'green': fmt.append('32') if bold: fmt.append('1') return f"\033[{';'.join(fmt)}m{text}\033[0m" def split_line(text: str) -> List[str]: """Split text into lines based on terminal width.""" term_width = shutil.get_terminal_size().columns or 80 if not text.strip(): return [] # Split by explicit newlines and wrap long lines lines = [] for line in text.split('\n'): while len(line) > term_width: lines.append(line[:term_width]) line = line[term_width:] if line: lines.append(line) return lines def run_command_with_live_tail(ext: str, command: List[str], output_lines: int = 20, refresh_rate: float = 0.1, cwd: Optional[str] = None): """ Execute a script-like command with real-time output of the last `output_lines` lines. - during execution: displays the last `output_lines` lines of output in real-time. - On success: Clears the displayed output. - On failure: Prints the full command output. Args: ext (str): the name of the native extension currently building. command (List[str]): The command to execute, as a list of arguments. output_lines (int, optional): Number of terminal lines to display during live output. Defaults to 20. refresh_rate (float, optional): Time in seconds between output refreshes. Defaults to 0.1. cwd (Optional[str], optional): Working directory to run the command in. Defaults to current directory. """ # Dump all subprocess output without any buffering if stdout is not a terminal if not sys.stdout.isatty(): return subprocess.run(command, cwd=cwd, check=True) # Start time for elapsed time calculation start = time.time() # Buffer for all output all_output = [] write_buffer = deque(maxlen=output_lines) # Current number of lines from sub process displayed current_lines = 0 # ANSI escape codes for terminal control CLEAR_LINE = '\033[K' MOVE_UP = '\033[1A' SAVE_CURSOR = '\0337' RESTORE_CURSOR = '\0338' CLEAR_REMAINING = '\033[J' def write_progress(status: Literal['RUNNING', 'SUCCEED', 'FAILED'] = 'RUNNING', new_line: Optional[str] = None): """Update terminal display with latest output""" nonlocal current_lines, process sys.stdout.write(SAVE_CURSOR) sys.stdout.write(MOVE_UP * current_lines) banner = f"ext={ext} pid={process.pid} status={status.upper()} elapsed=({time.time()-start:.2f}S)\n" if status != 'FAILED': banner = colored(banner, 'green', bold=True) else: banner = colored(banner, 'red', bold=True) sys.stdout.write(CLEAR_LINE + banner) if new_line is not None: all_output.append(new_line) write_buffer.extend(split_line(ANSI_ESCAPE.sub('', new_line).rstrip())) elif status == 'RUNNING': sys.stdout.write(RESTORE_CURSOR) sys.stdout.flush() return sys.stdout.write(CLEAR_REMAINING) if status == 'RUNNING': current_lines = 1 + len(write_buffer) for text in write_buffer: sys.stdout.write(text + '\n') elif status == 'FAILED': for text in all_output: sys.stdout.write(text) sys.stdout.flush() # Start subprocess sys.stdout.write(colored(f'ext={ext} command={" ".join(str(c) for c in command)}\n', bold=True)) sys.stdout.flush() process = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd, text=True, bufsize=1 ) try: write_progress() poll_obj = select.poll() poll_obj.register(process.stdout, select.POLLIN) while process.poll() is None: poll_result = poll_obj.poll(refresh_rate * 1000) if poll_result: write_progress(new_line=process.stdout.readline()) else: write_progress() # Get any remaining output while True: line = process.stdout.readline() if not line: break write_progress(new_line=line) except BaseException as e: process.terminate() raise e finally: exit_code = process.wait() write_progress(status='SUCCEED' if exit_code == 0 else 'FAILED') # Convert distutils Windows platform specifiers to CMake -A arguments PLAT_TO_CMAKE = { "win32": "Win32", "win-amd64": "x64", "win-arm32": "ARM", "win-arm64": "ARM64", } class CMakeExtension(Extension): def __init__(self, name: str, sourcedir: str) -> None: super().__init__(name, sources=[]) print(name, sourcedir) self.sourcedir = sourcedir def get_cmake_abi_args(cmake_args): if torch.compiled_with_cxx11_abi(): cmake_args.append("-D_GLIBCXX_USE_CXX11_ABI=1") else: cmake_args.append("-D_GLIBCXX_USE_CXX11_ABI=0") return cmake_args class CMakeBuild(BuildExtension): def build_extension(self, ext) -> None: if not isinstance(ext, CMakeExtension): super().build_extension(ext) return ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) extdir = ext_fullpath.parent.resolve() # Using this requires trailing slash for auto-detection & inclusion of # auxiliary "native" libs debug = int(os.environ.get("DEBUG", 0) ) if self.debug is None else self.debug cfg = "Debug" if debug else "Release" # CMake lets you override the generator - we need to check this. # Can be set with Conda-Build, for example. cmake_generator = os.environ.get("CMAKE_GENERATOR", "") # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code # from Python. cmake_args = [ f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", f"-DPYTHON_EXECUTABLE={sys.executable}", f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm ] if CUDA_HOME is not None: cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"] elif MUSA_HOME is not None: cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"] elif ROCM_HOME is not None: cmake_args += ["-DKTRANSFORMERS_USE_ROCM=ON"] elif KTRANSFORMERS_BUILD_XPU: cmake_args += ["-DKTRANSFORMERS_USE_XPU=ON", "-DKTRANSFORMERS_USE_CUDA=OFF"] elif KTRANSFORMERS_BUILD_NPU: cmake_args += ["-DKTRANSFORMERS_USE_NPU=ON", "-DKTRANSFORMERS_USE_CUDA=OFF"] else: raise ValueError("Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set and XPU is not available.") cmake_args = get_cmake_abi_args(cmake_args) # log cmake_args print("CMake args:", cmake_args) build_args = [] if "CMAKE_ARGS" in os.environ: cmake_args += [ item for item in os.environ["CMAKE_ARGS"].split(" ") if item] if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY: cpu_args = CpuInstructInfo.CMAKE_FANCY elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512: cpu_args = CpuInstructInfo.CMAKE_AVX512 elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX2: cpu_args = CpuInstructInfo.CMAKE_AVX2 else: cpu_args = CpuInstructInfo.CMAKE_NATIVE cmake_args += [ item for item in cpu_args.split(" ") if item ] # In this example, we pass in the version to C++. You might not need to. cmake_args += [ f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"] if self.compiler.compiler_type != "msvc": if not cmake_generator or cmake_generator == "Ninja": pass # try: # import ninja # ninja_executable_path = Path(ninja.BIN_DIR) / "ninja" # cmake_args += [ # "-GNinja", # f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", # ] # except ImportError: # pass else: # Single config generators are handled "normally" single_config = any( x in cmake_generator for x in {"NMake", "Ninja"}) # CMake allows an arch-in-generator style for backward compatibility contains_arch = any(x in cmake_generator for x in {"ARM", "Win64"}) if not single_config and not contains_arch and cmake_generator: cmake_args += ["-A", PLAT_TO_CMAKE[self.plat_name]] # Multi-config generators have a different way to specify configs if not single_config: cmake_args += [ f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}" ] build_args += ["--config", cfg] if sys.platform.startswith("darwin"): # Cross-compile support for macOS - respect ARCHFLAGS if set archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) if archs: cmake_args += [ "-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: cpu_count = os.cpu_count() if cpu_count is None: cpu_count = 1 if hasattr(self, "parallel") and self.parallel: build_args += [f"--parallel={self.parallel}"] else: build_args += [f"--parallel={cpu_count}"] print("CMake args:", cmake_args) build_temp = Path(ext.sourcedir) / "build" print("build_temp:", build_temp) if not build_temp.exists(): build_temp.mkdir(parents=True) run_command_with_live_tail(ext.name, ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp ) run_command_with_live_tail(ext.name, ["cmake", "--build", build_temp, "--verbose", *build_args], cwd=build_temp ) if CUDA_HOME is not None or ROCM_HOME is not None: ops_module = CUDAExtension('KTransformersOps', [ 'csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu', 'csrc/ktransformers_ext/cuda/binding.cpp', 'csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu' ], extra_compile_args={ 'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'], 'nvcc': [ '-O3', # '--use_fast_math', '-Xcompiler', '-fPIC', '-DKTRANSFORMERS_USE_CUDA', ] } ) elif MUSA_HOME is not None: SimplePorting(cuda_dir_path="csrc/ktransformers_ext/cuda", mapping_rule={ # Common rules "at::cuda": "at::musa", "#include ": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"", "#include ": "#include \"torch_musa/csrc/core/MUSAGuard.h\"", "nv_bfloat16": "mt_bfloat16", }).run() ops_module = MUSAExtension('KTransformersOps', [ 'csrc/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu', 'csrc/ktransformers_ext/cuda_musa/binding.cpp', # TODO: Add Marlin support for MUSA. # 'csrc/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu' ], extra_compile_args={ 'cxx': ['force_mcc'], 'mcc': [ '-O3', '-DKTRANSFORMERS_USE_MUSA', '-DTHRUST_IGNORE_CUB_VERSION_CHECK', ] } ) elif torch.xpu.is_available(): #XPUExtension is not available now. ops_module = None elif KTRANSFORMERS_BUILD_NPU: pass else: raise ValueError("Unsupported backend: CUDA_HOME ROCM_HOME MUSA_HOME are not set and XPU is not available.") if not torch.xpu.is_available() and not KTRANSFORMERS_BUILD_NPU: ext_modules = [ CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), ops_module, CUDAExtension( 'vLLMMarlin', [ 'csrc/custom_marlin/binding.cpp', 'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu', 'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu', ], extra_compile_args={ 'cxx': ['-O3'], 'nvcc': ['-O3', '-Xcompiler', '-fPIC'], }, ) ] if with_balance: print("using balance_serve") ext_modules.append( CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve")) ) setup( name=VersionInfo.PACKAGE_NAME, version=VersionInfo().get_package_version(), install_requires=triton_dep, cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, ext_modules=ext_modules ) elif torch.xpu.is_available(): ext_modules = [ CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), ] setup( name=VersionInfo.PACKAGE_NAME, version=VersionInfo().get_package_version(), install_requires=triton_dep, cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, ext_modules=ext_modules ) elif KTRANSFORMERS_BUILD_NPU: ext_modules = [ CMakeExtension("cpuinfer_ext", os.fspath(Path("").resolve() / "csrc" / "ktransformers_ext")), ] if with_balance: print("using balance_serve") ext_modules.append( CMakeExtension("balance_serve", os.fspath(Path("").resolve()/ "csrc"/ "balance_serve")) ) setup( name=VersionInfo.PACKAGE_NAME, version=VersionInfo().get_package_version(), cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild}, ext_modules=ext_modules ) ================================================ FILE: archive/third_party/llamafile/README.md ================================================ The code in this folder is copied from [Mozilla-Ocho/llamafile](https://github.com/Mozilla-Ocho/llamafile). Special thanks to the Mozilla-Ocho team. ================================================ FILE: archive/third_party/llamafile/bench.h ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/bench.h // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi #pragma once #include #include "micros.h" #define BENCH(x) \ do { \ x; \ __asm__ volatile("" ::: "memory"); \ long long start = micros(); \ for (int i = 0; i < ITERATIONS; ++i) { \ __asm__ volatile("" ::: "memory"); \ x; \ __asm__ volatile("" ::: "memory"); \ } \ printf("%9lld us %s\n", (micros() - start + ITERATIONS - 1) / ITERATIONS, #x); \ } while (0) ================================================ FILE: archive/third_party/llamafile/flags.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/flags.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #include "flags.h" bool FLAG_precise = false; ================================================ FILE: archive/third_party/llamafile/flags.h ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/flags.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #pragma once extern bool FLAG_precise; ================================================ FILE: archive/third_party/llamafile/iqk_mul_mat.inc ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat.inc // Copyrigth 2024 Iwan Kawrakow - Apache 2.0 Licens // with additions from // https://github.com/ikawrakow/ik_llama.cpp/blob/main/ggml/src/iqk/iqk_mul_mat.cpp // Copyrigth 2024-2025 Iwan Kawrakow - MIT Licens // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp fenc=utf-8 :vi // // Copyright 2024 Iwan Kawrakow // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // // Copyright (C) 2024-2025 Iwan Kawrakow // MIT license // SPDX-License-Identifier: MIT // #if defined(KTRANSFORMERS_USE_NPU) && KTRANSFORMERS_USE_NPU // use ARM version #include "iqk_mul_mat_arm.inc" #else // use x86 version #include "iqk_mul_mat_x86.inc" #endif ================================================ FILE: archive/third_party/llamafile/iqk_mul_mat_amd_avx2.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat_amd_avx2.cpp // Copyrigth 2024 Iwan Kawrakow. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #if defined(__x86_64__) || defined(_M_X64) #include "iqk_mul_mat.inc" #endif // __x86_64__ ================================================ FILE: archive/third_party/llamafile/iqk_mul_mat_amd_zen4.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat_amd_zen4.cpp // Copyrigth 2024 Iwan Kawrakow. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #if defined(__x86_64__) || defined(_M_X64) #define iqk_mul_mat iqk_mul_mat_zen4 #define iqk_mul_mat_moe iqk_mul_mat_moe_zen4 #include "iqk_mul_mat.inc" #endif // __x86_64__ ================================================ FILE: archive/third_party/llamafile/iqk_mul_mat_arm.inc ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat.inc // Copyrigth 2024 Iwan Kawrakow. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp fenc=utf-8 :vi // // Copyright 2024 Iwan Kawrakow // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #if defined __x86_64__ || defined __aarch64__ || defined(_M_X64) #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-quants.h" #include "sgemm.h" // For i-quants, I had to explicitely specify which // functions to inline / not inline (at least for some // of the functions), else performance would be significantly // lower. This is worrysome as things can change with, // e.g., a different compiler version or running on a different // CPU. #ifdef _MSC_VER #define IQK_NOINLINE __declspec(noinline) #define IQK_ALWAYS_INLINE inline #else #define IQK_NOINLINE __attribute__((__noinline__)) #define IQK_ALWAYS_INLINE __attribute__((always_inline)) #endif #define GGML_COMMON_IMPL_C #include "llama.cpp/ggml-common.h" // clang-format off // This matrix - vector and matrix - matrix multiplication implementation // for legacy quants, k-quants and i-quants makes prompt processing 150-200% // (legacy and k-quants) or 250-400% (i-quants) faster. // compared to mainline llama.cpp (and llamafile). // It provides implementations for ARM_NEON (all quants) and AVX2 // (all quants except sub-4 bit i-quants). // // Main idea is that unpacking the quants and the block scales to // be ready for dot products with the corresponding Q8_Y quants // takes time (here 'Y' stands for K, 0, or 1, depending on quantization type). // Hence, if we are performing a QX x Q8_Y matrix matrix // multiplication (as needed for prompt processing), we can get // a significant speedup by reusing the unpacked QX quants and scales // for multiplication with several Q8_K columns. We also achieve fewer // loads from memory, which is the main purpose of tiling in general // purpose matrix multiplication packages. #include #include #endif constexpr ggml_type GGML_TYPE_Q8_0_X4 = static_cast(98); constexpr ggml_type GGML_TYPE_Q8_1_X4 = static_cast(99); namespace { #define GEMV_Q4K #define GEMV_Q6K #define GEMM_Q4K_Q6K typedef struct { int32_t i1; int32_t i2; } mmid_row_mapping; struct DataInfo { float * s; const char * cy; size_t bs; size_t by; int cur_y = 0; int ne11; const mmid_row_mapping * row_mapping = nullptr; size_t bs2 = 0; inline const char * src1_row(int iy) const { if (!row_mapping) return cy + (cur_y + iy)*by; int i11 = row_mapping[cur_y + iy].i1 % ne11; int i12 = row_mapping[cur_y + iy].i2; return cy + (i11 + i12*ne11)*by; } inline void store(int ix, int iy, float result) const { *(dst_row(iy) + ix) = result; //dst_row(iy)[ix] = result; } inline float* ptr(int ix, int iy) const { return dst_row(iy) + ix; } inline float * dst_row(int iy) const { if (!row_mapping) return s + (cur_y + iy)*bs; int i12 = row_mapping[cur_y + iy].i2; int i1 = row_mapping[cur_y + iy].i1; int i2 = i12; return s + i1*bs + i2*bs2; } }; /* moonll change param for set_mul_mat add func16 */ typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x); typedef void (*mul_mat_t_v2)(int m, int n, int k, const void *vx, size_t bx, const DataInfo& info); struct MulMat { std::array funcs = {}; mul_mat_t func16 = nullptr; mul_mat_t_v2 funcs_v2; //inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) { IQK_NOINLINE void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) { constexpr int k_x_step = 64; // This works best on my Ryzen-7950X and M2 Max CPUs (but differences to other tile size are small) if (func16 && nrc_y >= 16) { int n_step = (nrc_y - info.cur_y)/16; for (int ix = 0; ix < nrc_x; ix += k_x_step) { auto this_info = info; this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; for (int iy = 0; iy < n_step; ++iy) { func16(n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x); this_info.cur_y += 16; } } info.cur_y += 16 * n_step; if (info.cur_y == nrc_y) return; } int n_step = (nrc_y - info.cur_y)/funcs.size(); if (n_step > 0) { for (int ix = 0; ix < nrc_x; ix += k_x_step) { auto this_info = info; this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; for (int iy = 0; iy < n_step; ++iy) { funcs.back()(n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x); this_info.cur_y += funcs.size(); } } info.cur_y += funcs.size() * n_step; } int n_left = nrc_y - info.cur_y; if (n_left > 0) { funcs[n_left-1](n, vx, bx, info, nrc_x); } } #if defined __x86_64__ || defined(_M_X64) static IQK_NOINLINE bool set_mul_mat(int typeA, int typeB,int ne00, MulMat& mm, int Ny); #else IQK_NOINLINE void mul_mat_NxM_v2(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) { funcs_v2(nrc_x, nrc_y, n, vx, bx, info); return; } static IQK_NOINLINE bool set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int Ny); #endif private: template static IQK_NOINLINE void set_functions(MulMat& m); }; inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) { const uint16_t * scales = (const uint16_t *)scales8; const uint32_t a0 = scales[0] | (scales[1] << 16); const uint32_t a1 = scales[2] | (scales[3] << 16); const uint32_t a2 = scales[4] | (scales[5] << 16); aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030); aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030); aux32[2] = a1 & 0x3f3f3f3f; aux32[0] = a0 & 0x3f3f3f3f; } /* moonll decoding tables */ #ifdef __AVX2__ static const uint64_t iq1s_grid_us[2048] = { 0x0000000000000000, 0x0000000000000002, 0x0000000000000101, 0x0000000000000200, 0x0000000000000202, 0x0000000000010001, 0x0000000000010101, 0x0000000000020000, 0x0000000000020002, 0x0000000000020200, 0x0000000000020202, 0x0000000001000101, 0x0000000001010001, 0x0000000001010100, 0x0000000001010102, 0x0000000001020101, 0x0000000002000000, 0x0000000002000002, 0x0000000002000200, 0x0000000002000202, 0x0000000002010101, 0x0000000002020000, 0x0000000002020002, 0x0000000002020200, 0x0000000002020202, 0x0000000100000100, 0x0000000100000101, 0x0000000100010001, 0x0000000100010100, 0x0000000100010102, 0x0000000100010201, 0x0000000100010202, 0x0000000100020101, 0x0000000101000001, 0x0000000101000102, 0x0000000101000201, 0x0000000101010002, 0x0000000101010101, 0x0000000101010202, 0x0000000101020001, 0x0000000101020100, 0x0000000101020102, 0x0000000101020200, 0x0000000102000101, 0x0000000102010001, 0x0000000102010100, 0x0000000102010102, 0x0000000102020101, 0x0000000200000000, 0x0000000200000002, 0x0000000200000200, 0x0000000200000202, 0x0000000200010101, 0x0000000200020000, 0x0000000200020002, 0x0000000200020200, 0x0000000200020202, 0x0000000201000101, 0x0000000201010001, 0x0000000201010201, 0x0000000201020100, 0x0000000201020201, 0x0000000202000000, 0x0000000202000002, 0x0000000202000200, 0x0000000202000202, 0x0000000202010001, 0x0000000202010101, 0x0000000202010201, 0x0000000202020000, 0x0000000202020002, 0x0000000202020200, 0x0000000202020202, 0x0000010000010001, 0x0000010000010100, 0x0000010000010102, 0x0000010000020101, 0x0000010001000001, 0x0000010001000201, 0x0000010001010101, 0x0000010001010202, 0x0000010001020100, 0x0000010001020101, 0x0000010002010001, 0x0000010002010201, 0x0000010002020101, 0x0000010100000001, 0x0000010100000100, 0x0000010100000101, 0x0000010100000102, 0x0000010100010101, 0x0000010100010200, 0x0000010100010202, 0x0000010100020201, 0x0000010101000000, 0x0000010101000101, 0x0000010101000202, 0x0000010101010000, 0x0000010101010001, 0x0000010101010100, 0x0000010101010101, 0x0000010101010102, 0x0000010101010201, 0x0000010101020000, 0x0000010101020002, 0x0000010101020101, 0x0000010101020200, 0x0000010101020202, 0x0000010102000001, 0x0000010102010001, 0x0000010102010101, 0x0000010102010200, 0x0000010102010202, 0x0000010102020001, 0x0000010102020100, 0x0000010102020101, 0x0000010102020102, 0x0000010102020201, 0x0000010200010100, 0x0000010200010201, 0x0000010201000001, 0x0000010201000100, 0x0000010201010000, 0x0000010201010002, 0x0000010201010101, 0x0000010201010200, 0x0000010201020000, 0x0000010201020001, 0x0000010201020102, 0x0000010201020201, 0x0000010202000101, 0x0000010202010001, 0x0000010202010100, 0x0000010202010201, 0x0000020000000000, 0x0000020000000002, 0x0000020000000200, 0x0000020000000202, 0x0000020000010101, 0x0000020000020000, 0x0000020000020002, 0x0000020000020200, 0x0000020000020202, 0x0000020001000101, 0x0000020001010001, 0x0000020001010102, 0x0000020001020101, 0x0000020002000000, 0x0000020002000002, 0x0000020002000200, 0x0000020002000202, 0x0000020002010101, 0x0000020002020000, 0x0000020002020002, 0x0000020002020200, 0x0000020002020202, 0x0000020100000101, 0x0000020100010001, 0x0000020100010100, 0x0000020100010201, 0x0000020100020100, 0x0000020100020101, 0x0000020101000001, 0x0000020101010000, 0x0000020101010001, 0x0000020101010101, 0x0000020101020001, 0x0000020101020100, 0x0000020101020201, 0x0000020102010001, 0x0000020102010100, 0x0000020102010102, 0x0000020102010201, 0x0000020102020101, 0x0000020200000000, 0x0000020200000002, 0x0000020200000200, 0x0000020200000202, 0x0000020200010101, 0x0000020200020000, 0x0000020200020002, 0x0000020200020200, 0x0000020200020202, 0x0000020201000101, 0x0000020201010001, 0x0000020201010201, 0x0000020201020001, 0x0000020201020101, 0x0000020202000000, 0x0000020202000002, 0x0000020202000101, 0x0000020202000200, 0x0000020202000202, 0x0000020202010101, 0x0000020202020000, 0x0000020202020002, 0x0000020202020200, 0x0000020202020202, 0x0001000000010000, 0x0001000000010001, 0x0001000000010100, 0x0001000000010201, 0x0001000000020100, 0x0001000000020101, 0x0001000001000001, 0x0001000001000100, 0x0001000001010000, 0x0001000001010101, 0x0001000001010200, 0x0001000001020001, 0x0001000001020100, 0x0001000001020101, 0x0001000001020201, 0x0001000002010001, 0x0001000002010100, 0x0001000002010102, 0x0001000002020001, 0x0001000002020101, 0x0001000100000001, 0x0001000100000100, 0x0001000100000102, 0x0001000100000201, 0x0001000100010000, 0x0001000100010002, 0x0001000100010101, 0x0001000100010200, 0x0001000100020001, 0x0001000100020100, 0x0001000100020201, 0x0001000101000101, 0x0001000101000202, 0x0001000101010000, 0x0001000101010001, 0x0001000101010002, 0x0001000101010100, 0x0001000101010101, 0x0001000101010102, 0x0001000101010201, 0x0001000101020000, 0x0001000101020101, 0x0001000102000100, 0x0001000102010002, 0x0001000102010101, 0x0001000102020001, 0x0001000102020100, 0x0001000200010001, 0x0001000200010100, 0x0001000200010102, 0x0001000200020101, 0x0001000201000000, 0x0001000201000102, 0x0001000201000201, 0x0001000201010002, 0x0001000201010101, 0x0001000201010200, 0x0001000201010202, 0x0001000201020100, 0x0001000201020102, 0x0001000202000101, 0x0001000202010001, 0x0001000202010100, 0x0001000202010102, 0x0001000202020101, 0x0001010000000001, 0x0001010000000102, 0x0001010000000201, 0x0001010000010100, 0x0001010000010101, 0x0001010000010200, 0x0001010000010201, 0x0001010000020001, 0x0001010000020102, 0x0001010001000001, 0x0001010001000101, 0x0001010001000102, 0x0001010001000200, 0x0001010001000202, 0x0001010001010001, 0x0001010001010100, 0x0001010001010101, 0x0001010001010102, 0x0001010001010201, 0x0001010001020002, 0x0001010001020101, 0x0001010001020200, 0x0001010002000100, 0x0001010002000201, 0x0001010002010000, 0x0001010002010100, 0x0001010002010101, 0x0001010002010200, 0x0001010002010201, 0x0001010002010202, 0x0001010002020001, 0x0001010002020100, 0x0001010002020101, 0x0001010002020201, 0x0001010100000002, 0x0001010100000101, 0x0001010100000202, 0x0001010100010001, 0x0001010100010100, 0x0001010100010101, 0x0001010100010102, 0x0001010100010201, 0x0001010100020000, 0x0001010100020002, 0x0001010100020101, 0x0001010100020200, 0x0001010100020202, 0x0001010101000001, 0x0001010101000100, 0x0001010101000101, 0x0001010101000102, 0x0001010101010001, 0x0001010101010002, 0x0001010101010100, 0x0001010101010101, 0x0001010101010102, 0x0001010101010201, 0x0001010101010202, 0x0001010101020001, 0x0001010101020100, 0x0001010101020101, 0x0001010101020102, 0x0001010101020201, 0x0001010102000000, 0x0001010102000002, 0x0001010102000100, 0x0001010102000101, 0x0001010102000200, 0x0001010102000202, 0x0001010102010000, 0x0001010102010001, 0x0001010102010100, 0x0001010102010101, 0x0001010102010102, 0x0001010102010201, 0x0001010102010202, 0x0001010102020000, 0x0001010102020002, 0x0001010102020101, 0x0001010200000001, 0x0001010200000100, 0x0001010200000101, 0x0001010200000102, 0x0001010200010101, 0x0001010200010102, 0x0001010200010200, 0x0001010200010202, 0x0001010200020001, 0x0001010200020102, 0x0001010201000000, 0x0001010201000002, 0x0001010201000100, 0x0001010201000101, 0x0001010201000200, 0x0001010201000202, 0x0001010201010001, 0x0001010201010101, 0x0001010201010102, 0x0001010201010200, 0x0001010201010201, 0x0001010201020001, 0x0001010201020100, 0x0001010201020101, 0x0001010201020200, 0x0001010201020201, 0x0001010201020202, 0x0001010202000102, 0x0001010202000202, 0x0001010202010002, 0x0001010202010101, 0x0001010202020100, 0x0001010202020201, 0x0001020000010001, 0x0001020000010102, 0x0001020000020101, 0x0001020001000001, 0x0001020001000100, 0x0001020001000102, 0x0001020001000201, 0x0001020001010000, 0x0001020001010101, 0x0001020001010200, 0x0001020001010202, 0x0001020001020000, 0x0001020001020001, 0x0001020001020100, 0x0001020001020102, 0x0001020001020201, 0x0001020002000101, 0x0001020002010001, 0x0001020002010100, 0x0001020002020101, 0x0001020100010000, 0x0001020100010002, 0x0001020100010101, 0x0001020100010202, 0x0001020100020001, 0x0001020100020101, 0x0001020101000002, 0x0001020101000100, 0x0001020101000101, 0x0001020101000200, 0x0001020101010001, 0x0001020101010100, 0x0001020101010101, 0x0001020101010102, 0x0001020101010201, 0x0001020101010202, 0x0001020101020000, 0x0001020101020101, 0x0001020101020202, 0x0001020102000201, 0x0001020102010001, 0x0001020102010002, 0x0001020102010101, 0x0001020102010200, 0x0001020102020001, 0x0001020102020102, 0x0001020102020201, 0x0001020200000201, 0x0001020200010102, 0x0001020200020100, 0x0001020200020102, 0x0001020201000100, 0x0001020201000102, 0x0001020201000201, 0x0001020201010000, 0x0001020201010002, 0x0001020201010101, 0x0001020201010200, 0x0001020201020001, 0x0001020201020102, 0x0001020201020201, 0x0001020202000101, 0x0001020202010001, 0x0001020202010102, 0x0001020202010202, 0x0002000000000000, 0x0002000000000002, 0x0002000000000200, 0x0002000000000202, 0x0002000000010101, 0x0002000000020000, 0x0002000000020002, 0x0002000000020101, 0x0002000000020200, 0x0002000000020202, 0x0002000001000101, 0x0002000001010001, 0x0002000001010201, 0x0002000001020001, 0x0002000001020101, 0x0002000002000000, 0x0002000002000002, 0x0002000002000200, 0x0002000002000202, 0x0002000002010101, 0x0002000002020000, 0x0002000002020002, 0x0002000002020101, 0x0002000002020200, 0x0002000002020202, 0x0002000100000101, 0x0002000100010001, 0x0002000100010100, 0x0002000100010201, 0x0002000100020101, 0x0002000101000002, 0x0002000101000100, 0x0002000101000201, 0x0002000101010101, 0x0002000101010200, 0x0002000101010202, 0x0002000101020001, 0x0002000101020100, 0x0002000101020101, 0x0002000101020102, 0x0002000102000101, 0x0002000102010000, 0x0002000102010102, 0x0002000102010201, 0x0002000102020101, 0x0002000200000001, 0x0002000200000200, 0x0002000200000202, 0x0002000200010001, 0x0002000200010101, 0x0002000200020000, 0x0002000200020002, 0x0002000200020200, 0x0002000200020202, 0x0002000201000101, 0x0002000201010001, 0x0002000201010102, 0x0002000201010201, 0x0002000201020101, 0x0002000202000001, 0x0002000202000200, 0x0002000202000202, 0x0002000202010001, 0x0002000202010101, 0x0002000202020000, 0x0002000202020002, 0x0002000202020200, 0x0002000202020202, 0x0002010000000101, 0x0002010000010100, 0x0002010000010102, 0x0002010000010201, 0x0002010000020101, 0x0002010001000100, 0x0002010001000101, 0x0002010001000102, 0x0002010001000201, 0x0002010001010002, 0x0002010001010101, 0x0002010001010200, 0x0002010001010202, 0x0002010001020102, 0x0002010002000101, 0x0002010002010001, 0x0002010002010100, 0x0002010002010201, 0x0002010002020001, 0x0002010002020101, 0x0002010100000201, 0x0002010100010101, 0x0002010100020001, 0x0002010100020201, 0x0002010101000000, 0x0002010101000101, 0x0002010101000200, 0x0002010101010001, 0x0002010101010100, 0x0002010101010101, 0x0002010101010201, 0x0002010101020002, 0x0002010101020101, 0x0002010101020200, 0x0002010102000201, 0x0002010102010000, 0x0002010102010100, 0x0002010102010101, 0x0002010102010200, 0x0002010102010202, 0x0002010102020001, 0x0002010102020100, 0x0002010102020102, 0x0002010102020201, 0x0002010200000101, 0x0002010200010000, 0x0002010200010002, 0x0002010200010201, 0x0002010200020101, 0x0002010201000001, 0x0002010201000201, 0x0002010201010101, 0x0002010201020000, 0x0002010201020001, 0x0002010201020201, 0x0002010202000100, 0x0002010202000102, 0x0002010202010000, 0x0002010202010202, 0x0002020000000000, 0x0002020000000002, 0x0002020000000200, 0x0002020000000202, 0x0002020000010101, 0x0002020000020000, 0x0002020000020002, 0x0002020000020200, 0x0002020000020202, 0x0002020001000101, 0x0002020001010001, 0x0002020001010100, 0x0002020001020101, 0x0002020002000000, 0x0002020002000002, 0x0002020002000200, 0x0002020002000202, 0x0002020002020000, 0x0002020002020002, 0x0002020002020200, 0x0002020002020202, 0x0002020100000201, 0x0002020100010001, 0x0002020100010100, 0x0002020100010201, 0x0002020100020101, 0x0002020101000102, 0x0002020101000201, 0x0002020101010002, 0x0002020101010101, 0x0002020101020001, 0x0002020101020100, 0x0002020101020102, 0x0002020101020201, 0x0002020102000101, 0x0002020102010000, 0x0002020102010102, 0x0002020102010201, 0x0002020102020100, 0x0002020102020101, 0x0002020200000000, 0x0002020200000002, 0x0002020200000200, 0x0002020200000202, 0x0002020200020000, 0x0002020200020002, 0x0002020200020200, 0x0002020200020202, 0x0002020201000101, 0x0002020201010001, 0x0002020201010102, 0x0002020201010201, 0x0002020201020101, 0x0002020202000000, 0x0002020202000002, 0x0002020202000200, 0x0002020202000202, 0x0002020202010101, 0x0002020202020000, 0x0002020202020002, 0x0002020202020200, 0x0002020202020202, 0x0100000000000101, 0x0100000000010001, 0x0100000000010102, 0x0100000000020101, 0x0100000001000201, 0x0100000001010002, 0x0100000001010101, 0x0100000001010200, 0x0100000001010202, 0x0100000001020001, 0x0100000001020100, 0x0100000001020102, 0x0100000002010100, 0x0100000002010201, 0x0100000002020001, 0x0100000002020102, 0x0100000100000000, 0x0100000100000001, 0x0100000100000100, 0x0100000100000102, 0x0100000100000201, 0x0100000100010002, 0x0100000100010101, 0x0100000100010102, 0x0100000100010200, 0x0100000100010202, 0x0100000100020001, 0x0100000100020102, 0x0100000100020201, 0x0100000101000101, 0x0100000101000200, 0x0100000101000202, 0x0100000101010001, 0x0100000101010100, 0x0100000101010101, 0x0100000101010102, 0x0100000101010201, 0x0100000101010202, 0x0100000101020101, 0x0100000101020200, 0x0100000101020202, 0x0100000102000001, 0x0100000102000100, 0x0100000102000102, 0x0100000102010000, 0x0100000102010002, 0x0100000102010101, 0x0100000102020000, 0x0100000102020001, 0x0100000102020002, 0x0100000200000101, 0x0100000200010001, 0x0100000200010100, 0x0100000200010102, 0x0100000200020101, 0x0100000201000001, 0x0100000201010002, 0x0100000201010101, 0x0100000201010202, 0x0100000201020100, 0x0100000201020201, 0x0100000202000201, 0x0100000202010100, 0x0100000202020101, 0x0100010000000001, 0x0100010000010101, 0x0100010000010201, 0x0100010000020201, 0x0100010001000101, 0x0100010001000200, 0x0100010001000202, 0x0100010001010001, 0x0100010001010100, 0x0100010001010101, 0x0100010001010102, 0x0100010001020001, 0x0100010001020002, 0x0100010001020101, 0x0100010001020200, 0x0100010001020202, 0x0100010002000001, 0x0100010002000102, 0x0100010002000201, 0x0100010002010000, 0x0100010002010002, 0x0100010002010101, 0x0100010002020000, 0x0100010002020001, 0x0100010002020201, 0x0100010100000001, 0x0100010100000002, 0x0100010100000101, 0x0100010100000202, 0x0100010100010001, 0x0100010100010100, 0x0100010100010101, 0x0100010100010102, 0x0100010100010201, 0x0100010100020000, 0x0100010100020101, 0x0100010100020202, 0x0100010101000001, 0x0100010101000100, 0x0100010101000101, 0x0100010101000102, 0x0100010101000201, 0x0100010101010000, 0x0100010101010001, 0x0100010101010100, 0x0100010101010101, 0x0100010101010102, 0x0100010101010200, 0x0100010101010201, 0x0100010101020001, 0x0100010101020100, 0x0100010101020101, 0x0100010101020102, 0x0100010101020201, 0x0100010102000002, 0x0100010102000100, 0x0100010102000101, 0x0100010102000200, 0x0100010102010001, 0x0100010102010100, 0x0100010102010101, 0x0100010102010102, 0x0100010102010201, 0x0100010102010202, 0x0100010102020101, 0x0100010102020200, 0x0100010102020202, 0x0100010200000001, 0x0100010200000101, 0x0100010200000201, 0x0100010200010100, 0x0100010200010101, 0x0100010200010200, 0x0100010200010202, 0x0100010200020001, 0x0100010200020100, 0x0100010200020201, 0x0100010201000000, 0x0100010201000002, 0x0100010201000101, 0x0100010201000200, 0x0100010201010000, 0x0100010201010001, 0x0100010201010002, 0x0100010201010101, 0x0100010201010102, 0x0100010201010201, 0x0100010201020002, 0x0100010201020101, 0x0100010201020200, 0x0100010202000001, 0x0100010202000101, 0x0100010202000202, 0x0100010202010100, 0x0100010202010101, 0x0100010202020001, 0x0100010202020100, 0x0100010202020102, 0x0100020000000101, 0x0100020000010001, 0x0100020000010101, 0x0100020000010202, 0x0100020000020101, 0x0100020001000002, 0x0100020001000201, 0x0100020001010000, 0x0100020001010101, 0x0100020001010200, 0x0100020001020001, 0x0100020001020100, 0x0100020001020102, 0x0100020001020201, 0x0100020002000101, 0x0100020002010001, 0x0100020002010100, 0x0100020002010102, 0x0100020002010201, 0x0100020002020101, 0x0100020100000001, 0x0100020100000101, 0x0100020100000102, 0x0100020100000202, 0x0100020100010000, 0x0100020100010100, 0x0100020100010101, 0x0100020100010200, 0x0100020100020001, 0x0100020100020100, 0x0100020100020102, 0x0100020101000000, 0x0100020101000101, 0x0100020101000202, 0x0100020101010001, 0x0100020101010002, 0x0100020101010100, 0x0100020101010101, 0x0100020101010102, 0x0100020101010201, 0x0100020101020000, 0x0100020101020002, 0x0100020101020101, 0x0100020101020102, 0x0100020101020202, 0x0100020102000102, 0x0100020102000201, 0x0100020102010002, 0x0100020102010101, 0x0100020102010102, 0x0100020102010200, 0x0100020102020001, 0x0100020102020100, 0x0100020102020102, 0x0100020102020201, 0x0100020200010102, 0x0100020201000100, 0x0100020201000102, 0x0100020201000201, 0x0100020201010101, 0x0100020201010200, 0x0100020201010202, 0x0100020201020100, 0x0100020201020201, 0x0100020202010100, 0x0100020202020101, 0x0101000000000001, 0x0101000000000100, 0x0101000000000101, 0x0101000000000102, 0x0101000000000201, 0x0101000000010002, 0x0101000000010101, 0x0101000000010202, 0x0101000000020001, 0x0101000000020100, 0x0101000000020201, 0x0101000001000000, 0x0101000001000101, 0x0101000001000200, 0x0101000001010001, 0x0101000001010100, 0x0101000001010101, 0x0101000001010102, 0x0101000001010201, 0x0101000001020101, 0x0101000001020200, 0x0101000002000102, 0x0101000002000201, 0x0101000002010101, 0x0101000002010200, 0x0101000002020000, 0x0101000002020001, 0x0101000002020102, 0x0101000002020201, 0x0101000100000101, 0x0101000100000200, 0x0101000100000201, 0x0101000100000202, 0x0101000100010001, 0x0101000100010100, 0x0101000100010101, 0x0101000100010102, 0x0101000100010200, 0x0101000100010201, 0x0101000100020000, 0x0101000100020101, 0x0101000100020102, 0x0101000100020200, 0x0101000100020202, 0x0101000101000001, 0x0101000101000100, 0x0101000101000101, 0x0101000101000102, 0x0101000101000201, 0x0101000101010000, 0x0101000101010001, 0x0101000101010002, 0x0101000101010100, 0x0101000101010101, 0x0101000101010102, 0x0101000101010200, 0x0101000101010201, 0x0101000101010202, 0x0101000101020001, 0x0101000101020100, 0x0101000101020101, 0x0101000101020102, 0x0101000101020201, 0x0101000102000002, 0x0101000102000101, 0x0101000102010001, 0x0101000102010100, 0x0101000102010101, 0x0101000102010102, 0x0101000102010201, 0x0101000102020000, 0x0101000102020101, 0x0101000102020202, 0x0101000200000001, 0x0101000200000102, 0x0101000200010002, 0x0101000200010101, 0x0101000200010202, 0x0101000200020001, 0x0101000200020100, 0x0101000201000002, 0x0101000201000101, 0x0101000201000202, 0x0101000201010001, 0x0101000201010100, 0x0101000201010101, 0x0101000201010102, 0x0101000201010201, 0x0101000201020002, 0x0101000201020101, 0x0101000202000101, 0x0101000202010000, 0x0101000202010002, 0x0101000202010101, 0x0101000202010201, 0x0101000202010202, 0x0101000202020100, 0x0101010000000100, 0x0101010000000101, 0x0101010000010001, 0x0101010000010100, 0x0101010000010101, 0x0101010000010102, 0x0101010000010200, 0x0101010000010201, 0x0101010000020001, 0x0101010000020101, 0x0101010000020200, 0x0101010000020202, 0x0101010001000001, 0x0101010001000100, 0x0101010001000101, 0x0101010001000102, 0x0101010001000201, 0x0101010001000202, 0x0101010001010000, 0x0101010001010001, 0x0101010001010100, 0x0101010001010101, 0x0101010001010102, 0x0101010001010200, 0x0101010001010201, 0x0101010001010202, 0x0101010001020001, 0x0101010001020002, 0x0101010001020100, 0x0101010001020101, 0x0101010001020102, 0x0101010001020201, 0x0101010002000000, 0x0101010002000200, 0x0101010002000202, 0x0101010002010001, 0x0101010002010100, 0x0101010002010101, 0x0101010002010102, 0x0101010002010201, 0x0101010002020001, 0x0101010002020100, 0x0101010002020101, 0x0101010002020202, 0x0101010100000001, 0x0101010100000002, 0x0101010100000100, 0x0101010100000101, 0x0101010100000102, 0x0101010100000201, 0x0101010100010000, 0x0101010100010001, 0x0101010100010002, 0x0101010100010100, 0x0101010100010101, 0x0101010100010102, 0x0101010100010201, 0x0101010100010202, 0x0101010100020001, 0x0101010100020100, 0x0101010100020101, 0x0101010100020102, 0x0101010100020201, 0x0101010101000000, 0x0101010101000001, 0x0101010101000002, 0x0101010101000100, 0x0101010101000101, 0x0101010101000102, 0x0101010101000200, 0x0101010101000201, 0x0101010101010000, 0x0101010101010001, 0x0101010101010002, 0x0101010101010100, 0x0101010101010101, 0x0101010101010102, 0x0101010101010200, 0x0101010101010201, 0x0101010101010202, 0x0101010101020000, 0x0101010101020001, 0x0101010101020100, 0x0101010101020101, 0x0101010101020102, 0x0101010101020200, 0x0101010101020201, 0x0101010101020202, 0x0101010102000001, 0x0101010102000100, 0x0101010102000101, 0x0101010102000201, 0x0101010102000202, 0x0101010102010000, 0x0101010102010001, 0x0101010102010100, 0x0101010102010101, 0x0101010102010102, 0x0101010102010200, 0x0101010102010201, 0x0101010102020001, 0x0101010102020100, 0x0101010102020101, 0x0101010102020102, 0x0101010102020201, 0x0101010200000000, 0x0101010200000001, 0x0101010200000002, 0x0101010200000100, 0x0101010200000102, 0x0101010200000200, 0x0101010200000201, 0x0101010200010001, 0x0101010200010100, 0x0101010200010101, 0x0101010200010200, 0x0101010200010201, 0x0101010200020000, 0x0101010200020001, 0x0101010200020002, 0x0101010200020100, 0x0101010200020101, 0x0101010200020102, 0x0101010200020200, 0x0101010200020201, 0x0101010201000001, 0x0101010201000101, 0x0101010201000102, 0x0101010201000200, 0x0101010201000201, 0x0101010201000202, 0x0101010201010000, 0x0101010201010001, 0x0101010201010002, 0x0101010201010100, 0x0101010201010101, 0x0101010201010102, 0x0101010201010200, 0x0101010201010201, 0x0101010201010202, 0x0101010201020001, 0x0101010201020100, 0x0101010201020101, 0x0101010201020201, 0x0101010202000002, 0x0101010202000101, 0x0101010202000102, 0x0101010202000200, 0x0101010202000201, 0x0101010202000202, 0x0101010202010001, 0x0101010202010101, 0x0101010202010202, 0x0101010202020002, 0x0101010202020101, 0x0101010202020102, 0x0101010202020200, 0x0101010202020201, 0x0101020000000100, 0x0101020000000101, 0x0101020000000102, 0x0101020000000201, 0x0101020000010000, 0x0101020000010101, 0x0101020000010200, 0x0101020000020001, 0x0101020000020202, 0x0101020001000101, 0x0101020001000200, 0x0101020001000202, 0x0101020001010001, 0x0101020001010100, 0x0101020001010101, 0x0101020001010102, 0x0101020001010200, 0x0101020001010201, 0x0101020001020000, 0x0101020001020002, 0x0101020001020100, 0x0101020001020101, 0x0101020002000002, 0x0101020002000201, 0x0101020002010000, 0x0101020002010002, 0x0101020002010101, 0x0101020002010200, 0x0101020002020001, 0x0101020002020201, 0x0101020100000001, 0x0101020100000002, 0x0101020100000101, 0x0101020100000202, 0x0101020100010001, 0x0101020100010100, 0x0101020100010101, 0x0101020100010102, 0x0101020100010201, 0x0101020100020101, 0x0101020101000001, 0x0101020101000100, 0x0101020101000101, 0x0101020101000102, 0x0101020101000201, 0x0101020101010000, 0x0101020101010001, 0x0101020101010002, 0x0101020101010100, 0x0101020101010101, 0x0101020101010102, 0x0101020101010200, 0x0101020101010201, 0x0101020101010202, 0x0101020101020001, 0x0101020101020100, 0x0101020101020101, 0x0101020101020102, 0x0101020101020201, 0x0101020102000001, 0x0101020102000101, 0x0101020102000201, 0x0101020102010001, 0x0101020102010100, 0x0101020102010101, 0x0101020102010102, 0x0101020102010200, 0x0101020102010201, 0x0101020102020101, 0x0101020200000100, 0x0101020200000200, 0x0101020200010101, 0x0101020200010202, 0x0101020200020000, 0x0101020200020101, 0x0101020200020102, 0x0101020200020201, 0x0101020201000101, 0x0101020201000200, 0x0101020201000201, 0x0101020201010001, 0x0101020201010101, 0x0101020201010102, 0x0101020201010200, 0x0101020201010201, 0x0101020201020002, 0x0101020201020101, 0x0101020201020200, 0x0101020201020202, 0x0101020202000001, 0x0101020202000202, 0x0101020202010002, 0x0101020202010101, 0x0101020202010102, 0x0101020202010200, 0x0101020202010202, 0x0101020202020001, 0x0102000000000101, 0x0102000000010100, 0x0102000000010102, 0x0102000000010201, 0x0102000000020101, 0x0102000001000100, 0x0102000001010000, 0x0102000001010101, 0x0102000001010102, 0x0102000001010200, 0x0102000001010202, 0x0102000001020001, 0x0102000001020100, 0x0102000001020102, 0x0102000001020201, 0x0102000002000001, 0x0102000002010102, 0x0102000002020101, 0x0102000100000001, 0x0102000100000100, 0x0102000100000102, 0x0102000100000201, 0x0102000100010002, 0x0102000100010101, 0x0102000100020001, 0x0102000100020002, 0x0102000100020102, 0x0102000100020201, 0x0102000101000101, 0x0102000101000201, 0x0102000101010001, 0x0102000101010101, 0x0102000101010102, 0x0102000101010201, 0x0102000101020101, 0x0102000101020102, 0x0102000101020202, 0x0102000102000100, 0x0102000102000202, 0x0102000102010002, 0x0102000102010101, 0x0102000102020001, 0x0102000102020102, 0x0102000102020201, 0x0102000200010001, 0x0102000200010102, 0x0102000200010201, 0x0102000201000000, 0x0102000201000001, 0x0102000201000102, 0x0102000201010101, 0x0102000201010102, 0x0102000201010200, 0x0102000201020000, 0x0102000202000101, 0x0102000202010001, 0x0102000202010102, 0x0102000202020101, 0x0102010000010001, 0x0102010000010002, 0x0102010000010101, 0x0102010000010102, 0x0102010000010202, 0x0102010000020001, 0x0102010000020102, 0x0102010000020201, 0x0102010001000000, 0x0102010001000002, 0x0102010001000101, 0x0102010001000200, 0x0102010001000202, 0x0102010001010001, 0x0102010001010100, 0x0102010001010101, 0x0102010001010102, 0x0102010001010201, 0x0102010001010202, 0x0102010001020000, 0x0102010001020002, 0x0102010001020101, 0x0102010002000100, 0x0102010002000101, 0x0102010002000201, 0x0102010002010000, 0x0102010002010002, 0x0102010002010100, 0x0102010002010101, 0x0102010002010102, 0x0102010002010200, 0x0102010002010202, 0x0102010002020001, 0x0102010002020100, 0x0102010002020201, 0x0102010100000101, 0x0102010100000200, 0x0102010100000202, 0x0102010100010001, 0x0102010100010101, 0x0102010100010102, 0x0102010100010201, 0x0102010101000100, 0x0102010101000101, 0x0102010101000102, 0x0102010101000201, 0x0102010101010000, 0x0102010101010001, 0x0102010101010100, 0x0102010101010101, 0x0102010101010102, 0x0102010101010201, 0x0102010101020001, 0x0102010101020100, 0x0102010101020101, 0x0102010101020102, 0x0102010101020201, 0x0102010102000102, 0x0102010102000201, 0x0102010102000202, 0x0102010102010001, 0x0102010102010101, 0x0102010102010102, 0x0102010102010201, 0x0102010102010202, 0x0102010102020002, 0x0102010102020101, 0x0102010102020102, 0x0102010102020200, 0x0102010200000002, 0x0102010200000201, 0x0102010200010101, 0x0102010200020000, 0x0102010200020102, 0x0102010200020200, 0x0102010200020201, 0x0102010201000000, 0x0102010201000101, 0x0102010201000200, 0x0102010201000202, 0x0102010201010001, 0x0102010201010100, 0x0102010201010101, 0x0102010201010102, 0x0102010201010200, 0x0102010201010202, 0x0102010201020000, 0x0102010201020101, 0x0102010201020200, 0x0102010202000000, 0x0102010202000002, 0x0102010202000101, 0x0102010202000202, 0x0102010202010100, 0x0102010202010102, 0x0102010202010200, 0x0102010202010201, 0x0102010202020000, 0x0102010202020100, 0x0102010202020102, 0x0102010202020202, 0x0102020000010102, 0x0102020000010201, 0x0102020000020101, 0x0102020001000001, 0x0102020001010002, 0x0102020001010101, 0x0102020001010202, 0x0102020001020001, 0x0102020001020201, 0x0102020002000101, 0x0102020002010001, 0x0102020002010200, 0x0102020002020102, 0x0102020100000001, 0x0102020100000100, 0x0102020100010000, 0x0102020100010101, 0x0102020100020001, 0x0102020100020100, 0x0102020100020102, 0x0102020100020201, 0x0102020101000000, 0x0102020101000001, 0x0102020101000101, 0x0102020101000102, 0x0102020101000200, 0x0102020101010001, 0x0102020101010100, 0x0102020101010101, 0x0102020101010102, 0x0102020101010201, 0x0102020101020000, 0x0102020101020101, 0x0102020101020202, 0x0102020102000002, 0x0102020102000100, 0x0102020102000202, 0x0102020102010101, 0x0102020102020001, 0x0102020102020100, 0x0102020102020101, 0x0102020102020201, 0x0102020200010001, 0x0102020200010102, 0x0102020200010200, 0x0102020201000001, 0x0102020201000100, 0x0102020201000201, 0x0102020201010000, 0x0102020201010101, 0x0102020201010200, 0x0102020201010202, 0x0102020201020100, 0x0102020201020101, 0x0102020201020201, 0x0102020202000102, 0x0102020202010100, 0x0102020202010200, 0x0102020202010202, 0x0102020202020102, 0x0200000000000000, 0x0200000000000002, 0x0200000000000200, 0x0200000000000202, 0x0200000000020000, 0x0200000000020002, 0x0200000000020200, 0x0200000000020202, 0x0200000001000101, 0x0200000001010000, 0x0200000001010001, 0x0200000001010100, 0x0200000001010102, 0x0200000001010201, 0x0200000001020101, 0x0200000002000000, 0x0200000002000002, 0x0200000002000200, 0x0200000002000202, 0x0200000002010101, 0x0200000002020000, 0x0200000002020002, 0x0200000002020200, 0x0200000002020202, 0x0200000100000101, 0x0200000100010001, 0x0200000100010100, 0x0200000100010102, 0x0200000100010201, 0x0200000100020101, 0x0200000101000001, 0x0200000101000100, 0x0200000101000201, 0x0200000101010000, 0x0200000101010002, 0x0200000101010101, 0x0200000101010102, 0x0200000101010200, 0x0200000101010201, 0x0200000101020100, 0x0200000101020102, 0x0200000101020201, 0x0200000102000101, 0x0200000102000201, 0x0200000102010100, 0x0200000102010102, 0x0200000102010201, 0x0200000102020101, 0x0200000200000000, 0x0200000200000002, 0x0200000200000200, 0x0200000200000202, 0x0200000200010101, 0x0200000200020000, 0x0200000200020002, 0x0200000200020200, 0x0200000200020202, 0x0200000201010001, 0x0200000201010100, 0x0200000201010201, 0x0200000201020101, 0x0200000202000000, 0x0200000202000002, 0x0200000202000200, 0x0200000202000202, 0x0200000202010101, 0x0200000202020000, 0x0200000202020002, 0x0200000202020200, 0x0200000202020202, 0x0200010000010100, 0x0200010000010201, 0x0200010001000001, 0x0200010001000100, 0x0200010001010001, 0x0200010001010101, 0x0200010001010202, 0x0200010001020001, 0x0200010001020100, 0x0200010001020201, 0x0200010002010100, 0x0200010002010201, 0x0200010100000001, 0x0200010100000201, 0x0200010100010002, 0x0200010100010101, 0x0200010100010202, 0x0200010100020102, 0x0200010100020201, 0x0200010101000000, 0x0200010101000001, 0x0200010101000101, 0x0200010101000200, 0x0200010101010001, 0x0200010101010100, 0x0200010101010101, 0x0200010101010102, 0x0200010101010201, 0x0200010101010202, 0x0200010101020101, 0x0200010101020102, 0x0200010101020200, 0x0200010101020202, 0x0200010102000001, 0x0200010102000100, 0x0200010102000102, 0x0200010102000201, 0x0200010102010000, 0x0200010102010002, 0x0200010102010101, 0x0200010102010200, 0x0200010102020102, 0x0200010200010001, 0x0200010200010102, 0x0200010200010201, 0x0200010200020101, 0x0200010201000001, 0x0200010201000100, 0x0200010201000201, 0x0200010201000202, 0x0200010201010000, 0x0200010201010101, 0x0200010201010201, 0x0200010201010202, 0x0200010201020001, 0x0200010201020102, 0x0200010201020202, 0x0200010202000101, 0x0200010202010001, 0x0200010202010202, 0x0200010202020100, 0x0200020000000000, 0x0200020000000002, 0x0200020000000200, 0x0200020000000202, 0x0200020000010101, 0x0200020000020000, 0x0200020000020002, 0x0200020000020200, 0x0200020000020202, 0x0200020001000001, 0x0200020001000101, 0x0200020001010001, 0x0200020001010100, 0x0200020001010201, 0x0200020001020101, 0x0200020001020201, 0x0200020002000000, 0x0200020002000002, 0x0200020002000200, 0x0200020002000202, 0x0200020002010101, 0x0200020002020000, 0x0200020002020002, 0x0200020002020200, 0x0200020002020202, 0x0200020100000101, 0x0200020100000102, 0x0200020100010001, 0x0200020100010100, 0x0200020100010102, 0x0200020100020101, 0x0200020101000001, 0x0200020101000100, 0x0200020101000102, 0x0200020101000201, 0x0200020101010000, 0x0200020101010002, 0x0200020101010101, 0x0200020101010202, 0x0200020101020001, 0x0200020101020100, 0x0200020102000101, 0x0200020102010102, 0x0200020102010201, 0x0200020102020101, 0x0200020200000000, 0x0200020200000002, 0x0200020200000200, 0x0200020200000202, 0x0200020200010101, 0x0200020200020000, 0x0200020200020002, 0x0200020200020200, 0x0200020200020202, 0x0200020201000101, 0x0200020201010001, 0x0200020201010100, 0x0200020201010102, 0x0200020202000000, 0x0200020202000002, 0x0200020202000200, 0x0200020202000202, 0x0200020202010101, 0x0200020202020000, 0x0200020202020002, 0x0200020202020200, 0x0200020202020202, 0x0201000000000101, 0x0201000000010001, 0x0201000000010102, 0x0201000000010200, 0x0201000000010201, 0x0201000000020101, 0x0201000001000001, 0x0201000001000102, 0x0201000001000201, 0x0201000001010101, 0x0201000001010200, 0x0201000001010202, 0x0201000001020201, 0x0201000001020202, 0x0201000002000101, 0x0201000002010001, 0x0201000002010100, 0x0201000002010102, 0x0201000002010201, 0x0201000002020101, 0x0201000100000001, 0x0201000100000100, 0x0201000100000102, 0x0201000100000201, 0x0201000100010000, 0x0201000100010101, 0x0201000100010200, 0x0201000100010202, 0x0201000100020001, 0x0201000100020100, 0x0201000100020102, 0x0201000100020201, 0x0201000101000000, 0x0201000101000101, 0x0201000101010000, 0x0201000101010001, 0x0201000101010100, 0x0201000101010101, 0x0201000101010102, 0x0201000101010201, 0x0201000101020002, 0x0201000101020101, 0x0201000102000100, 0x0201000102000102, 0x0201000102010002, 0x0201000102010101, 0x0201000102010200, 0x0201000102020001, 0x0201000102020100, 0x0201000102020102, 0x0201000102020201, 0x0201000200000101, 0x0201000200010001, 0x0201000200010100, 0x0201000200010201, 0x0201000200020101, 0x0201000201000100, 0x0201000201000102, 0x0201000201000201, 0x0201000201010000, 0x0201000201010002, 0x0201000201010101, 0x0201000201010200, 0x0201000201020102, 0x0201000201020201, 0x0201000202000101, 0x0201000202010100, 0x0201000202010102, 0x0201000202020201, 0x0201010000000001, 0x0201010000000100, 0x0201010000000102, 0x0201010000010000, 0x0201010000010101, 0x0201010000010200, 0x0201010000020102, 0x0201010001000000, 0x0201010001000202, 0x0201010001010001, 0x0201010001010100, 0x0201010001010101, 0x0201010001010102, 0x0201010001010200, 0x0201010001010201, 0x0201010001020000, 0x0201010001020001, 0x0201010001020002, 0x0201010001020101, 0x0201010002000100, 0x0201010002000102, 0x0201010002010002, 0x0201010002010100, 0x0201010002010101, 0x0201010002010200, 0x0201010002020001, 0x0201010002020201, 0x0201010100000000, 0x0201010100000101, 0x0201010100000200, 0x0201010100000202, 0x0201010100010000, 0x0201010100010001, 0x0201010100010100, 0x0201010100010101, 0x0201010100010102, 0x0201010100010201, 0x0201010100020001, 0x0201010100020101, 0x0201010100020201, 0x0201010100020202, 0x0201010101000001, 0x0201010101000100, 0x0201010101000101, 0x0201010101000102, 0x0201010101000201, 0x0201010101010000, 0x0201010101010001, 0x0201010101010002, 0x0201010101010100, 0x0201010101010101, 0x0201010101010102, 0x0201010101010200, 0x0201010101010201, 0x0201010101010202, 0x0201010101020001, 0x0201010101020100, 0x0201010101020101, 0x0201010101020102, 0x0201010101020201, 0x0201010102000001, 0x0201010102000101, 0x0201010102000200, 0x0201010102010001, 0x0201010102010002, 0x0201010102010100, 0x0201010102010101, 0x0201010102010102, 0x0201010102010201, 0x0201010102010202, 0x0201010102020000, 0x0201010102020002, 0x0201010102020101, 0x0201010102020200, 0x0201010102020202, 0x0201010200000001, 0x0201010200000100, 0x0201010200010000, 0x0201010200010101, 0x0201010200010201, 0x0201010200020000, 0x0201010200020102, 0x0201010200020201, 0x0201010201000101, 0x0201010201000200, 0x0201010201000201, 0x0201010201010001, 0x0201010201010002, 0x0201010201010101, 0x0201010201010102, 0x0201010201010201, 0x0201010201020101, 0x0201010201020200, 0x0201010202000002, 0x0201010202000100, 0x0201010202000201, 0x0201010202000202, 0x0201010202010002, 0x0201010202010100, 0x0201010202010101, 0x0201010202020100, 0x0201010202020102, 0x0201010202020201, 0x0201020000000101, 0x0201020000010102, 0x0201020000010201, 0x0201020000020101, 0x0201020001000001, 0x0201020001000102, 0x0201020001010000, 0x0201020001010002, 0x0201020001010101, 0x0201020001010102, 0x0201020001010202, 0x0201020001020100, 0x0201020001020101, 0x0201020002000101, 0x0201020002010001, 0x0201020002010102, 0x0201020002010201, 0x0201020002020101, 0x0201020100000100, 0x0201020100000102, 0x0201020100000201, 0x0201020100010000, 0x0201020100010002, 0x0201020100010101, 0x0201020100010200, 0x0201020100010202, 0x0201020100020000, 0x0201020100020001, 0x0201020100020100, 0x0201020100020102, 0x0201020101000000, 0x0201020101000002, 0x0201020101000101, 0x0201020101000200, 0x0201020101000202, 0x0201020101010001, 0x0201020101010100, 0x0201020101010101, 0x0201020101010102, 0x0201020101010201, 0x0201020101020002, 0x0201020101020101, 0x0201020101020102, 0x0201020101020202, 0x0201020102000001, 0x0201020102000100, 0x0201020102010000, 0x0201020102010002, 0x0201020102010101, 0x0201020102010202, 0x0201020102020001, 0x0201020102020102, 0x0201020200000101, 0x0201020200010101, 0x0201020200020101, 0x0201020201000100, 0x0201020201000102, 0x0201020201000201, 0x0201020201010000, 0x0201020201010101, 0x0201020201010200, 0x0201020201020001, 0x0201020202000101, 0x0201020202010001, 0x0201020202010100, 0x0201020202010101, 0x0201020202010102, 0x0202000000000000, 0x0202000000000002, 0x0202000000000200, 0x0202000000000202, 0x0202000000010101, 0x0202000000020000, 0x0202000000020002, 0x0202000000020200, 0x0202000000020202, 0x0202000001000101, 0x0202000001010001, 0x0202000001010100, 0x0202000001010102, 0x0202000001010201, 0x0202000002000000, 0x0202000002000002, 0x0202000002000200, 0x0202000002000202, 0x0202000002010101, 0x0202000002020000, 0x0202000002020002, 0x0202000002020200, 0x0202000002020202, 0x0202000100000101, 0x0202000100000201, 0x0202000100010001, 0x0202000100010100, 0x0202000100010102, 0x0202000100010201, 0x0202000100010202, 0x0202000101000102, 0x0202000101000201, 0x0202000101010001, 0x0202000101010101, 0x0202000101010200, 0x0202000101010202, 0x0202000101020001, 0x0202000101020100, 0x0202000102000101, 0x0202000102010000, 0x0202000102010002, 0x0202000102010102, 0x0202000102010201, 0x0202000200000002, 0x0202000200000200, 0x0202000200000202, 0x0202000200010000, 0x0202000200010201, 0x0202000200020002, 0x0202000200020200, 0x0202000200020202, 0x0202000201000101, 0x0202000201010001, 0x0202000201010102, 0x0202000201010201, 0x0202000201020101, 0x0202000202000000, 0x0202000202000002, 0x0202000202000200, 0x0202000202000202, 0x0202000202010101, 0x0202000202020000, 0x0202000202020002, 0x0202000202020200, 0x0202000202020202, 0x0202010000010201, 0x0202010000020101, 0x0202010001000001, 0x0202010001000100, 0x0202010001010000, 0x0202010001010100, 0x0202010001010101, 0x0202010001010200, 0x0202010001010202, 0x0202010001020001, 0x0202010001020101, 0x0202010001020102, 0x0202010001020200, 0x0202010001020201, 0x0202010002000101, 0x0202010100000102, 0x0202010100000201, 0x0202010100010000, 0x0202010100010002, 0x0202010100010101, 0x0202010100010200, 0x0202010100020102, 0x0202010100020201, 0x0202010101000002, 0x0202010101000101, 0x0202010101010001, 0x0202010101010100, 0x0202010101010101, 0x0202010101010102, 0x0202010101010201, 0x0202010101020101, 0x0202010101020202, 0x0202010102000001, 0x0202010102000100, 0x0202010102000101, 0x0202010102000102, 0x0202010102000201, 0x0202010102010002, 0x0202010102010101, 0x0202010102010200, 0x0202010200000101, 0x0202010200010001, 0x0202010200010102, 0x0202010200010202, 0x0202010200020001, 0x0202010200020101, 0x0202010201000100, 0x0202010201000102, 0x0202010201000202, 0x0202010201010002, 0x0202010201010101, 0x0202010201010102, 0x0202010201010200, 0x0202010201020000, 0x0202010201020002, 0x0202010202000102, 0x0202010202010000, 0x0202010202010101, 0x0202010202010102, 0x0202010202010201, 0x0202010202020001, 0x0202010202020100, 0x0202010202020102, 0x0202020000000000, 0x0202020000000002, 0x0202020000000200, 0x0202020000000202, 0x0202020000020000, 0x0202020000020002, 0x0202020000020200, 0x0202020000020202, 0x0202020001010001, 0x0202020001010100, 0x0202020001010102, 0x0202020001010201, 0x0202020002000000, 0x0202020002000002, 0x0202020002000200, 0x0202020002000202, 0x0202020002010101, 0x0202020002020000, 0x0202020002020002, 0x0202020002020200, 0x0202020002020202, 0x0202020100000101, 0x0202020100010100, 0x0202020100010201, 0x0202020100020001, 0x0202020100020101, 0x0202020101000001, 0x0202020101010000, 0x0202020101010101, 0x0202020101010202, 0x0202020101020001, 0x0202020101020102, 0x0202020101020201, 0x0202020102010000, 0x0202020102010102, 0x0202020200000000, 0x0202020200000002, 0x0202020200000200, 0x0202020200000202, 0x0202020200020000, 0x0202020200020002, 0x0202020200020200, 0x0202020200020202, 0x0202020201010001, 0x0202020201010100, 0x0202020201010102, 0x0202020202000000, 0x0202020202000002, 0x0202020202000200, 0x0202020202000202, 0x0202020202010101, 0x0202020202020000, 0x0202020202020002, 0x0202020202020200, 0x0202020202020202, }; #else static const uint32_t iq1s_grid_us[2048] = { 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212, 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011, 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111, 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220, 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022, 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220, 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101, 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110, 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111, 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010, 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210, 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221, 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021, 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002, 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101, 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101, 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211, 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110, 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022, 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121, 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220, 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001, 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101, 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102, 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012, 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010, 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111, 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122, 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222, 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001, 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102, 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101, 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000, 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101, 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112, 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110, 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211, 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012, 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111, 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120, 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122, 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121, 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221, 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001, 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101, 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101, 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011, 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111, 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011, 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122, 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121, 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222, 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101, 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000, 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200, 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110, 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112, 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222, 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021, 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121, 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201, 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200, 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101, 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011, 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010, 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211, 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121, 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000, 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202, 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202, 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211, 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112, 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020, 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121, 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222, 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102, 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100, 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110, 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011, 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111, 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110, 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121, 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222, 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201, 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102, 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201, 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012, 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010, 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010, 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110, 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011, 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212, 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021, 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021, 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021, 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101, 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101, 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100, 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010, 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111, 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010, 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111, 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120, 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120, 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101, 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001, 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201, 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210, 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211, 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111, 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112, 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211, 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010, 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021, 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122, 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221, 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102, 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100, 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101, 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101, 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101, 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012, 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110, 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112, 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210, 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210, 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210, 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010, 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110, 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122, 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020, 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021, 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022, 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120, 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222, 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221, 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001, 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102, 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201, 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012, 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111, 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012, 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110, 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110, 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121, 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221, 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220, 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222, 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000, 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201, 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012, 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011, 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212, 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221, 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121, 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202, 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202, 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002, 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101, 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210, 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112, 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011, 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011, 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210, 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020, 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220, 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222, 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222, 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001, 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010, 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111, 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010, 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110, 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221, 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122, 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202, 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100, 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101, 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112, 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111, 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211, 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222, 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221, 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022, 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101, 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211, 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111, 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111, 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010, 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121, 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222, 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000, 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202, 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000, 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202, 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110, 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110, 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222, 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120, 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022, 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101, 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202, 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110, 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110, 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111, 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111, 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120, 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121, 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001, 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202, 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001, 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200, 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011, 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212, 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012, 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110, 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012, 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111, 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020, 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121, 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222, 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102, 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102, 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101, 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212, 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210, 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111, 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212, 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221, 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121, 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002, 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000, 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202, 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112, 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111, 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020, 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221, 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022, 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100, 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201, 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112, 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211, 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012, 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121, 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020, 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120, 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200, 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200, 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110, 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011, 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222, 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, }; #endif #ifndef HAVE_FANCY_SIMD const uint64_t keven_signs[128] = { 0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff, 0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff, 0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff, 0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff, 0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff, 0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff, 0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff, 0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff, 0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff, 0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff, 0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff, 0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff, 0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff, 0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff, 0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff, 0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff, 0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff, 0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff, 0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff, 0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff, 0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff, 0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff, 0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff, 0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff, 0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff, 0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff, 0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff, 0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff, 0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff, 0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff, 0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff, 0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff, }; #endif } /* moonll change mulmat add typeB and strideB }*/ bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, float * C, long stride_C, int ith, int nth) { MulMat mm; #if defined __x86_64__ || defined(_M_X64) if (!MulMat::set_mul_mat(typeA, typeB, (int)ne00, mm, Ny)) { return false; } #else int row_size_q8; if (!MulMat::set_mul_mat(typeA, (int)ne00, mm, row_size_q8, Ny)) { return false; } #endif size_t row_size_qx = strideA*ggml_type_size(ggml_type(typeA)); size_t row_size_qy = strideB*ggml_type_size(ggml_type(typeB)); auto nrc_x = (Nx + nth - 1)/nth; auto first_x = ith*nrc_x; if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0}; #ifdef __ARM_NEON #ifdef GEMM_Q4K_Q6K if (Ny >= 8 && (typeA == GGML_TYPE_Q4_K || typeA == GGML_TYPE_Q6_K)) { mm.mul_mat_NxM_v2(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); } else #endif #endif { mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); } return true; } bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, const void * B, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping; assert(row_mapping != nullptr); MulMat mm; int row_size_q8; /* moonll if (!MulMat::set_mul_mat(typeA, ne00, mm, row_size_q8, Ny)) { return false; }*/ int row_size_qx = ggml_row_size((ggml_type)typeA, ne00); int nrc_x = (Nx + nth - 1)/nth; int first_x = ith*nrc_x; if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), (size_t)row_size_q8, 0, ne11, row_mapping, nb2/sizeof(float)}; mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); return true; } #if defined __x86_64__ || defined(_M_X64) #if defined HAVE_FANCY_SIMD #undef HAVE_FANCY_SIMD #endif #if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) #define HAVE_FANCY_SIMD #endif //#define HAVE_FANCY_SIMD namespace { inline float hsum_float_4(__m128 x) { x = _mm_add_ps(x, _mm_movehl_ps(x, x)); x = _mm_add_ss(x, _mm_movehdup_ps(x)); return _mm_cvtss_f32(x); } inline float hsum_float_8(__m256 x) { return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); } #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) template struct Q8 { constexpr static int nrc_y = nrc; Q8(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy); } #ifdef HAVE_FANCY_SIMD inline __m512i load_quants64(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); } #endif inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); } inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); } inline float scale(int iy, int i) const { return y[iy][i].d; } const block_q8 * y[nrc_y]; }; // Handles q4_K and q5_K scales/mins struct Scales8K { template inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) { make_q4_scales(data, utmp); const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1); accum_mins(mins128, q8, i, c, accd); const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); return MM256_SET_M128I(sc128, sc128); } #ifdef HAVE_FANCY_SIMD template inline __m512i process_mins_and_scales_64(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) { auto scales = process_mins_and_scales(data, c, i, q8, accd); return _mm512_inserti32x8(_mm512_castsi256_si512(scales), scales, 1); } #endif template inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const { const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0])); for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i q8s = q8.load_bsums(iy, i); const __m256i prod = _mm256_madd_epi16(mins, q8s); accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]); } } #ifdef HAVE_FANCY_SIMD const __m512i shuffles512[2] = { _mm512_set_epi64(0x0706070607060706, 0x0302030203020302, 0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100, 0x0504050405040504, 0x0100010001000100), _mm512_set_epi64(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908) }; #endif const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100), _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)}; uint32_t utmp[4]; }; template inline void process_mins_16(const __m256i& all_scales, const Q8& q8, int i, float d, __m256 * accm) { for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i prod = _mm256_madd_epi16(all_scales, q8.load_bsums(iy, i)); accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); } } inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) { const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); scales[0] = MM256_SET_M128I(l_scales, l_scales); scales[1] = MM256_SET_M128I(h_scales, h_scales); } struct ScaleQ3 { inline __m128i make_scales(const uint16_t * s8) const { const uint16_t * scales16 = (const uint16_t *)s8; uint32_t aux0 = scales16[0] | (scales16[1] << 16); uint32_t aux1 = scales16[2] | (scales16[3] << 16); uint32_t aux2 = scales16[4] | (scales16[5] << 16); __m128i scales128 = _mm_set_epi32( ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030), ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030), (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030), (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030)); return _mm_add_epi8(scales128, m32); } const __m128i m32 = _mm_set1_epi8(-32); }; struct ScaleIQ4XS { inline __m128i make_scales(const uint32_t scales_l, const uint16_t scales_h) { uint32_t tmp32 = scales_h | (scales_h << 14); const __m128i sh = _mm_slli_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(tmp32), hshift), hmask), 4); const __m128i sl = _mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(scales_l), lshift), lmask); return _mm_add_epi16(_mm_or_si128(sh, _mm_cvtepi8_epi16(_mm_shuffle_epi8(sl, lshuffle))), m32); } const __m128i hshift = _mm_set_epi32(12, 8, 4, 0); const __m128i lshift = _mm_set_epi32(4, 0, 4, 0); const __m128i hmask = _mm_set1_epi16(0x03); const __m128i lmask = _mm_set1_epi8(0xf); const __m128i lshuffle = _mm_set_epi32(0x07030602, 0x05010400, 0x07030602, 0x05010400); const __m128i m32 = _mm_set1_epi16(-32); }; struct Scales8KBase { template inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const { const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0])); for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i q8s = q8.load_bsums(iy, i); const __m256i prod = _mm256_madd_epi16(mins, q8s); accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]); } } inline __m256i shuffle(__m128i mins) const { return MM256_SET_M128I(_mm_shuffle_epi8(mins, shuffles[1]), _mm_shuffle_epi8(mins, shuffles[0])); } const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100), _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)}; }; template struct BaseDequantizer { BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {} inline void new_row(int ix) { x = (const Block *)((const char *)vx + bx*ix); } const void * vx; size_t bx; const Block * x; float d; }; __m128i inline load_iq4nl_values_128() { static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241}; return _mm_loadu_si128((const __m128i *)kvalues_iq4nl); } __m256i inline load_iq4nl_values_256() { auto val128 = load_iq4nl_values_128(); return MM256_SET_M128I(val128, val128); } #ifdef HAVE_FANCY_SIMD //====================================== Zen4 ================================================== struct BlockPermuter { const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0); const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4); }; struct Q4Bits { inline void prepare(const uint8_t * q4) { auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0); auto tmp1 = _mm512_and_si512(q4bits, ml); auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2); values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2); q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1); tmp1 = _mm512_and_si512(q4bits, ml); tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2); values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2); } inline void prepare64(const uint8_t * q4) { auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0); values[0] = _mm512_and_si512(q4bits, ml); values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1); values[2] = _mm512_and_si512(q4bits, ml); values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); } __m512i values[4]; const __m512i ml = _mm512_set1_epi8(0xf); BlockPermuter perm; }; struct Q2Bits { inline void prepare(const uint8_t * q2) { auto q2bits = _mm512_loadu_si512((const __m512i*)q2); auto tmp = _mm512_srli_epi16(q2bits, 2); values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp); values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp); values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml); values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml); values[0] = _mm512_and_si512(values[0], ml); values[2] = _mm512_and_si512(values[2], ml); } __m512i values[4]; const __m512i ml = _mm512_set1_epi8(0x03); BlockPermuter perm; }; struct DequantizerQ4K final : public BaseDequantizer { DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) { d = GGML_FP16_TO_FP32(x[i].d); bits.prepare(x[i].qs); auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]); scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]); } Q4Bits bits; Scales8K s8k; }; /* moonll DequantizerIQ4XS */ __m512i inline load_iq4nl_values_512() { auto val256 = load_iq4nl_values_256(); return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); } struct DequantizerIQ4XS final : public BaseDequantizer { DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {} template inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) { d = GGML_FP16_TO_FP32(x[i].d); prepare(x[i].qs); auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h); s8k.accum_mins(scales128, q8, i, -128.f*d, accd); auto scales256 = MM256_SET_M128I(scales128, scales128); auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]); scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]); scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]); scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]); } inline void prepare(const uint8_t * q4) { bits.prepare64(q4); // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111 // bits.valuse[1]: 16..31, 48...63, 80...95, 112..127 // etc. auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]); bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1])); bits.values[0] = _mm512_shuffle_epi8(values, tmp); tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]); bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3])); bits.values[2] = _mm512_shuffle_epi8(values, tmp); } Q4Bits bits; Scales8KBase s8k; ScaleIQ4XS siq4; const __m512i values; const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); const __m512i shuffles[4] = { _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1), _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1), _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1), _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1), }; }; struct HighBit5 { inline void apply(const uint8_t * h, Q4Bits& bits) { auto hbits256 = _mm256_loadu_si256((const __m256i *)h); auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1); bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh)); bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh)); bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(hbits, mh)); bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh)); } const __m512i mh = _mm512_set1_epi8(0x10); }; struct HighBit3 { inline void apply(const uint8_t * h, Q2Bits& bits) { auto hbits256 = _mm256_loadu_si256((const __m256i *)h); auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1); bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh)); bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, mh)); bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh)); bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), mh)); } const __m512i mh = _mm512_set1_epi8(0x04); }; struct DequantizerQ5K final : public BaseDequantizer { DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) { d = GGML_FP16_TO_FP32(x[i].d); bits.prepare(x[i].qs); hbits.apply(x[i].qh, bits); auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]); scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]); } Q4Bits bits; HighBit5 hbits; Scales8K s8k; }; struct Scale16 { inline void make_scales(const __m128i& scales8, __m512i * scales) const { auto all_scales8 = MM256_SET_M128I(scales8, scales8); auto scales1 = _mm256_shuffle_epi8(all_scales8, shuffle1); auto scales2 = _mm256_shuffle_epi8(all_scales8, shuffle2); scales[0] = _mm512_cvtepi8_epi16(scales1); scales[1] = _mm512_cvtepi8_epi16(scales2); } template inline void process_mins_and_scales(int i, float c, const __m128i& mins8, const __m128i& scales8, const Q8& q8, __m256 * accm, __m512i * scales) const { process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, c, accm); make_scales(scales8, scales); } const __m256i shuffle1 = _mm256_set_epi32(0x07070707, 0x03030303, 0x06060606, 0x02020202, 0x05050505, 0x01010101, 0x04040404, 0x00000000); const __m256i shuffle2 = _mm256_set_epi32(0x0f0f0f0f, 0x0b0b0b0b, 0x0e0e0e0e, 0x0a0a0a0a, 0x0d0d0d0d, 0x09090909, 0x0c0c0c0c, 0x08080808); }; struct DequantizerQ2K final : public BaseDequantizer { DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { d = GGML_FP16_TO_FP32(x[i].d); bits.prepare(x[i].qs); const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); sc16.process_mins_and_scales(i, -GGML_FP16_TO_FP32(x[i].dmin), mins8, scales8, q8, accm, scales); } Q2Bits bits; Scale16 sc16; const __m128i m4 = _mm_set1_epi8(0xf); }; struct DequantizerQ3K final : public BaseDequantizer { DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { d = GGML_FP16_TO_FP32(x[i].d); bits.prepare(x[i].qs); hbits.apply(x[i].hmask, bits); auto scales128 = sc3.make_scales((const uint16_t *)x[i].scales); sc16.process_mins_and_scales(i, -4.f*d, scales128, scales128, q8, accm, scales); } Q2Bits bits; HighBit3 hbits; ScaleQ3 sc3; Scale16 sc16; const __m128i m4 = _mm_set1_epi8(0xf); const __m128i m32 = _mm_set1_epi8(-32); }; struct DequantizerQ6K final : public BaseDequantizer { DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { d = GGML_FP16_TO_FP32(x[i].d); bits.prepare64(x[i].ql); add_high_bits(x[i].qh, bits); auto scales128 = _mm_loadu_si128((const __m128i *)x[i].scales); sc16.process_mins_and_scales(i, -32.f*d, scales128, scales128, q8, accm, scales); } inline void add_high_bits(const uint8_t * qh, Q4Bits& bits) const { auto hbits = _mm512_loadu_si512((const __m512i *)qh); auto tmp1 = _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh); auto tmp2 = _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh); bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2)); bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2)); tmp1 = _mm512_and_si512(hbits, mh); tmp2 = _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh); bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2)); bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2)); } Q4Bits bits; HighBit3 hbits; Scale16 sc16; const __m512i mh = _mm512_set1_epi8(0x30); }; template static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx); __m256 accm[nrc_y]; __m512 accd[nrc_y]; __m512i scales[2]; for (int ix = 0; ix < nrc_x; ++ix) { for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps(); for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps(); deq.new_row(ix); for (int i = 0; i < nb; ++i) { deq.new_block(i, q8, accm, scales); for (int iy = 0; iy < nrc_y; ++iy) { const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants(iy, i, 0)); const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants(iy, i, 1)); const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants(iy, i, 2)); const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants(iy, i, 3)); auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); } } } template inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) { const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0)); const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1)); const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2)); const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3)); auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); } template static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx); __m256 accm[nrc_y]; __m512 accd[nrc_y]; __m512i scales[2]; for (int ix = 0; ix < nrc_x; ++ix) { for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps(); for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps(); deq.new_row(ix); for (int i = 0; i < nb; ++i) { deq.new_block(i, q8, accm, scales); for (int iy = 0; iy < nrc_y; ++iy) { const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0)); const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1)); const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2)); const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3)); auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); } } } template static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx); __m256 accm[nrc_y]; __m512 accd[nrc_y]; __m512i scales[4]; for (int ix = 0; ix < nrc_x; ++ix) { for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps(); for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps(); deq.new_row(ix); for (int i = 0; i < nb; ++i) { deq.new_block(i, q8, accm, scales); for (int iy = 0; iy < nrc_y; ++iy) { const __m512i p1 = _mm512_maddubs_epi16(deq.bits.values[0], q8.load_quants64(iy, i, 0)); const __m512i p2 = _mm512_maddubs_epi16(deq.bits.values[1], q8.load_quants64(iy, i, 1)); const __m512i p3 = _mm512_maddubs_epi16(deq.bits.values[2], q8.load_quants64(iy, i, 2)); const __m512i p4 = _mm512_maddubs_epi16(deq.bits.values[3], q8.load_quants64(iy, i, 3)); auto sumi = _mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_setzero_si512(), p1, scales[0]), p2, scales[1]), p3, scales[2]), p4, scales[3]); accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); } } } template static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; constexpr int k_nx = 2; Q8<1> q8(info); Dequantizer deq1(vx, bx); Dequantizer deq2(vx, bx); Dequantizer * deq[k_nx]; deq[0] = &deq1; deq[1] = &deq2; __m512i scales[2*k_nx]; for (int ix = 0; ix < nrc_x; ++ix) { auto accd = _mm512_setzero_ps(); auto accm = _mm256_setzero_ps(); for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_row(ix); for (int i = 0; i < nb/k_nx; ++i) { for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx); for (int kx = 0; kx < k_nx; ++kx) { compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd); } } if (2*(nb/2) < nb) { int i0 = 2*(nb/2); deq[0]->new_block(i0, q8, &accm, scales); compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd); } auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1)); info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256))); } } #else // ===================================== Vanilla AVX2 ===================================== struct Q4Bits { inline void prepare(const uint8_t * q4, int j) { auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0); values[0] = _mm256_and_si256(q4bits, ml); values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1); values[2] = _mm256_and_si256(q4bits, ml); values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); } inline void prepare64(const uint8_t * q4, int j) { auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0); values[0] = _mm256_and_si256(q4bits, ml); values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1); values[1] = _mm256_and_si256(q4bits, ml); values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); } inline void prepare16(const uint8_t * q4, int j) { values[0] = dequant16(q4 + 64*j + 0); values[1] = dequant16(q4 + 64*j + 16); values[2] = dequant16(q4 + 64*j + 32); values[3] = dequant16(q4 + 64*j + 48); } inline __m256i dequant16(const uint8_t * qs) const { const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs); const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128); return _mm256_and_si256(ml, aux256); }; __m256i values[4]; const __m256i ml = _mm256_set1_epi8(0xf); }; struct Q2Bits { inline void prepare(const uint8_t * q2, int j) { auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j); values[0] = _mm256_and_si256(q2bits, ml); values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml); values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml); values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml); } __m256i values[4]; const __m256i ml = _mm256_set1_epi8(0x03); }; struct HighBit5 { inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); } inline void apply(Q4Bits& bits, bool do_shift) { bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh)); bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh)); bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)); if (do_shift) { hbits = _mm256_srli_epi16(hbits, 4); } } const __m256i mh = _mm256_set1_epi8(0x10); __m256i hbits; }; struct HighBit3 { inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); } inline void apply(Q2Bits& bits, bool do_shift) { bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)); bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh)); bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh)); if (do_shift) { hbits = _mm256_srli_epi16(hbits, 4); } } const __m256i mh = _mm256_set1_epi8(0x04); __m256i hbits; }; /* template inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { if (j == 0) { for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2))); const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3))); sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4)); } } else { for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6))); const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7))); sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3)); sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4)); } } }*/ struct DequantizerQ4K final : public BaseDequantizer { DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { d = GGML_FP16_TO_FP32(x[i].d); return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); } inline void prepare(int i, int j) { bits.prepare(x[i].qs, j); } Q4Bits bits; Scales8K s8k; }; struct DequantizerIQ4XS final : public BaseDequantizer { DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {} template inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { d = GGML_FP16_TO_FP32(x[i].d); auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h); s8k.accum_mins(scales128, q8, i, -128.f*d, accd); return MM256_SET_M128I(scales128, scales128); } inline void prepare(int i, int j) { bits.prepare16(x[i].qs, j); bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]); bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]); bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]); bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]); } static __m256i load_values() { static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241}; auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl); return MM256_SET_M128I(val128, val128); } Q4Bits bits; Scales8K s8k; ScaleIQ4XS siq4; const __m256i values; }; struct DequantizerQ5K final : public BaseDequantizer { DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { d = GGML_FP16_TO_FP32(x[i].d); hbits.load(x[i].qh); return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); } inline void prepare(int i, int j) { bits.prepare(x[i].qs, j); hbits.apply(bits, j == 0); } Q4Bits bits; HighBit5 hbits; Scales8K s8k; }; template inline void process_mins_and_scales_16(const __m128i& scales128, const Q8& q8, int i, float d, __m256 * accm, __m256i * scales) { const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); process_mins_16(all_scales, q8, i, d, accm); prepare_scales_16(all_scales, scales); } struct DequantizerQ3K final : public BaseDequantizer { DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { d = GGML_FP16_TO_FP32(x[i].d); hbits.load(x[i].hmask); process_mins_and_scales_16(sc3.make_scales((const uint16_t *)x[i].scales), q8, i, -4.f*d, accm, scales); } inline void prepare(int i, int j) { bits.prepare(x[i].qs, j); hbits.apply(bits, j == 0); } Q2Bits bits; HighBit3 hbits; ScaleQ3 sc3; const __m128i m32 = _mm_set1_epi8(-32); }; struct DequantizerQ2K final : public BaseDequantizer { DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { d = GGML_FP16_TO_FP32(x[i].d); const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, -GGML_FP16_TO_FP32(x[i].dmin), accm); prepare_scales_16(_mm256_cvtepi8_epi16(scales8), scales); } inline void prepare(int i, int j) { bits.prepare(x[i].qs, j); } Q2Bits bits; const __m128i m4 = _mm_set1_epi8(0xf); }; struct DequantizerQ6K final : public BaseDequantizer { DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { d = GGML_FP16_TO_FP32(x[i].d); process_mins_and_scales_16(_mm_loadu_si128((const __m128i *)x[i].scales), q8, i, -32.f*d, accm, scales); } inline void prepare(int i, int j) { bits.prepare64(x[i].ql, j); auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j); bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh)); bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh)); bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh)); } Q4Bits bits; const __m256i mh = _mm256_set1_epi8(0x30); }; inline __m256i get_scale_shuffle_8(int i); inline void set_scales_8(const __m256i& all_scales, int j, __m256i* scales); inline __m256i get_scale_shuffle_16(int i); inline void set_scales_16(const __m256i& all_scales, __m256i* scales); template static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); const int nb = n/QK_K; Q8 q8(info); __m256i all_scales[2]; __m256i scales[4]; __m256 accd[nrc_y]; Dequantizer deq(vx, bx); for (int ix = 0; ix < nrc_x; ++ix) { deq.new_row(ix); for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { deq.new_block(i, q8, accd, all_scales); __m256i sumi[nrc_y]; for (int j = 0; j < QK_K/128; ++j) { deq.prepare(i, j); set_scales_16(all_scales[j], scales); multiply_add(deq.bits, scales, j, i, q8, sumi); } for (int iy = 0; iy < nrc_y; ++iy) { accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(sumi[iy]), accd[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, hsum_float_8(accd[iy])); } } } template static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx); __m256 accd[nrc_y]; __m256i scales[4]; for (int ix = 0; ix < nrc_x; ++ix) { for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); deq.new_row(ix); for (int i = 0; i < nb; ++i) { auto all_scales = deq.new_block(i, q8, accd); __m256i sumi[nrc_y]; for (int j = 0; j < QK_K/128; ++j) { deq.prepare(i, j); set_scales_8(all_scales, j, scales); multiply_add(deq.bits, scales, j, i, q8, sumi); } for (int iy = 0; iy < nrc_y; ++iy) { const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i)); accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, hsum_float_8(accd[iy])); } } } #endif // Zen4 or vanilla AVX2 // // ============================== Legacy quants // struct DotHelper { const __m256i m1 = _mm256_set1_epi16(1); #if defined(__AVX512VNNI__) && defined(__AVX512VL__) inline __m256i dot(__m256i x, __m256i y) const { return _mm256_dpbusd_epi32(_mm256_setzero_si256(), x, y); } #else inline __m256i dot(__m256i x, __m256i y) const { return _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x, y)); } #endif }; struct SignedDot { DotHelper helper; inline __m256i compute(__m256i x, __m256i y) const { return helper.dot(_mm256_sign_epi8(x, x), _mm256_sign_epi8(y, x)); } }; struct UnsignedDot { DotHelper helper; inline __m256i compute(__m256i x, __m256i y) const { return helper.dot(x, y); } }; template struct Sum4 { Dot dot; inline __m256i compute(const __m256i * qx, const Q8 * y) const { const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[0].qs)); const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y[1].qs)); const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y[2].qs)); const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y[3].qs)); const __m256i p01 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1 const __m256i p23 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3 return _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p01, p23)); // 0,1,2,3, 0,1,2,3 } }; struct Sum4_Q8 { SignedDot dot; static inline __m256i add1(__m256i a, __m256i b) { return _mm256_add_epi32(_mm256_unpacklo_epi32(a, b), _mm256_unpackhi_epi32(a, b)); } static inline __m256i add2(__m256i a, __m256i b) { return _mm256_add_epi32(_mm256_unpacklo_epi64(a, b), _mm256_unpackhi_epi64(a, b)); } inline __m256i compute(const __m256i * qx, const block_q8_0 * y) const { const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[0].qs)); const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y[1].qs)); const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y[2].qs)); const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y[3].qs)); const __m256i p01 = add1(p0, p1); // 0,1, 0,1, 0,1, 0,1 const __m256i p23 = add1(p2, p3); // 2,3, 2,3, 2,3, 2,3 return add2(p01, p23); // returns 0,1,2,3, 0,1,2,3 } }; struct ScaleHelperQ_0 { ggml_half scales8[4]; template inline __m128 prepare4(const Q * y) { for (int j = 0; j < 4; ++j) scales8[j] = y[j].d; return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8)); } template inline __m128 prepare4(__m128 other_scales, const Q * y) { return _mm_mul_ps(other_scales, prepare4(y)); } template inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); } template inline float prepare1(float d, const Q * y) const { return d*prepare1(y); } }; template struct ScaleHelperQ_0_1 { ggml_half scales8[4]; template inline __m256 prepare4(const Q * y) { for (int j = 0; j < 4; ++j) scales8[j] = y[j].d; auto s4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8)); return _mm256_set_m128(_mm_mul_ps(s4, min), s4); } template inline __m256 prepare4(__m256 other_scales, const Q * y) { return _mm_mul256_ps(other_scales, prepare4(y)); } template inline std::pair prepare1(const Q * y) const { float d = GGML_FP16_TO_FP32(y->d); return std::make_pair(d, -d*float(min_value)); } std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); } const __m128 min = _mm_set1_ps(float(-min_value)); }; struct ScaleHelperQ_1 { uint32_t scales8[4]; const __m128i shuffle = _mm_set_epi16(0x0f0e, 0x0b0a, 0x0706, 0x0302, 0x0d0c, 0x0908, 0x0504, 0x0100); template inline __m256 prepare4(const Q * y) { for (int j = 0; j < 4; ++j) { // it is slightly faster to directly dereference (const uint32 *)&y[j].d, but some compilers // complain that this breaks strict-aliasing rules. memcpy(scales8 + j, &y[j].d, sizeof(uint32_t)); } return _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *)scales8), shuffle)); } template inline __m256 prepare4(__m256 other_scales, const Q * y) { return _mm256_mul_ps(other_scales, prepare4(y)); } template inline std::pair prepare1(const Q * y) const { return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m)); } template inline std::pair prepare1(const std::pair& dm, const Q * y) const { return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m)); } std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); } }; struct MinusType0 { inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); } inline float compute(float d, int) const { return d; } inline float result(__m256 acc, int) const { return hsum_float_8(acc); } }; template struct MinusType1 { __m128 accm[nrc_y]; MinusType1() { for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm_setzero_ps(); } inline __m256 compute(__m256 dm, int iy) { const __m128 d = _mm256_castps256_ps128(dm); const __m128 m = _mm256_extractf128_ps(dm, 1); accm[iy] = _mm_add_ps(accm[iy], m); return _mm256_set_m128(d, d); } inline float compute(const std::pair& dm, int iy) { accm[iy] = _mm_add_ps(accm[iy], _mm_set1_ps(dm.second*0.25f)); return dm.first; } inline float result(__m256 acc, int iy) const { const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); return hsum_float_4(_mm_add_ps(sum, accm[iy])); } }; template struct AccumT { __m256 acc[nrc_y]; Minus accm; AccumT() { for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = _mm256_setzero_ps(); } template inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, const DataInfo& info, int ix) { auto qx = unp.quants(); __m256 dall[nrc_y]; for (int i = 0; i < nb/4; ++i) { auto other_scales = unp.set_block_4(i); for (int iy = 0; iy < nrc_y; ++iy) { auto s12 = scales.prepare4(other_scales, y[iy] + 4*i); dall[iy] = accm.compute(s12, iy); } for (int iy = 0; iy < nrc_y; ++iy) { auto pall = sum.compute(qx, y[iy] + 4*i); acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]); } } if (!is_multiple_of_4) { for (int i = 4*(nb/4); i < nb; ++i) { auto other_scales = unp.set_block(i); for (int iy = 0; iy < nrc_y; ++iy) { auto s12 = scales.prepare1(other_scales, y[iy] + i); auto d = accm.compute(s12, iy); const __m256i p0 = sum.dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]); } } } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, accm.result(acc[iy], iy)); //s[iy*bs] = accm.result(acc[iy], iy); } } }; template using AccumType0 = AccumT; template using AccumType1 = AccumT, nrc_y, is_multiple_of_4>; using Sum4Type0 = Sum4; using Sum4Type1 = Sum4; template void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) { Unpacker unp(vx, bx); Sum4Type sum4; Scales scales; for (int ix = 0; ix < nrc_x; ++ix) { unp.set_row(ix); AccumType accum; accum.compute(nb, unp, scales, sum4, y, info, ix); } } template void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%Unpacker::block_size() == 0); Q8 q8(info); int nb = n/Unpacker::block_size(); if (nb%4 == 0) { mul_mat_qX_q8_Helper, ScaleHelperQ_0, block_q8_0, nrc_y>( nb, vx, bx, info, q8.y, nrc_x ); } else { mul_mat_qX_q8_Helper, ScaleHelperQ_0, block_q8_0, nrc_y>( nb, vx, bx, info, q8.y, nrc_x ); } } template void mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%Unpacker::block_size() == 0); Q8 q8(info); int nb = n/Unpacker::block_size(); if (nb%4 == 0) { mul_mat_qX_q8_Helper, ScaleHelperQ_1, block_q8_1, nrc_y>( nb, vx, bx, info, q8.y, nrc_x ); } else { mul_mat_qX_q8_Helper, ScaleHelperQ_1, block_q8_1, nrc_y>( nb, vx, bx, info, q8.y, nrc_x ); } } struct Dequantizer4bit { const __m256i m4 = _mm256_set1_epi8(0xf); inline __m256i dequant(const uint8_t * qs) const { const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs); return _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128), m4); } }; struct Q8_0_Dequantizer { inline __m256i dequant(const block_q8_0 * x) const { return _mm256_loadu_si256((const __m256i *)x->qs); } }; struct Q8_0_1_Dequantizer { inline __m256i dequant(const block_q8_0 * x) const { return _mm256_add_epi8(_mm256_set1_epi8(127), _mm256_loadu_si256((const __m256i *)x->qs)); } }; struct Q4_0_Dequantizer { Dequantizer4bit b4; const __m256i m8 = _mm256_set1_epi8(-8); inline __m256i dequant(const block_q4_0 * x) const { return _mm256_add_epi8(b4.dequant(x->qs), m8); } }; struct Q4_1_Dequantizer { Dequantizer4bit b4; inline __m256i dequant(const block_q4_1 * x) const { return b4.dequant(x->qs); } }; struct HBitDequantizer { const __m256i shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); const __m256i mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); const __m256i minus1 = _mm256_set1_epi64x(-1); inline __m256i to_bytes(const uint8_t * bits) const { // Note: Data in all ggml quants is at least 2-byte aligned. // => we can cast to uint16_t and use or on two consecutive entries // which is faster than memcpy const uint16_t * aux16 = (const uint16_t *)bits; const uint32_t aux32 = aux16[0] | (aux16[1] << 16); //uint32_t aux32; memcpy(&aux32, bits, sizeof(uint32_t)); __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(aux32), shuffle); bytes = _mm256_or_si256(bytes, mask); return _mm256_cmpeq_epi8(bytes, minus1); } }; struct Q5_0_Dequantizer { Dequantizer4bit b4; HBitDequantizer hbit; const __m256i mh = _mm256_set1_epi8((char)0xF0); inline __m256i dequant(const block_q5_0 * x) const { const __m256i vqh = _mm256_andnot_si256(hbit.to_bytes(x->qh), mh); return _mm256_or_si256(b4.dequant(x->qs), vqh); } }; struct Q5_1_Dequantizer { Dequantizer4bit b4; HBitDequantizer hbit; const __m256i mh = _mm256_set1_epi8(0x10); inline __m256i dequant(const block_q5_1 * x) const { const __m256i vqh = _mm256_and_si256(hbit.to_bytes(x->qh), mh); return _mm256_or_si256(b4.dequant(x->qs), vqh); } }; template struct Q_Unpacker { Q_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const Q*)cx_0), bx(bx) {} const char * cx_0; const Q * x; size_t bx; Scales scales; Dequantizer deq; __m256i qx[4]; inline const __m256i* quants() const { return qx; } inline void set_row(int ix) { x = (const Q*)(cx_0 + ix*bx); } inline auto set_block_4(int i) { for (int j = 0; j < 4; ++j) { qx[j] = deq.dequant(x + 4*i + j); } return scales.prepare4(x + 4*i); } inline auto set_block(int i) { qx[0] = deq.dequant(x + i); return scales.prepare1(x + i); } }; struct Q8_0_Unpacker final : public Q_Unpacker { Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} inline static int block_size() { return QK4_0; } }; struct Q8_0_1_Unpacker final : public Q_Unpacker, Q8_0_1_Dequantizer> { Q8_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} // using Sum4T = Sum4TypeQ81; inline static int block_size() { return QK8_0; } }; struct Q4_0_Unpacker final : public Q_Unpacker { Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} inline static int block_size() { return QK4_0; } }; struct Q5_0_Unpacker final : public Q_Unpacker { Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} inline static int block_size() { return QK5_0; } }; struct Q4_1_Unpacker final : public Q_Unpacker { Q4_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} inline static int block_size() { return QK4_1; } }; struct Q5_1_Unpacker final : public Q_Unpacker { Q5_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} inline static int block_size() { return QK4_1; } }; template void mul_mat_q8_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%Q8_0_Unpacker::block_size() == 0); Q8 q8(info); int nb = n/Q8_0_Unpacker::block_size(); if (nb%4 == 0) { mul_mat_qX_q8_Helper, ScaleHelperQ_0, block_q8_0, nrc_y>( nb, vx, bx, info, q8.y, nrc_x ); } else { mul_mat_qX_q8_Helper, ScaleHelperQ_0, block_q8_0, nrc_y>( nb, vx, bx, info, q8.y, nrc_x ); } } /* moonll add some structs for DequantizerIQ2XXS SimpleBits EvenSignHelper */ struct SimpleBits { __m256i values[4]; }; // fix for #829: 添加对 AVX512VPOPCNTDQ 的检测 #if defined(HAVE_FANCY_SIMD) && defined(__AVX512VPOPCNTDQ__) #define HAVE_AVX512_POPCNT 1 #else #define HAVE_AVX512_POPCNT 0 #endif struct EvenSignHelper { #if defined HAVE_FANCY_SIMD // #pragma message("Using AVX512VPOPCNTDQ in even sign helper") union sbits_t { __m128i vec; __mmask32 mask[4]; }; IQK_ALWAYS_INLINE void sign_2_values(__m256i aux, __m256i * values) const { aux = _mm256_and_si256(_mm256_srlv_epi32(aux, shifts), mask); // fix for #829: 兼容Intel Cascade Lake架构的CPU,如果不支持AVX512VPOPCNTDQ扩展,则使用替代实现 #if HAVE_AVX512_POPCNT auto pcnt = _mm256_popcnt_epi32(aux); #else // 提供替代实现,使用标准的位计数方法 __m256i pcnt; int* pcnt_ptr = reinterpret_cast(&pcnt); int* aux_ptr = reinterpret_cast(&aux); // 直接获取 aux 的地址,避免不必要的复制 #pragma unroll 8 // 提示编译器展开循环,提高 SIMD 计算吞吐量 for (int i = 0; i < 8; i++) { pcnt_ptr[i] = __builtin_popcount(aux_ptr[i]); // 使用编译器内置 popcount } #endif sbits_t sbits; sbits.vec = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7))); values[0] = _mm256_mask_sub_epi8(values[0], sbits.mask[0], _mm256_setzero_si256(), values[0]); values[1] = _mm256_mask_sub_epi8(values[1], sbits.mask[1], _mm256_setzero_si256(), values[1]); //auto sign_bits = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7))); //const __mmask32 * m32 = (const __mmask32 *)&sign_bits; //values[0] = _mm256_mask_sub_epi8(values[0], m32[0], _mm256_setzero_si256(), values[0]); //values[1] = _mm256_mask_sub_epi8(values[1], m32[1], _mm256_setzero_si256(), values[1]); } const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0); const __m256i mask = _mm256_set1_epi32(127); const __m256i mone = _mm256_set1_epi32(1); #else inline void sign_value(uint32_t aux32, __m256i& value) const { auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127], keven_signs[(aux32 >> 7) & 127], keven_signs[(aux32 >> 0) & 127]); value = _mm256_sign_epi8(value, signs); } #endif }; /* moonll ad multiply_add for mul_mat_qX_K_q8_K_IQ_1 add func get_scale_shuffle_8 get_scale_shuffle_16 set_scales_16 */ inline __m256i get_scale_shuffle_8(int i) { return _mm256_set1_epi16((2*i) | ((2*i+1) << 8)); } inline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) { scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0)); scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1)); scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2)); scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3)); } inline __m256i get_scale_shuffle_16(int i) { static const uint8_t k_shuffle[128] = { 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, }; return _mm256_loadu_si256((const __m256i*)k_shuffle + i); } inline void set_scales_16(const __m256i& all_scales, __m256i * scales) { scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0)); scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1)); scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2)); scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3)); } template inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { if (j == 0) { #ifdef HAVE_FANCY_SIMD for (int iy = 0; iy < Q8::nrc_y; ++iy) { sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3))); } #else for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2))); const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3))); sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4)); } #endif } else { #ifdef HAVE_FANCY_SIMD for (int iy = 0; iy < Q8::nrc_y; ++iy) { sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7))); } #else for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6))); const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7))); sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3)); sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4)); } #endif } } /* moonll ad multiply_add_1 for mul_mat_qX_K_q8_K_IQ_1 add func set_scales_8_iq set_scales_16_iq add MUL_MAT mul_mat_qX_K_q8_K_IQ_1 mul_mat_qX_K_q8_K_IQ_N mul_mat_qX_K_q8_K_IQ */ template inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) { if (j == 0) { #ifdef HAVE_FANCY_SIMD auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]); auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]); auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]); auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]); sumi[0] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_packs_epi32(p1, p2)); sumi[1] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[1], _mm256_packs_epi32(p3, p4)); #else const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0])); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1])); const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2])); const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3])); sumi[0] = _mm256_add_epi32(p1, p3); sumi[1] = _mm256_add_epi32(p2, p4); #endif } else { #ifdef HAVE_FANCY_SIMD auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]); auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]); auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]); auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]); sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[0], _mm256_packs_epi32(p1, p2)); sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[1], _mm256_packs_epi32(p3, p4)); #else const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0])); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1])); const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2])); const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3])); sumi[0] = _mm256_add_epi32(sumi[0], _mm256_add_epi32(p1, p3)); sumi[1] = _mm256_add_epi32(sumi[1], _mm256_add_epi32(p2, p4)); #endif } } inline void set_scales_8_iq(int j, const __m256i& all_scales, __m256i * scales) { //#ifdef HAVE_FANCY_SIMD auto shuffle = j == 0 ? _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100) : _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908); scales[0] = _mm256_shuffle_epi8(all_scales, shuffle); scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(4))); //#else // set_scales_8(all_scales, j, scales); //#endif } inline void set_scales_16_iq(const __m256i& all_scales, __m256i * scales) { #ifdef HAVE_FANCY_SIMD auto shuffle = _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100); scales[0] = _mm256_shuffle_epi8(all_scales, shuffle); scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(8))); #else set_scales_16(all_scales, scales); #endif } template static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_K; Q8<1> q8(info); Dequantizer deq(vx, bx); __m256i scales[2]; __m256i q8_quants[4]; for (int ix = 0; ix < nrc_x; ++ix) { __m256 accd = _mm256_setzero_ps(); deq.new_row(ix); for (int i = 0; i < nb; ++i) { __m256i sumi[2], all_scales[Dequantizer::num_blocks/8]; deq.new_block(i, all_scales); for (int j = 0; j < QK_K/128; ++j) { deq.prepare(i, j, q8, q8_quants); if constexpr (Dequantizer::num_blocks == 8) { set_scales_8_iq(j, all_scales[0], scales); } else { set_scales_16_iq(all_scales[j], scales); } multiply_add_1(j, deq.bits, scales, q8_quants, sumi); } accd = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(0, i)), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi[0], sumi[1])), accd); } info.store(ix, 0, hsum_float_8(accd)); } } template static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx); __m256i scales[4]; __m256 accd[nrc_y]; for (int ix = 0; ix < nrc_x; ++ix) { for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); deq.new_row(ix); for (int i = 0; i < nb; ++i) { __m256i sumi[nrc_y], all_scales[Dequantizer::num_blocks/8]; //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256(); __m256i mins; float dmin = deq.new_block(i, all_scales, mins); for (int iy = 0; iy < nrc_y; ++iy) { auto bsums = q8.load_bsums(iy, i); auto prod = _mm256_madd_epi16(mins, bsums); accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]); } for (int j = 0; j < QK_K/128; ++j) { deq.prepare(i, j); if constexpr (Dequantizer::num_blocks == 8) { set_scales_8(all_scales[0], j, scales); } else { set_scales_16(all_scales[j], scales); } //multiply_add_iq(deq.bits, scales, j, i, q8, sumi); multiply_add(deq.bits, scales, j, i, q8, sumi); } for (int iy = 0; iy < nrc_y; ++iy) { const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i)); accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, hsum_float_8(accd[iy])); } } } template static void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); #ifdef HAVE_FANCY_SIMD if constexpr (nrc_y == 1) { mul_mat_qX_K_q8_K_IQ_1(n, vx, bx, info, nrc_x); } else { mul_mat_qX_K_q8_K_IQ_N(n, vx, bx, info, nrc_x); } #else mul_mat_qX_K_q8_K_IQ_N(n, vx, bx, info, nrc_x); #endif } /* moonll iq1s core func for iq1s mul_mat_iq1_s_q8_K */ template static void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(n%QK_K == 0); Q8 q8(info); __m256i qx[8]; __m256i scales[4]; __m256 acc[nrc_y] = {}; auto delta_mask = _mm_set1_epi16(-32768); // to avoid stupid overflow warnings when using 0x8000 __m256i shuffle0 = _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100); for (int ix = 0; ix < nrc_x; ++ix) { auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx); for (int ibl = 0; ibl < n/QK_K; ++ibl) { float d = GGML_FP16_TO_FP32(iq1s[ibl].d); auto qhb = _mm_loadu_si128((const __m128i *)iq1s[ibl].qh); auto scales128 = _mm_and_si128(_mm_srli_epi16(qhb, 12), _mm_set1_epi16(7)); scales128 = _mm_add_epi16(_mm_slli_epi16(scales128, 1), _mm_set1_epi16(1)); #ifdef HAVE_FANCY_SIMD auto mask = _mm_cmpeq_epi16_mask(_mm_and_si128(qhb, delta_mask), delta_mask); auto deltas128 = _mm_mask_blend_epi16(mask, _mm_set1_epi16(-7), _mm_set1_epi16(-9)); #else auto mask = _mm_cmpeq_epi16(_mm_and_si128(qhb, delta_mask), delta_mask); auto deltas128 = _mm_or_si128(_mm_and_si128(mask, _mm_set1_epi16(-9)), _mm_andnot_si128(mask, _mm_set1_epi16(-7))); #endif deltas128 = _mm_mullo_epi16(scales128, deltas128); scales128 = _mm_slli_epi16(scales128, 3); auto deltas_l = _mm_unpacklo_epi16(deltas128, deltas128); auto deltas_h = _mm_unpackhi_epi16(deltas128, deltas128); auto deltas = MM256_SET_M128I(deltas_h, deltas_l); // blocks 0,0, 1,1, 2,2, ..., 7,7 auto all_scales = MM256_SET_M128I(scales128, scales128); auto shuffle = shuffle0; for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { scales[ib64] = _mm256_shuffle_epi8(all_scales, shuffle); shuffle = _mm256_add_epi8(shuffle, _mm256_set1_epi8(4)); } const uint8_t * qs = iq1s[ibl].qs; const uint16_t * qh = iq1s[ibl].qh; for (int ib = 0; ib < QK_K/32; ib += 2) { qx[ib+0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid_us[qs[2] | ((qh[ib+0] << 2) & 0x700)], iq1s_grid_us[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid_us[qs[0] | ((qh[ib+0] << 8) & 0x700)]); qx[ib+1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid_us[qs[6] | ((qh[ib+1] << 2) & 0x700)], iq1s_grid_us[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid_us[qs[4] | ((qh[ib+1] << 8) & 0x700)]); qs += 8; } for (int iy = 0; iy < nrc_y; ++iy) { auto bsums = q8.load_bsums(iy, ibl); auto sumi = _mm256_setzero_si256(); for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { auto qy1 = q8.load_quants(iy, ibl, 2*ib64+0); auto qy2 = q8.load_quants(iy, ibl, 2*ib64+1); #ifdef HAVE_FANCY_SIMD auto dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+0], qy1); auto dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+1], qy2); sumi = _mm256_dpwssd_epi32(sumi, scales[ib64], _mm256_packs_epi32(dot1, dot2)); #else auto dot1 = _mm256_maddubs_epi16(qx[2*ib64+0], qy1); auto dot2 = _mm256_maddubs_epi16(qx[2*ib64+1], qy2); auto dot = _mm256_add_epi16(_mm256_unpacklo_epi64(dot1, dot2), _mm256_unpackhi_epi64(dot1, dot2)); sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(scales[ib64], dot)); #endif } #ifdef HAVE_FANCY_SIMD sumi = _mm256_dpwssd_epi32(sumi, bsums, deltas); #else sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(bsums, deltas)); #endif acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d*q8.scale(iy, ibl)), _mm256_cvtepi32_ps(sumi), acc[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, 0.125f*hsum_float_8(acc[iy])); acc[iy] = _mm256_setzero_ps(); } } } /* moonll iq1s DequantizerIQ2XXS DequantizerIQ2XXS is important Dequantizer for DequantizerIQ1_S */ struct DequantizerIQ2XXS final : public BaseDequantizer { DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} constexpr static int num_blocks = 8; union Data { __m256i vec; uint32_t val[8]; }; inline __m128i load_scales(int i) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); const uint16_t * a16 = (const uint16_t *)x[i].qs; auto scales = _mm_srli_epi16(_mm_set_epi16(a16[31], a16[27], a16[23], a16[19], a16[15], a16[11], a16[7], a16[3]), 12); return _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi16(1)); } inline void new_block(int i, __m256i * scales) { auto sc16 = load_scales(i); scales[0] = MM256_SET_M128I(sc16, sc16); } inline float new_block(int i, __m256i * scales, __m256i& mins) { auto sc16 = load_scales(i); mins = scb.shuffle(sc16); scales[0] = MM256_SET_M128I(sc16, sc16); return -d*minv; } inline static void make4(const uint32_t * aux32, __m256i * values) { const uint8_t * aux8 = (const uint8_t *)aux32; values[0] = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[ 1]], iq2xxs_grid[aux8[ 0]]); values[1] = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[ 9]], iq2xxs_grid[aux8[ 8]]); values[2] = _mm256_set_epi64x(iq2xxs_grid[aux8[19]], iq2xxs_grid[aux8[18]], iq2xxs_grid[aux8[17]], iq2xxs_grid[aux8[16]]); values[3] = _mm256_set_epi64x(iq2xxs_grid[aux8[27]], iq2xxs_grid[aux8[26]], iq2xxs_grid[aux8[25]], iq2xxs_grid[aux8[24]]); } IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const { #ifdef HAVE_FANCY_SIMD esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0); esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2); #else esh.sign_value(aux32[1], values[0]); esh.sign_value(aux32[3], values[1]); esh.sign_value(aux32[5], values[2]); esh.sign_value(aux32[7], values[3]); #endif } inline void make4_signed(const uint32_t * aux32, const __m256i& min_value, __m256i * values) const { make4(aux32, values); sign_values(aux32, values); for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value); } inline void make4(const uint32_t * aux32, __m256i * values, __m256i * q8) const { make4(aux32, values); sign_values(aux32, q8); } inline void prepare(int i, int j) { Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j); make4_signed(data.val, min_value, bits.values); } inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) { for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k); Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j); make4(data.val, bits.values, q8_quants); } constexpr static int minv = 43; SimpleBits bits; Scales8KBase scb; EvenSignHelper esh; const __m256i min_value = _mm256_set1_epi8(minv); const __m256i shuffle = _mm256_set_epi32(7, 5, 3, 1, 7, 5, 3, 1); }; /* moonll add Q8_0_Unpacker && DequantizerIQ2XXS support add func mul_mat_qX_K_q8_K_IQ */ template void MulMat::set_functions(MulMat& m) { if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qX_0_q8_0_T; m.funcs[1] = mul_mat_qX_0_q8_0_T; m.funcs[2] = mul_mat_qX_0_q8_0_T; m.funcs[3] = mul_mat_qX_0_q8_0_T; m.funcs[4] = mul_mat_qX_0_q8_0_T; m.funcs[5] = mul_mat_qX_0_q8_0_T; m.funcs[6] = mul_mat_qX_0_q8_0_T; m.funcs[7] = mul_mat_qX_0_q8_0_T; } else if constexpr (std::is_same_v || std::is_same_v|| std::is_same_v) { m.funcs[0] = mul_mat_qX_1_q8_1_T; m.funcs[1] = mul_mat_qX_1_q8_1_T; m.funcs[2] = mul_mat_qX_1_q8_1_T; m.funcs[3] = mul_mat_qX_1_q8_1_T; m.funcs[4] = mul_mat_qX_1_q8_1_T; m.funcs[5] = mul_mat_qX_1_q8_1_T; m.funcs[6] = mul_mat_qX_1_q8_1_T; m.funcs[7] = mul_mat_qX_1_q8_1_T; } else if constexpr (std::is_same_v) { m.funcs[0] = mul_mat_qX_K_q8_K_IQ; m.funcs[1] = mul_mat_qX_K_q8_K_IQ; m.funcs[2] = mul_mat_qX_K_q8_K_IQ; m.funcs[3] = mul_mat_qX_K_q8_K_IQ; m.funcs[4] = mul_mat_qX_K_q8_K_IQ; m.funcs[5] = mul_mat_qX_K_q8_K_IQ; m.funcs[6] = mul_mat_qX_K_q8_K_IQ; m.funcs[7] = mul_mat_qX_K_q8_K_IQ; } else { #ifdef HAVE_FANCY_SIMD if constexpr (std::is_same_v) { m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[3] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[4] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512; } else { m.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1; m.funcs[1] = mul_mat_qX_K_q8_K_AVX512; m.funcs[2] = mul_mat_qX_K_q8_K_AVX512; m.funcs[3] = mul_mat_qX_K_q8_K_AVX512; m.funcs[4] = mul_mat_qX_K_q8_K_AVX512; m.funcs[5] = mul_mat_qX_K_q8_K_AVX512; m.funcs[6] = mul_mat_qX_K_q8_K_AVX512; m.funcs[7] = mul_mat_qX_K_q8_K_AVX512; } #else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qY_K_q8_K_T; m.funcs[1] = mul_mat_qY_K_q8_K_T; m.funcs[2] = mul_mat_qY_K_q8_K_T; m.funcs[3] = mul_mat_qY_K_q8_K_T; m.funcs[4] = mul_mat_qY_K_q8_K_T; m.funcs[5] = mul_mat_qY_K_q8_K_T; m.funcs[6] = mul_mat_qY_K_q8_K_T; m.funcs[7] = mul_mat_qY_K_q8_K_T; } else { m.funcs[0] = mul_mat_qX_K_q8_K_T; m.funcs[1] = mul_mat_qX_K_q8_K_T; m.funcs[2] = mul_mat_qX_K_q8_K_T; m.funcs[3] = mul_mat_qX_K_q8_K_T; m.funcs[4] = mul_mat_qX_K_q8_K_T; m.funcs[5] = mul_mat_qX_K_q8_K_T; m.funcs[6] = mul_mat_qX_K_q8_K_T; m.funcs[7] = mul_mat_qX_K_q8_K_T; } #endif } } struct QFBase { #ifdef __AVX512F__ constexpr static int k_step = 16; using Data = __m512; using Acc = __m512; static inline Data load(const ggml_half * x) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x)); } static inline Data load(const float * x) { return _mm512_loadu_ps(x); } static inline Data load(const ggml_bf16_t * x) { return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)x)), 16)); } static inline Acc acc(Acc prev, const Data& y, const Data& x) { return _mm512_fmadd_ps(y, x, prev); } static inline Acc acc_first(const Data& y, const Data& x) { return _mm512_mul_ps(y, x); } static inline Acc add(Acc x, Acc y) { return _mm512_add_ps(x, y); } static inline float hsum(Acc acc) { return _mm512_reduce_add_ps(acc); } template static inline Data load4Floats(const Float * x) { return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0); } static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) { acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc); acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc); acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc); acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc); return acc; } static inline Acc acc_r4_first(const Data * xv, const Data& yv) { auto acc = _mm512_mul_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00)); acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc); acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc); acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc); return acc; } static inline __m128 hsum_r4(Acc acc) { auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 0), _mm512_extractf32x4_ps(acc, 1)); auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 2), _mm512_extractf32x4_ps(acc, 3)); return _mm_add_ps(sum1, sum2); } #else constexpr static int k_step = 8; using Data = __m256; using Acc = __m256; static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); } static inline Data load(const float * x) { return _mm256_loadu_ps(x); } static inline Data load(const ggml_bf16_t * x) { return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)x)), 16)); } static inline Acc acc(Acc prev, const Data& y, const Data& x) { return _mm256_fmadd_ps(y, x, prev); } static inline Acc add(Acc x, Acc y) { return _mm256_add_ps(x, y); } static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) { acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc); acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc); acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc); acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc); return acc; } static inline Acc acc_r4_first(const Data * xv, const Data& yv) { auto acc = _mm256_mul_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00)); acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc); acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc); acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc); return acc; } static inline Acc acc_first(const Data& y, const Data& x) { return _mm256_mul_ps(y, x); } static inline float hsum(Acc acc) { return hsum_float_8(acc); } static inline __m128 hsum_r4(Acc acc) { return _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); } template static inline Data load4Floats(const Float * x) { return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0); } #endif static inline __m128 load128(const ggml_half * x) { return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x)); } static inline __m128 load128(const float * x) { return _mm_loadu_ps(x); } static inline __m128 load128(const ggml_bf16_t * x) { return _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*)x)), 16)); } }; template struct QFT final : public QFBase { constexpr static int nrc = nrc_in; QFT(const DataInfo& info) { for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)info.src1_row(iy); } QFT(const char * cx, size_t bx) { for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)(cx + iy*bx); } IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); } IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); } IQK_ALWAYS_INLINE void load_r4(int ix, int i, Data * xv) const { xv[0] = load1(ix+0, i); xv[1] = load1(ix+1, i); xv[2] = load1(ix+2, i); xv[3] = load1(ix+3, i); #ifdef __AVX512F__ auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]); auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]); auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]); auto t3 = _mm512_unpackhi_ps(xv[2], xv[3]); xv[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1))); xv[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1))); xv[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3))); xv[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3))); #else auto t0 = _mm256_unpacklo_ps(xv[0], xv[1]); auto t1 = _mm256_unpacklo_ps(xv[2], xv[3]); auto t2 = _mm256_unpackhi_ps(xv[0], xv[1]); auto t3 = _mm256_unpackhi_ps(xv[2], xv[3]); xv[0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1))); xv[1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1))); xv[2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3))); xv[3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3))); #endif } const Float * y[nrc]; }; template IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { int nb = n/QFBase::k_step; int nb4 = n/4; Qy y(info); Qx x(cx + ix0*bx, bx); QFBase::Data xv[Qx::nrc]; QFBase::Acc acc[Qx::nrc*Qy::nrc]; auto yv = y.load1(0, 0); for (int ix = 0; ix < Qx::nrc; ++ix) { xv[ix] = x.load1(ix, 0); acc[ix] = QFBase::acc_first(yv, xv[ix]); } for (int iy = 1; iy < Qy::nrc; ++iy) { yv = y.load1(iy, 0); for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]); } for (int i = 1; i < nb; ++i) { yv = y.load1(0, i); for (int ix = 0; ix < Qx::nrc; ++ix) { xv[ix] = x.load1(ix, i); acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]); } for (int iy = 1; iy < Qy::nrc; ++iy) { yv = y.load1(iy, i); for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]); } } for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) { yv = y.load_tail(0, i); for (int ix = 0; ix < Qx::nrc; ++ix) { xv[ix] = x.load_tail(ix, i); acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]); } for (int iy = 1; iy < Qy::nrc; ++iy) { yv = y.load_tail(iy, i); for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]); } } for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix])); } // This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done // in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in // f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now. template void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const char * cx = (const char *)vx; // TBD if we want this //if constexpr (nrc_y == 1) { // constexpr int k_nx = 2; // for (int ix = 0; ix < nrc_x/k_nx; ++ix) { // mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, ix*k_nx, info); // } // if (int lastx = k_nx*(nrc_x/k_nx); lastx < nrc_x) { // int nx = nrc_x - lastx; // switch (nx) { // case 1: mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, lastx, info); break; // case 2: mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, lastx, info); break; // case 3: mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, lastx, info); break; // } // //mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, lastx, info); // } // return; //} #ifdef __AVX512F__ constexpr int k_nx = 5; #else constexpr int k_nx = nrc_y == 1 ? 4 : 2; #endif for (int ix = 0; ix < nrc_x/k_nx; ++ix) { mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, ix*k_nx, info); } int last_x = k_nx*(nrc_x/k_nx); if (last_x == nrc_x) return; int nx = nrc_x - last_x; #ifdef __AVX512F__ switch (nx) { case 1: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; case 2: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; case 3: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; case 4: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; } #else if constexpr (nrc_y == 1) { switch (nx) { case 1: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; case 2: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; case 3: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; } } else { switch (nx) { case 1: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; } } #endif } template void set_mul_mat_f(MulMat& mm) { for (auto& f : mm.funcs) f = nullptr; mm.funcs[0] = mul_mat_fX_fY_T<1, FloatX, FloatY>; mm.funcs[1] = mul_mat_fX_fY_T<2, FloatX, FloatY>; mm.funcs[2] = mul_mat_fX_fY_T<3, FloatX, FloatY>; mm.funcs[3] = mul_mat_fX_fY_T<4, FloatX, FloatY>; mm.funcs[4] = mul_mat_fX_fY_T<5, FloatX, FloatY>; #ifndef __AVX512F__ mm.funcs[5] = mul_mat_fX_fY_T<6, FloatX, FloatY>; #endif } /* moonll add typeb TO compare return not expected type of weight matrix add IQ2XSS add IQ1_S add GGML_TYPE_IQ4_XS */ bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { (void)Ny; auto expected_typeB = GGML_TYPE_Q8_K; switch (typeA) { case GGML_TYPE_Q2_K: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; case GGML_TYPE_Q3_K: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; case GGML_TYPE_Q4_K: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; case GGML_TYPE_Q5_K: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; case GGML_TYPE_Q6_K: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; case GGML_TYPE_IQ4_XS: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; case GGML_TYPE_IQ2_XXS: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; case GGML_TYPE_Q4_0: assert (ne00 % QK4_0 == 0); MulMat::set_functions(mm); expected_typeB = GGML_TYPE_Q8_0; break; case GGML_TYPE_Q4_1: assert (ne00 % QK4_1 == 0); MulMat::set_functions(mm); expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q5_0: assert (ne00 % QK5_0 == 0); MulMat::set_functions(mm); expected_typeB = GGML_TYPE_Q8_0; break; case GGML_TYPE_Q5_1: assert (ne00 % QK5_1 == 0); MulMat::set_functions(mm); expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q8_0: assert (ne00 % QK8_0 == 0); #ifdef HAVE_FANCY_SIMD MulMat::set_functions(mm); expected_typeB = GGML_TYPE_Q8_1_X4; #else MulMat::set_functions(mm); expected_typeB = GGML_TYPE_Q8_0_X4; #endif break; case GGML_TYPE_IQ1_S: mm.funcs[0] = mul_mat_iq1_s_q8_K<1>; mm.funcs[1] = mul_mat_iq1_s_q8_K<2>; mm.funcs[2] = mul_mat_iq1_s_q8_K<3>; mm.funcs[3] = mul_mat_iq1_s_q8_K<4>; mm.funcs[4] = mul_mat_iq1_s_q8_K<5>; mm.funcs[5] = mul_mat_iq1_s_q8_K<6>; mm.funcs[6] = mul_mat_iq1_s_q8_K<7>; mm.funcs[7] = mul_mat_iq1_s_q8_K<8>; #ifdef HAVE_FANCY_SIMD mm.func16 = mul_mat_iq1_s_q8_K<16>; #endif // row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); expected_typeB = GGML_TYPE_Q8_K; break; default: { printf("case:%d",typeA); return false; } } return ggml_type(typeB) == expected_typeB; } } // namespace /* iq1_s is not support for arm */ #else // __aarch64__ #include namespace { template struct Q8 { constexpr static int nrc_y = nrc; Q8(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy); } inline int8x16_t load_quants_16(int iy, int i, int j) const { return vld1q_s8(y[iy][i].qs + 16*j); } inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); } inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); } inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); } inline int16x8_t load_bsums8(int iy, int i) const { auto q8s = vld1q_s16_x2(y[iy][i].bsums); return vpaddq_s16(q8s.val[0], q8s.val[1]); } inline float scale(int iy, int i) const { return y[iy][i].d; } const block_q8 * y[nrc_y]; }; template struct BaseDequantizer { BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {} inline void new_row(int ix) { x = (const block_q *)((const char *)vx + ix*bx); } const void * vx; const block_q * x; const size_t bx; const int nrc; }; struct Q4bits { const uint8x16_t m4b = vdupq_n_u8(0xf); uint8x16x4_t b1, b2; inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const { b.val[0] = vandq_u8(val[0], m4b); b.val[2] = vshrq_n_u8(val[0], 4); b.val[1] = vandq_u8(val[1], m4b); b.val[3] = vshrq_n_u8(val[1], 4); } inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const { b.val[0] = vandq_u8(val[0], m4b); b.val[1] = vshrq_n_u8(val[0], 4); b.val[2] = vandq_u8(val[1], m4b); b.val[3] = vshrq_n_u8(val[1], 4); } inline void prepare(const uint8_t * qs) { auto q4bits = vld1q_u8_x2(qs); prepare4(b1, q4bits.val); q4bits = vld1q_u8_x2(qs+32); prepare4(b2, q4bits.val); } inline void prepare_v2(const uint8_t * qs) { auto q4bits = vld1q_u8_x4(qs); prepare4(b1, q4bits.val+0); prepare4(b2, q4bits.val+2); } inline void prepare64(const uint8_t * qs) { auto q4bits = vld1q_u8_x4(qs); b1.val[0] = vandq_u8(q4bits.val[0], m4b); b1.val[1] = vandq_u8(q4bits.val[1], m4b); b1.val[2] = vandq_u8(q4bits.val[2], m4b); b1.val[3] = vandq_u8(q4bits.val[3], m4b); b2.val[0] = vshrq_n_u8(q4bits.val[0], 4); b2.val[1] = vshrq_n_u8(q4bits.val[1], 4); b2.val[2] = vshrq_n_u8(q4bits.val[2], 4); b2.val[3] = vshrq_n_u8(q4bits.val[3], 4); } inline void prepare16(const uint8_t * qs) { auto q4bits = vld1q_u8_x2(qs); prepare4_16(b1, q4bits.val); q4bits = vld1q_u8_x2(qs+32); prepare4_16(b2, q4bits.val); } inline void prepare16_v2(const uint8_t * qs) { auto q4bits = vld1q_u8_x4(qs); prepare4_16(b1, q4bits.val+0); prepare4_16(b2, q4bits.val+2); } }; struct Scales8 { uint32_t utmp[4]; const uint8_t * sc8 = (const uint8_t *)utmp; template inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) { make_q4_scales(x.scales, utmp); int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8)); accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin)); uint8x8_t scales8 = vld1_u8(sc8); uint16x8_t scales16 = vmovl_u8(scales8); int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))), vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))}; return scales; } }; struct DequantizerQ4K final : public BaseDequantizer { DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 8; } constexpr static bool should_scale_quants() { return false; } template inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { d = GGML_FP16_TO_FP32(x[i].d); return s8.process_scales_mins(x[i], q8, i, acc); } inline void prepare(int i, int j) { if (nrc == 1) bits.prepare_v2(x[i].qs+64*j); else bits.prepare(x[i].qs+64*j); } Q4bits bits; Scales8 s8; float d; }; struct DequantizerQ6K final : public BaseDequantizer { DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 16; } constexpr static bool should_scale_quants() { return false; } template inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { d = GGML_FP16_TO_FP32(x[i].d); return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d); } inline void prepare(int i, int j) { auto hbits = vld1q_u8_x2(x[i].qh + 32*j); bits.prepare64(x[i].ql+64*j); bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb)); bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb)); bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb)); bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb)); bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb)); bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb)); bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb)); bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb)); } Q4bits bits; const uint8x16_t mhb = vdupq_n_u8(0x30); float d; }; template struct BlockQxK { inline BlockQxK(const int maxn, const int maxk): maxn(maxn), maxk(maxk) { values = (int8_t*)aligned_alloc(256, maxn * maxk * sizeof(int8_t)); scales = (int*)aligned_alloc(256, maxn * maxk / SS * sizeof(int)); ds = (float*)aligned_alloc(256, maxn * maxk / QK * sizeof(int)); if constexpr (NeedSum) { dmins = (float*)aligned_alloc(256, maxn * maxk / QK * sizeof(int)); scalems = (int16_t*)aligned_alloc(256, maxn * maxk / SS * sizeof(int16_t)); } } inline ~BlockQxK() { free(values); free(scales); free(ds); if constexpr (NeedSum) { free(dmins); free(scalems); } } inline int FromDequantizer(const void * vx, size_t bx, int idx, int n_, int k_) { n = n_; k = k_; bn = n / BS; bk = k / QK; Dequantizer deq(vx, bx, 1); for (int i = 0; i < n; i += BS) { for (int j = 0; j < BS; j ++) { deq.new_row(j + i + idx); for (int x = 0; x < bk; x ++) { { int8x16_t base = NeedSum ? vdupq_n_s8(0) : vdupq_n_s8(32); int32_t *dst = (int32_t*)(values + i*k + j*4 + x*QK*BS); deq.prepare(x, 0); int8x16_t v0 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[0]), base); int8x16_t v1 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[1]), base); int8x16_t v2 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[2]), base); int8x16_t v3 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[3]), base); *(dst + (0 + 0*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 0); *(dst + (1 + 0*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 1); *(dst + (2 + 0*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 2); *(dst + (3 + 0*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 3); *(dst + (0 + 1*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 0); *(dst + (1 + 1*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 1); *(dst + (2 + 1*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 2); *(dst + (3 + 1*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 3); *(dst + (0 + 2*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 0); *(dst + (1 + 2*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 1); *(dst + (2 + 2*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 2); *(dst + (3 + 2*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 3); *(dst + (0 + 3*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 0); *(dst + (1 + 3*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 1); *(dst + (2 + 3*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 2); *(dst + (3 + 3*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 3); v0 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[0]), base); v1 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[1]), base); v2 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[2]), base); v3 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[3]), base); *(dst + (0 + 0*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 0); *(dst + (1 + 0*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 1); *(dst + (2 + 0*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 2); *(dst + (3 + 0*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 3); *(dst + (0 + 1*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 0); *(dst + (1 + 1*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 1); *(dst + (2 + 1*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 2); *(dst + (3 + 1*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 3); *(dst + (0 + 2*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 0); *(dst + (1 + 2*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 1); *(dst + (2 + 2*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 2); *(dst + (3 + 2*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 3); *(dst + (0 + 3*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 0); *(dst + (1 + 3*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 1); *(dst + (2 + 3*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 2); *(dst + (3 + 3*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 3); deq.prepare(x, 1); v0 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[0]), base); v1 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[1]), base); v2 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[2]), base); v3 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[3]), base); *(dst + (0 + 0*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 0); *(dst + (1 + 0*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 1); *(dst + (2 + 0*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 2); *(dst + (3 + 0*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 3); *(dst + (0 + 1*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 0); *(dst + (1 + 1*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 1); *(dst + (2 + 1*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 2); *(dst + (3 + 1*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 3); *(dst + (0 + 2*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 0); *(dst + (1 + 2*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 1); *(dst + (2 + 2*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 2); *(dst + (3 + 2*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 3); *(dst + (0 + 3*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 0); *(dst + (1 + 3*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 1); *(dst + (2 + 3*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 2); *(dst + (3 + 3*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 3); v0 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[0]), base); v1 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[1]), base); v2 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[2]), base); v3 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[3]), base); *(dst + (0 + 0*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 0); *(dst + (1 + 0*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 1); *(dst + (2 + 0*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 2); *(dst + (3 + 0*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 3); *(dst + (0 + 1*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 0); *(dst + (1 + 1*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 1); *(dst + (2 + 1*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 2); *(dst + (3 + 1*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 3); *(dst + (0 + 2*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 0); *(dst + (1 + 2*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 1); *(dst + (2 + 2*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 2); *(dst + (3 + 2*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 3); *(dst + (0 + 3*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 0); *(dst + (1 + 3*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 1); *(dst + (2 + 3*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 2); *(dst + (3 + 3*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 3); } if constexpr (std::is_same_v) { int32_t *dst = (int32_t*)(scales + i*(k/SS) + j + x*QK/SS*BS); int8x16_t ss = vld1q_s8(deq.x[x].scales); int16x8_t s16_0 = vmovl_s8(vget_low_s8(ss)); int16x8_t s16_1 = vmovl_s8(vget_high_s8(ss)); int32x4_t s32_0 = vmovl_s16(vget_low_s16(s16_0)); int32x4_t s32_1 = vmovl_s16(vget_high_s16(s16_0)); int32x4_t s32_2 = vmovl_s16(vget_low_s16(s16_1)); int32x4_t s32_3 = vmovl_s16(vget_high_s16(s16_1)); *(dst + (0+0*4)*BS) = vgetq_lane_s32(s32_0, 0); *(dst + (1+0*4)*BS) = vgetq_lane_s32(s32_0, 1); *(dst + (2+0*4)*BS) = vgetq_lane_s32(s32_0, 2); *(dst + (3+0*4)*BS) = vgetq_lane_s32(s32_0, 3); *(dst + (0+1*4)*BS) = vgetq_lane_s32(s32_1, 0); *(dst + (1+1*4)*BS) = vgetq_lane_s32(s32_1, 1); *(dst + (2+1*4)*BS) = vgetq_lane_s32(s32_1, 2); *(dst + (3+1*4)*BS) = vgetq_lane_s32(s32_1, 3); *(dst + (0+2*4)*BS) = vgetq_lane_s32(s32_2, 0); *(dst + (1+2*4)*BS) = vgetq_lane_s32(s32_2, 1); *(dst + (2+2*4)*BS) = vgetq_lane_s32(s32_2, 2); *(dst + (3+2*4)*BS) = vgetq_lane_s32(s32_2, 3); *(dst + (0+3*4)*BS) = vgetq_lane_s32(s32_3, 0); *(dst + (1+3*4)*BS) = vgetq_lane_s32(s32_3, 1); *(dst + (2+3*4)*BS) = vgetq_lane_s32(s32_3, 2); *(dst + (3+3*4)*BS) = vgetq_lane_s32(s32_3, 3); } if constexpr (std::is_same_v) { int32_t *dst = (int32_t*)(scales + i*(k/SS) + j + x*QK/SS*BS); int16_t *dst2 = (int16_t*)(scalems + i*(k/SS) + j + x*QK/SS*BS); uint32_t utmp[4]; const uint8_t * sc8 = (const uint8_t *)utmp; make_q4_scales(deq.x[x].scales, utmp); int8x16_t ss = vld1q_s8((const int8_t *)sc8); int16x8_t scale = vmovl_s8(vget_low_s8(ss)); int16x8_t scale_min = vmovl_high_s8(ss); int32x4_t s32_0 = vmovl_s16(vget_low_s16(scale)); int32x4_t s32_1 = vmovl_s16(vget_high_s16(scale)); *(dst + (0+0*4)*BS) = vgetq_lane_s32(s32_0, 0); *(dst + (1+0*4)*BS) = vgetq_lane_s32(s32_0, 1); *(dst + (2+0*4)*BS) = vgetq_lane_s32(s32_0, 2); *(dst + (3+0*4)*BS) = vgetq_lane_s32(s32_0, 3); *(dst + (0+1*4)*BS) = vgetq_lane_s32(s32_1, 0); *(dst + (1+1*4)*BS) = vgetq_lane_s32(s32_1, 1); *(dst + (2+1*4)*BS) = vgetq_lane_s32(s32_1, 2); *(dst + (3+1*4)*BS) = vgetq_lane_s32(s32_1, 3); *(dst2 + 0*BS) = vgetq_lane_s16(scale_min, 0); *(dst2 + 1*BS) = vgetq_lane_s16(scale_min, 1); *(dst2 + 2*BS) = vgetq_lane_s16(scale_min, 2); *(dst2 + 3*BS) = vgetq_lane_s16(scale_min, 3); *(dst2 + 4*BS) = vgetq_lane_s16(scale_min, 4); *(dst2 + 5*BS) = vgetq_lane_s16(scale_min, 5); *(dst2 + 6*BS) = vgetq_lane_s16(scale_min, 6); *(dst2 + 7*BS) = vgetq_lane_s16(scale_min, 7); } { float *dst = ds + i*bk + j + x*BS; *dst = GGML_FP16_TO_FP32(deq.x[x].d); } if constexpr (std::is_same_v) { float *dst = dmins + i*bk + j + x*BS; *dst = - GGML_FP16_TO_FP32(deq.x[x].dmin); } } } } return 0; } int8_t *values; // [bn][k/4][BS][4] int *scales; // [bn][k/SS][BS] float *ds; // [bn][bk][BS] float *dmins; // [bn][bk][BS] int16_t *scalems; // [bn][k/SS][BS] static constexpr int BS = 8; static constexpr int QK = 256; static constexpr int SS = std::is_same_v ? 16 : 32; static constexpr int NeedSum = std::is_same_v ? 0 : 1; const int maxn; const int maxk; int n; int k; int bn; int bk; }; template IQK_NOINLINE void matmul_v2_kernel(const Dequantizer *a, const block_q8_K *y[BN], const DataInfo &info, int idx, int idy) { constexpr int BS = a->BS; constexpr int QK = a->QK; constexpr int SS = a->SS; for (int s = 0; s < a->n; s += BS) { float32x4_t cc[BN][BS/4]; for (int i = 0; i < BN; i ++) { for (int j = 0; j < BS/4; j ++) { cc[i][j] = vdupq_n_f32(0); } } const int8_t *a_ptr = a->values + s*a->k; const int8_t *b_ptr[BN]; for (int k = 0; k < a->bk; k ++) { for (int i = 0; i < BN; i ++) { b_ptr[i] = y[i][k].qs; } int32x4_t cci[BN][BS/4]; if constexpr (BN == 4 && SS == 16) { int64_t length = QK/SS; auto ap = a_ptr; auto sp = a->scales + s*a->k/SS + (k*QK/SS)*BS; // asm volatile ( asm volatile ( " eor %[c00].16b, %[c00].16b, %[c00].16b \n" " eor %[c10].16b, %[c10].16b, %[c10].16b \n" " eor %[c20].16b, %[c20].16b, %[c20].16b \n" " eor %[c30].16b, %[c30].16b, %[c30].16b \n" " eor %[c01].16b, %[c01].16b, %[c01].16b \n" " eor %[c11].16b, %[c11].16b, %[c11].16b \n" " eor %[c21].16b, %[c21].16b, %[c21].16b \n" " eor %[c31].16b, %[c31].16b, %[c31].16b \n" " loop_%=: \n" " subs %[len], %[len], #1 \n" " ld1 {v12.16b}, [%[bp0]], #16 \n" " ld1 {v13.16b}, [%[bp1]], #16 \n" " ld1 {v14.16b}, [%[bp2]], #16 \n" " ld1 {v15.16b}, [%[bp3]], #16 \n" " prfm pldl1strm, [%[ap], #256] \n" " ld1 {v8.16b}, [%[ap]], #16 \n" " ld1 {v9.16b}, [%[ap]], #16 \n" " eor v0.16b, v0.16b, v0.16b \n" " eor v1.16b, v1.16b, v1.16b \n" " eor v2.16b, v2.16b, v2.16b \n" " eor v3.16b, v3.16b, v3.16b \n" " eor v4.16b, v4.16b, v4.16b \n" " eor v5.16b, v5.16b, v5.16b \n" " eor v6.16b, v6.16b, v6.16b \n" " eor v7.16b, v7.16b, v7.16b \n" " ld1 {v10.16b}, [%[ap]], #16 \n" " ld1 {v11.16b}, [%[ap]], #16 \n" " sdot v0.4s, v8.16b, v12.4b[0] \n" " sdot v1.4s, v8.16b, v13.4b[0] \n" " sdot v2.4s, v8.16b, v14.4b[0] \n" " sdot v3.4s, v8.16b, v15.4b[0] \n" " sdot v4.4s, v9.16b, v12.4b[0] \n" " sdot v5.4s, v9.16b, v13.4b[0] \n" " sdot v6.4s, v9.16b, v14.4b[0] \n" " sdot v7.4s, v9.16b, v15.4b[0] \n" " prfm pldl1strm, [%[ap], #256] \n" " ld1 {v8.16b}, [%[ap]], #16 \n" " ld1 {v9.16b}, [%[ap]], #16 \n" " sdot v0.4s, v10.16b, v12.4b[1] \n" " sdot v1.4s, v10.16b, v13.4b[1] \n" " sdot v2.4s, v10.16b, v14.4b[1] \n" " sdot v3.4s, v10.16b, v15.4b[1] \n" " sdot v4.4s, v11.16b, v12.4b[1] \n" " sdot v5.4s, v11.16b, v13.4b[1] \n" " sdot v6.4s, v11.16b, v14.4b[1] \n" " sdot v7.4s, v11.16b, v15.4b[1] \n" " ld1 {v10.16b}, [%[ap]], #16 \n" " ld1 {v11.16b}, [%[ap]], #16 \n" " sdot v0.4s, v8.16b, v12.4b[2] \n" " sdot v1.4s, v8.16b, v13.4b[2] \n" " sdot v2.4s, v8.16b, v14.4b[2] \n" " sdot v3.4s, v8.16b, v15.4b[2] \n" " sdot v4.4s, v9.16b, v12.4b[2] \n" " sdot v5.4s, v9.16b, v13.4b[2] \n" " sdot v6.4s, v9.16b, v14.4b[2] \n" " sdot v7.4s, v9.16b, v15.4b[2] \n" " ld1 {v8.4s}, [%[sp]], #16 \n" " ld1 {v9.4s}, [%[sp]], #16 \n" " sdot v0.4s, v10.16b, v12.4b[3] \n" " sdot v1.4s, v10.16b, v13.4b[3] \n" " sdot v2.4s, v10.16b, v14.4b[3] \n" " sdot v3.4s, v10.16b, v15.4b[3] \n" " sdot v4.4s, v11.16b, v12.4b[3] \n" " sdot v5.4s, v11.16b, v13.4b[3] \n" " sdot v6.4s, v11.16b, v14.4b[3] \n" " sdot v7.4s, v11.16b, v15.4b[3] \n" " mla %[c00].4s, v0.4s, v8.4s \n" " mla %[c10].4s, v1.4s, v8.4s \n" " mla %[c20].4s, v2.4s, v8.4s \n" " mla %[c30].4s, v3.4s, v8.4s \n" " mla %[c01].4s, v4.4s, v9.4s \n" " mla %[c11].4s, v5.4s, v9.4s \n" " mla %[c21].4s, v6.4s, v9.4s \n" " mla %[c31].4s, v7.4s, v9.4s \n" " bne loop_%= \n" " exit_%=:\n" : [len] "+r" (length) , [ap] "+r" (ap) , [bp0] "+r" (b_ptr[0]) , [bp1] "+r" (b_ptr[1]) , [bp2] "+r" (b_ptr[2]) , [bp3] "+r" (b_ptr[3]) , [sp] "+r" (sp) , [c00] "+w" (cci[0][0]) , [c10] "+w" (cci[1][0]) , [c20] "+w" (cci[2][0]) , [c30] "+w" (cci[3][0]) , [c01] "+w" (cci[0][1]) , [c11] "+w" (cci[1][1]) , [c21] "+w" (cci[2][1]) , [c31] "+w" (cci[3][1]) : : "v0", "v1", "v2", "v3" , "v4", "v5", "v6", "v7" , "v8", "v9", "v10", "v11" , "v12", "v13", "v14", "v15" , "memory", "cc" ); a_ptr += BS * QK; } else if (BN == 4 && SS == 32) { int64_t length = QK/SS; auto ap = a_ptr; auto sp = a->scales + s*a->k/SS + (k*QK/SS)*BS; // asm volatile ( asm volatile ( " eor %[c00].16b, %[c00].16b, %[c00].16b \n" " eor %[c10].16b, %[c10].16b, %[c10].16b \n" " eor %[c20].16b, %[c20].16b, %[c20].16b \n" " eor %[c30].16b, %[c30].16b, %[c30].16b \n" " eor %[c01].16b, %[c01].16b, %[c01].16b \n" " eor %[c11].16b, %[c11].16b, %[c11].16b \n" " eor %[c21].16b, %[c21].16b, %[c21].16b \n" " eor %[c31].16b, %[c31].16b, %[c31].16b \n" " loop_%=: \n" " subs %[len], %[len], #1 \n" " ld1 {v12.16b}, [%[bp0]], #16 \n" " ld1 {v13.16b}, [%[bp1]], #16 \n" " ld1 {v14.16b}, [%[bp2]], #16 \n" " ld1 {v15.16b}, [%[bp3]], #16 \n" " prfm pldl1strm, [%[ap], #256] \n" " ld1 {v8.16b}, [%[ap]], #16 \n" " ld1 {v9.16b}, [%[ap]], #16 \n" " eor v0.16b, v0.16b, v0.16b \n" " eor v1.16b, v1.16b, v1.16b \n" " eor v2.16b, v2.16b, v2.16b \n" " eor v3.16b, v3.16b, v3.16b \n" " eor v4.16b, v4.16b, v4.16b \n" " eor v5.16b, v5.16b, v5.16b \n" " eor v6.16b, v6.16b, v6.16b \n" " eor v7.16b, v7.16b, v7.16b \n" " ld1 {v10.16b}, [%[ap]], #16 \n" " ld1 {v11.16b}, [%[ap]], #16 \n" " sdot v0.4s, v8.16b, v12.4b[0] \n" " sdot v1.4s, v8.16b, v13.4b[0] \n" " sdot v2.4s, v8.16b, v14.4b[0] \n" " sdot v3.4s, v8.16b, v15.4b[0] \n" " sdot v4.4s, v9.16b, v12.4b[0] \n" " sdot v5.4s, v9.16b, v13.4b[0] \n" " sdot v6.4s, v9.16b, v14.4b[0] \n" " sdot v7.4s, v9.16b, v15.4b[0] \n" " prfm pldl1strm, [%[ap], #256] \n" " ld1 {v8.16b}, [%[ap]], #16 \n" " ld1 {v9.16b}, [%[ap]], #16 \n" " sdot v0.4s, v10.16b, v12.4b[1] \n" " sdot v1.4s, v10.16b, v13.4b[1] \n" " sdot v2.4s, v10.16b, v14.4b[1] \n" " sdot v3.4s, v10.16b, v15.4b[1] \n" " sdot v4.4s, v11.16b, v12.4b[1] \n" " sdot v5.4s, v11.16b, v13.4b[1] \n" " sdot v6.4s, v11.16b, v14.4b[1] \n" " sdot v7.4s, v11.16b, v15.4b[1] \n" " ld1 {v10.16b}, [%[ap]], #16 \n" " ld1 {v11.16b}, [%[ap]], #16 \n" " sdot v0.4s, v8.16b, v12.4b[2] \n" " sdot v1.4s, v8.16b, v13.4b[2] \n" " sdot v2.4s, v8.16b, v14.4b[2] \n" " sdot v3.4s, v8.16b, v15.4b[2] \n" " sdot v4.4s, v9.16b, v12.4b[2] \n" " sdot v5.4s, v9.16b, v13.4b[2] \n" " sdot v6.4s, v9.16b, v14.4b[2] \n" " sdot v7.4s, v9.16b, v15.4b[2] \n" " prfm pldl1strm, [%[ap], #256] \n" " ld1 {v8.16b}, [%[ap]], #16 \n" " ld1 {v9.16b}, [%[ap]], #16 \n" " sdot v0.4s, v10.16b, v12.4b[3] \n" " sdot v1.4s, v10.16b, v13.4b[3] \n" " sdot v2.4s, v10.16b, v14.4b[3] \n" " sdot v3.4s, v10.16b, v15.4b[3] \n" " sdot v4.4s, v11.16b, v12.4b[3] \n" " sdot v5.4s, v11.16b, v13.4b[3] \n" " sdot v6.4s, v11.16b, v14.4b[3] \n" " sdot v7.4s, v11.16b, v15.4b[3] \n" " ld1 {v10.16b}, [%[ap]], #16 \n" " ld1 {v11.16b}, [%[ap]], #16 \n" " ld1 {v12.16b}, [%[bp0]], #16 \n" " ld1 {v13.16b}, [%[bp1]], #16 \n" " ld1 {v14.16b}, [%[bp2]], #16 \n" " ld1 {v15.16b}, [%[bp3]], #16 \n" " sdot v0.4s, v8.16b, v12.4b[0] \n" " sdot v1.4s, v8.16b, v13.4b[0] \n" " sdot v2.4s, v8.16b, v14.4b[0] \n" " sdot v3.4s, v8.16b, v15.4b[0] \n" " sdot v4.4s, v9.16b, v12.4b[0] \n" " sdot v5.4s, v9.16b, v13.4b[0] \n" " sdot v6.4s, v9.16b, v14.4b[0] \n" " sdot v7.4s, v9.16b, v15.4b[0] \n" " prfm pldl1strm, [%[ap], #256] \n" " ld1 {v8.16b}, [%[ap]], #16 \n" " ld1 {v9.16b}, [%[ap]], #16 \n" " sdot v0.4s, v10.16b, v12.4b[1] \n" " sdot v1.4s, v10.16b, v13.4b[1] \n" " sdot v2.4s, v10.16b, v14.4b[1] \n" " sdot v3.4s, v10.16b, v15.4b[1] \n" " sdot v4.4s, v11.16b, v12.4b[1] \n" " sdot v5.4s, v11.16b, v13.4b[1] \n" " sdot v6.4s, v11.16b, v14.4b[1] \n" " sdot v7.4s, v11.16b, v15.4b[1] \n" " ld1 {v10.16b}, [%[ap]], #16 \n" " ld1 {v11.16b}, [%[ap]], #16 \n" " sdot v0.4s, v8.16b, v12.4b[2] \n" " sdot v1.4s, v8.16b, v13.4b[2] \n" " sdot v2.4s, v8.16b, v14.4b[2] \n" " sdot v3.4s, v8.16b, v15.4b[2] \n" " sdot v4.4s, v9.16b, v12.4b[2] \n" " sdot v5.4s, v9.16b, v13.4b[2] \n" " sdot v6.4s, v9.16b, v14.4b[2] \n" " sdot v7.4s, v9.16b, v15.4b[2] \n" " ld1 {v8.4s}, [%[sp]], #16 \n" " ld1 {v9.4s}, [%[sp]], #16 \n" " sdot v0.4s, v10.16b, v12.4b[3] \n" " sdot v1.4s, v10.16b, v13.4b[3] \n" " sdot v2.4s, v10.16b, v14.4b[3] \n" " sdot v3.4s, v10.16b, v15.4b[3] \n" " sdot v4.4s, v11.16b, v12.4b[3] \n" " sdot v5.4s, v11.16b, v13.4b[3] \n" " sdot v6.4s, v11.16b, v14.4b[3] \n" " sdot v7.4s, v11.16b, v15.4b[3] \n" " mla %[c00].4s, v0.4s, v8.4s \n" " mla %[c10].4s, v1.4s, v8.4s \n" " mla %[c20].4s, v2.4s, v8.4s \n" " mla %[c30].4s, v3.4s, v8.4s \n" " mla %[c01].4s, v4.4s, v9.4s \n" " mla %[c11].4s, v5.4s, v9.4s \n" " mla %[c21].4s, v6.4s, v9.4s \n" " mla %[c31].4s, v7.4s, v9.4s \n" " bne loop_%= \n" " exit_%=:\n" : [len] "+r" (length) , [ap] "+r" (ap) , [bp0] "+r" (b_ptr[0]) , [bp1] "+r" (b_ptr[1]) , [bp2] "+r" (b_ptr[2]) , [bp3] "+r" (b_ptr[3]) , [sp] "+r" (sp) , [c00] "+w" (cci[0][0]) , [c10] "+w" (cci[1][0]) , [c20] "+w" (cci[2][0]) , [c30] "+w" (cci[3][0]) , [c01] "+w" (cci[0][1]) , [c11] "+w" (cci[1][1]) , [c21] "+w" (cci[2][1]) , [c31] "+w" (cci[3][1]) : : "v0", "v1", "v2", "v3" , "v4", "v5", "v6", "v7" , "v8", "v9", "v10", "v11" , "v12", "v13", "v14", "v15" , "memory", "cc" ); a_ptr += BS * QK; } else { for (int i = 0; i < BN; i ++) { for (int j = 0; j < BS/4; j ++) { cci[i][j] = vdupq_n_s32(0); } } for (int k0 = 0; k0 < QK/SS; k0 ++) { int32x4_t ccv[BN][BS/4]; for (int i = 0; i < BN; i ++) { for (int j = 0; j < BS/4; j ++) { ccv[i][j] = vdupq_n_s32(0); } } #pragma unroll for (int k2 = 0; k2 < SS; k2 += 16) { const int OFFSET = 256; __builtin_prefetch((a_ptr + OFFSET + 0*64), 0, 0); __builtin_prefetch((a_ptr + OFFSET + 1*64), 0, 0); int8x16_t bb[BN]; int8x16_t aa[BS/4]; for (int i = 0; i < BN; i ++) { bb[i] = vld1q_s8(b_ptr[i]); b_ptr[i] += 16; } for (int k1 = 0; k1 < 4; k1 ++) { for (int i = 0; i < BS/4; i ++) { aa[i] = vld1q_s8(a_ptr); a_ptr += 16; } for (int i = 0; i < BN; i ++) { for (int j = 0; j < BS/4; j ++) { ccv[i][j] = vdotq_laneq_s32(ccv[i][j], aa[j], bb[i], k1); } } } } int32x4_t scal[BS/4]; for (int i = 0; i < BS/4; i ++) { scal[i] = vld1q_s32(a->scales + s*a->k/SS + (k*QK/SS+k0)*BS + i*4); } for (int i = 0; i < BN; i ++) { for (int j = 0; j < BS/4; j ++) { cci[i][j] = vmlaq_s32(cci[i][j], ccv[i][j], scal[j]); } } } } float32x4_t scalf[BS/4]; for (int i = 0; i < BS/4; i ++) { scalf[i] = vld1q_f32(a->ds + s*a->bk + k*BS + i*4); } for (int i = 0; i < BN; i ++) { for (int j = 0; j < BS/4; j ++) { cc[i][j] = vfmaq_f32(cc[i][j], vcvtq_f32_s32(cci[i][j]), vmulq_n_f32(scalf[j], y[i][k].d)); } } } if constexpr (a->NeedSum) { const int16_t *a_ptr = a->scalems + s*a->k/SS; const int16_t *b_ptr[BN]; for (int k = 0; k < a->bk; k ++) { for (int i = 0; i < BN; i ++) { b_ptr[i] = y[i][k].bsums; } int32x4_t cci[BN][BS/4]; for (int i = 0; i < BN; i ++) { for (int j = 0; j < BS/4; j ++) { cci[i][j] = vdupq_n_s32(0); } } for (int k0 = 0; k0 < QK/SS/4; k0 ++) { int16x8_t bb[BN]; int16x8_t aa[BS/8]; for (int i = 0; i < BN; i ++) { bb[i] = vld1q_s16(b_ptr[i]); b_ptr[i] += 8; } for (int k1 = 0; k1 < 4; k1 ++) { for (int i = 0; i < BS/8; i ++) { aa[i] = vld1q_s16(a_ptr); a_ptr += 8; } for (int i = 0; i < BN; i ++) { for (int j = 0; j < BS/8; j ++) { cci[i][2*j+0] = vmlal_laneq_s16(cci[i][2*j+0], vget_low_s16(aa[j]), bb[i], 2*k1+0); cci[i][2*j+1] = vmlal_high_laneq_s16(cci[i][2*j+1], aa[j], bb[i], 2*k1+0); cci[i][2*j+0] = vmlal_laneq_s16(cci[i][2*j+0], vget_low_s16(aa[j]), bb[i], 2*k1+1); cci[i][2*j+1] = vmlal_high_laneq_s16(cci[i][2*j+1], aa[j], bb[i], 2*k1+1); } } } } float32x4_t scalf[BS/4]; for (int i = 0; i < BS/4; i ++) { scalf[i] = vld1q_f32(a->dmins + s*a->bk + k*BS + i*4); } for (int i = 0; i < BN; i ++) { for (int j = 0; j < BS/4; j ++) { cc[i][j] = vfmaq_f32(cc[i][j], vcvtq_f32_s32(cci[i][j]), vmulq_n_f32(scalf[j], y[i][k].d)); } } } } for (int i = 0; i < BN; i ++) { for (int j = 0; j < BS/4; j ++) { vst1q_f32(info.ptr(j*4+s+idx, i), cc[i][j]); } } } return; } template IQK_NOINLINE void mul_mat_qX_K_q8_K_T_v2(int m, int n, int k, const void * vx, size_t bx, const DataInfo& info) { constexpr int m_step = 64; constexpr int n_step = 4; assert(m%m_step == 0); int n2 = n - (n%n_step); int left = n%n_step; BlockQxK xx(m_step, k); for (int i = 0; i < m; i += m_step) { auto this_info = info; int bm = (m - i) < m_step ? (m - i) : m_step; xx.FromDequantizer(vx, bx, i, bm, k); for (int j = 0; j < n2; j += n_step) { Q8 q8(this_info); matmul_v2_kernel, n_step>(&xx, q8.y, this_info, i, j); this_info.cur_y += n_step; } if (left) { switch (left) { case 1: { Q8<1, block_q8_K> q8(this_info); matmul_v2_kernel, 1>(&xx, q8.y, this_info, i, n2); this_info.cur_y += 1; break; } case 2: { Q8<2, block_q8_K> q8(this_info); matmul_v2_kernel, 2>(&xx, q8.y, this_info, i, n2); this_info.cur_y += 2; break; } case 3: { Q8<3, block_q8_K> q8(this_info); matmul_v2_kernel, 3>(&xx, q8.y, this_info, i, n2); this_info.cur_y += 3; break; } } } } return; } template IQK_ALWAYS_INLINE void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) { auto mzero = vdupq_n_s32(0); const int8x16_t * qs_1 = (const int8x16_t *)qx_1.val; const int8x16_t * qs_2 = (const int8x16_t *)qx_2.val; auto q8b_1 = q8.load_quants(iy, i, 4*j+0); auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_1[0], q8b_1.val[0]), qs_1[1], q8b_1.val[1]); // block 1 auto q8b_2 = q8.load_quants(iy, i, 4*j+1); auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_1[2], q8b_2.val[0]), qs_1[3], q8b_2.val[1]); // block 2 auto p12 = vpaddq_s32(p1, p2); auto q8b_3 = q8.load_quants(iy, i, 4*j+2); auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_2[0], q8b_3.val[0]), qs_2[1], q8b_3.val[1]); // block 3 auto q8b_4 = q8.load_quants(iy, i, 4*j+3); auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_2[2], q8b_4.val[0]), qs_2[3], q8b_4.val[1]); // block 4 auto p34 = vpaddq_s32(p3, p4); auto pall = vpaddq_s32(p12, p34); sumi = vmlaq_s32(sumi, scales.val[j], pall); } template IQK_ALWAYS_INLINE void compute_8_blocks(const int8x16_t * qx, const Q8& q8, const int32x4_t& scales, int iy, int i, int j, int32x4_t& sumi) { auto mzero = vdupq_n_s32(0); auto q8b_1 = q8.load_quants(iy, i, 4*j+0); auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[0], q8b_1.val[0]), qx[1], q8b_1.val[1]); // block 1 auto q8b_2 = q8.load_quants(iy, i, 4*j+1); auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[2], q8b_2.val[0]), qx[3], q8b_2.val[1]); // block 2 auto p12 = vpaddq_s32(p1, p2); auto q8b_3 = q8.load_quants(iy, i, 4*j+2); auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[4], q8b_3.val[0]), qx[5], q8b_3.val[1]); // block 3 auto q8b_4 = q8.load_quants(iy, i, 4*j+3); auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[6], q8b_4.val[0]), qx[7], q8b_4.val[1]); // block 4 auto p34 = vpaddq_s32(p3, p4); auto pall = vpaddq_s32(p12, p34); sumi = vmlaq_s32(sumi, scales, pall); } template IQK_ALWAYS_INLINE void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) { auto mzero = vdupq_n_s32(0); auto q8b_1 = q8.load_quants(iy, i, 4*j+0); auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]), ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1, auto q8b_2 = q8.load_quants(iy, i, 4*j+1); auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]), ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4, auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3 sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12); auto q8b_3 = q8.load_quants(iy, i, 4*j+2); auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]), ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5, auto q8b_4 = q8.load_quants(iy, i, 4*j+3); auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]), ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7, auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7 sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34); } template IQK_NOINLINE void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx, nrc_y); for (int ix = 0; ix < nrc_x; ++ix) { deq.new_row(ix); float32x4_t acc[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); for (int i = 0; i < nb; ++i) { int32x4_t sumi[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); if constexpr (Dequantizer::num_blocks() == 8) { auto scales = deq.new_block(i); deq.prepare(i, 0); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); deq.prepare(i, 1); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); } else if constexpr (Dequantizer::num_blocks() == 16) { auto scales = deq.new_block(i); deq.prepare(i, 0); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); deq.prepare(i, 1); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); } else { GGML_ASSERT(false); } #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); } } #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, vaddvq_f32(acc[iy])); } } } template inline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) { for (int iy = 0; iy < Q8::nrc_y; ++iy) { auto q8s = q8.load_bsums8(iy, i); int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s)); int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s)); float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2)); acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i))); } } template inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) { for (int iy = 0; iy < Q8::nrc_y; ++iy) { auto q8s = q8.load_bsums(iy, i); int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0])); int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0])); int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1])); int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1])); float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4))); acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i))); } } struct Q2bits { const uint8x16_t m4b = vdupq_n_u8(0x03); uint8x16x4_t b1, b2; inline void prepare(const uint8_t * qs) { auto q2bits = vld1q_u8_x2(qs); b1.val[0] = vandq_u8(q2bits.val[0], m4b); b1.val[1] = vandq_u8(q2bits.val[1], m4b); q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); b1.val[2] = vandq_u8(q2bits.val[0], m4b); b1.val[3] = vandq_u8(q2bits.val[1], m4b); q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); b2.val[0] = vandq_u8(q2bits.val[0], m4b); b2.val[1] = vandq_u8(q2bits.val[1], m4b); q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); b2.val[2] = vandq_u8(q2bits.val[0], m4b); b2.val[3] = vandq_u8(q2bits.val[1], m4b); } }; struct HighBit5 { const uint8x16_t mhb = vdupq_n_u8(0x10); uint8x16x2_t bits; inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) { b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb)); b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb)); b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb)); b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb)); b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb)); b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb)); b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb)); b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb)); if (do_shift) { bits.val[0] = vshrq_n_u8(bits.val[0], 4); bits.val[1] = vshrq_n_u8(bits.val[1], 4); } } }; struct HighBit3 { const uint8x16_t mhb = vdupq_n_u8(0x04); uint8x16x2_t bits; inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) { b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb)); b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb)); b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb)); b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb)); b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb)); b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb)); b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb)); b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb)); if (do_shift) { bits.val[0] = vshrq_n_u8(bits.val[0], 4); bits.val[1] = vshrq_n_u8(bits.val[1], 4); } } }; struct DequantizerQ5K final : public BaseDequantizer { DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 8; } constexpr static bool should_scale_quants() { return false; } template inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { d = GGML_FP16_TO_FP32(x[i].d); h.bits = vld1q_u8_x2(x[i].qh); return s8.process_scales_mins(x[i], q8, i, acc); } inline void prepare(int i, int j) { bits.prepare(x[i].qs+64*j); h.apply(bits.b1, bits.b2, j == 0); } Q4bits bits; HighBit5 h; Scales8 s8; uint8x16x2_t hbits; float d; }; inline int32x4x4_t make_wider(const int16x8x2_t& scales16) { int32x4x4_t scales = { vmovl_s16(vget_low_s16 (scales16.val[0])), vmovl_s16(vget_high_s16(scales16.val[0])), vmovl_s16(vget_low_s16 (scales16.val[1])), vmovl_s16(vget_high_s16(scales16.val[1])), }; return scales; } template inline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) { int16x8x2_t scales16; scales16.val[0] = vmovl_s8(vget_low_s8(scales8)); scales16.val[1] = vmovl_s8(vget_high_s8(scales8)); accum_mins_16(scales16, q8, acc, i, c); return make_wider(scales16); } struct DequantizerQ3K final : public BaseDequantizer { DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 16; } constexpr static bool should_scale_quants() { return false; } template inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { d = GGML_FP16_TO_FP32(x[i].d); h.bits = vld1q_u8_x2(x[i].hmask); const uint16_t * sc16 = (const uint16_t *)x[i].scales; uint32_t aux0 = sc16[0] | (sc16[1] << 16); uint32_t aux1 = sc16[2] | (sc16[3] << 16); uint32_t aux2 = sc16[4] | (sc16[5] << 16); aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030); aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030); aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030); aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030); return process_scales_mins_16(vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)), q8, acc, i, -4.f*d); } inline void prepare(int i, int j) { bits.prepare(x[i].qs+32*j); h.apply(bits.b1, bits.b2, j == 0); } uint32_t aux32[4]; Q2bits bits; HighBit3 h; float d; }; struct DequantizerQ2K final : public BaseDequantizer { DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 16; } constexpr static bool should_scale_quants() { return true; } template inline void process_scales(int i, const Q8& q8, float32x4_t * acc) { d = GGML_FP16_TO_FP32(x[i].d); auto scales_and_mins = vld1q_u8(x[i].scales); auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4)); int16x8x2_t scales16; scales16.val[0] = vmovl_s8(vget_low_s8(mins8)); scales16.val[1] = vmovl_s8(vget_high_s8(mins8)); accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin)); scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf)); } template inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { process_scales(i, q8, acc); int16x8x2_t scales16; scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8))); scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8))); return make_wider(scales16); } template inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { auto m1 = vdupq_n_u8(1); auto shuffle = vdupq_n_u8(8*j); bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); for (int iy = 0; iy < Q8::nrc_y; ++iy) { auto q8b_1 = q8.load_quants(iy, i, 4*j+0); sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); auto q8b_2 = q8.load_quants(iy, i, 4*j+1); sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); auto q8b_3 = q8.load_quants(iy, i, 4*j+2); sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]), vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]); auto q8b_4 = q8.load_quants(iy, i, 4*j+3); sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]), vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]); } } inline void prepare(int i, int j) { bits.prepare(x[i].qs+32*j); } uint32_t aux32[4]; uint8x16_t scales8; Q2bits bits; float d; }; IQK_ALWAYS_INLINE void fusion_mul_mat_qX_K_q8_K_T_y1_d6k( float32x4_t &acc, const uint8_t *x_ql, // [128] 4bit const uint8_t *x_qh, // [64] 2bit const int8_t *x_scale, // [16] 8bit float x_d, const int8_t *y_qs, // [256] 8bit const int16_t *y_bsums, // [16] 16bit float y_d) { float c0 = x_d * y_d; float c1 = -32.0f * c0; const int OFFSET = 1024; __builtin_prefetch((x_ql + OFFSET + 0*64), 0, 0); __builtin_prefetch((x_ql + OFFSET + 1*64), 0, 0); __builtin_prefetch((x_ql + OFFSET + 2*64), 0, 0); int16x8_t scale16_0, scale16_1; { int8x16_t tmp = vld1q_s8(x_scale); scale16_0 = vmovl_s8(vget_low_s8(tmp)); scale16_1 = vmovl_high_s8(tmp); } { int16x8_t q8s0 = vld1q_s16(y_bsums + 0); int16x8_t q8s1 = vld1q_s16(y_bsums + 8); int32x4_t b0 = vmull_s16(vget_low_s16(scale16_0), vget_low_s16(q8s0)); b0 = vmlal_high_s16(b0, scale16_0, q8s0); b0 = vmlal_s16(b0, vget_low_s16(scale16_1), vget_low_s16(q8s1)); b0 = vmlal_high_s16(b0, scale16_1, q8s1); acc = vfmaq_n_f32(acc, vcvtq_f32_s32(b0), c1); } uint8x16_t x0, x1, x2, x3, x4, x5, x6, x7; int32x4_t sumi = vdupq_n_s32(0); { const uint8x16_t m0 = vdupq_n_u8(0x3f); const uint8x16_t m1 = vdupq_n_u8(0x30); const uint8x16_t m2 = vdupq_n_u8(0x0f); x0 = vld1q_u8(x_ql + 0*16 + 0*64); x1 = vld1q_u8(x_ql + 1*16 + 0*64); x2 = vld1q_u8(x_ql + 2*16 + 0*64); x3 = vld1q_u8(x_ql + 3*16 + 0*64); uint8x16_t hbits0 = vld1q_u8(x_qh + 0*16 + 0*32); uint8x16_t hbits1 = vld1q_u8(x_qh + 1*16 + 0*32); x4 = vandq_u8(hbits0, m0); x4 = vsriq_n_u8(x4, x0, 4); x5 = vandq_u8(hbits1, m0); x5 = vsriq_n_u8(x5, x1, 4); x6 = vshrq_n_u8(hbits0, 2); x6 = vsriq_n_u8(x6, x2, 4); x7 = vshrq_n_u8(hbits1, 2); x7 = vsriq_n_u8(x7, x3, 4); x0 = vsliq_n_u8(x0, hbits0, 4); x0 = vandq_u8(x0, m0); x1 = vsliq_n_u8(x1, hbits1, 4); x1 = vandq_u8(x1, m0); hbits0 = vshlq_n_u8(hbits0, 2); hbits0 = vandq_u8(hbits0, m1); x2 = vandq_u8(x2, m2); x2 = vorrq_u8(x2, hbits0); hbits1 = vshlq_n_u8(hbits1, 2); hbits1 = vandq_u8(hbits1, m1); x3 = vandq_u8(x3, m2); x3 = vorrq_u8(x3, hbits1); } { int8x16_t base = vdupq_n_s8(32); int8x16_t y0 = vld1q_s8(y_qs + 0*16 + 0*128); int8x16_t y1 = vld1q_s8(y_qs + 1*16 + 0*128); int8x16_t y2 = vld1q_s8(y_qs + 2*16 + 0*128); int8x16_t y3 = vld1q_s8(y_qs + 3*16 + 0*128); int8x16_t y4 = vld1q_s8(y_qs + 4*16 + 0*128); int8x16_t y5 = vld1q_s8(y_qs + 5*16 + 0*128); int8x16_t y6 = vld1q_s8(y_qs + 6*16 + 0*128); int8x16_t y7 = vld1q_s8(y_qs + 7*16 + 0*128); int32x4_t p00 = vdupq_n_s32(0); int32x4_t p01 = vdupq_n_s32(0); int32x4_t p10 = vdupq_n_s32(0); int32x4_t p11 = vdupq_n_s32(0); int32x4_t p20 = vdupq_n_s32(0); int32x4_t p21 = vdupq_n_s32(0); int32x4_t p30 = vdupq_n_s32(0); int32x4_t p31 = vdupq_n_s32(0); p00 = vdotq_s32(p00, vreinterpretq_s8_u8(x0), y0); p01 = vdotq_s32(p01, vreinterpretq_s8_u8(x1), y1); p10 = vdotq_s32(p10, vreinterpretq_s8_u8(x2), y2); p11 = vdotq_s32(p11, vreinterpretq_s8_u8(x3), y3); p20 = vdotq_s32(p20, vreinterpretq_s8_u8(x4), y4); p21 = vdotq_s32(p21, vreinterpretq_s8_u8(x5), y5); p30 = vdotq_s32(p30, vreinterpretq_s8_u8(x6), y6); p31 = vdotq_s32(p31, vreinterpretq_s8_u8(x7), y7); // p00 = vdotq_s32(p00, vsubq_s8(vreinterpretq_s8_u8(x0), base), y0); // p01 = vdotq_s32(p01, vsubq_s8(vreinterpretq_s8_u8(x1), base), y1); // p10 = vdotq_s32(p10, vsubq_s8(vreinterpretq_s8_u8(x2), base), y2); // p11 = vdotq_s32(p11, vsubq_s8(vreinterpretq_s8_u8(x3), base), y3); // p20 = vdotq_s32(p20, vsubq_s8(vreinterpretq_s8_u8(x4), base), y4); // p21 = vdotq_s32(p21, vsubq_s8(vreinterpretq_s8_u8(x5), base), y5); // p30 = vdotq_s32(p30, vsubq_s8(vreinterpretq_s8_u8(x6), base), y6); // p31 = vdotq_s32(p31, vsubq_s8(vreinterpretq_s8_u8(x7), base), y7); p00 = vpaddq_s32(p00, p01); p10 = vpaddq_s32(p10, p11); p20 = vpaddq_s32(p20, p21); p30 = vpaddq_s32(p30, p31); p00 = vpaddq_s32(p00, p10); p20 = vpaddq_s32(p20, p30); sumi = vmlaq_s32(sumi, vmovl_s16(vget_low_s16(scale16_0)), p00); sumi = vmlaq_s32(sumi, vmovl_high_s16(scale16_0), p20); } { const uint8x16_t m0 = vdupq_n_u8(0x3f); const uint8x16_t m1 = vdupq_n_u8(0x30); const uint8x16_t m2 = vdupq_n_u8(0x0f); x0 = vld1q_u8(x_ql + 0*16 + 1*64); x1 = vld1q_u8(x_ql + 1*16 + 1*64); x2 = vld1q_u8(x_ql + 2*16 + 1*64); x3 = vld1q_u8(x_ql + 3*16 + 1*64); uint8x16_t hbits0 = vld1q_u8(x_qh + 0*16 + 1*32); uint8x16_t hbits1 = vld1q_u8(x_qh + 1*16 + 1*32); x4 = vandq_u8(hbits0, m0); x4 = vsriq_n_u8(x4, x0, 4); x5 = vandq_u8(hbits1, m0); x5 = vsriq_n_u8(x5, x1, 4); x6 = vshrq_n_u8(hbits0, 2); x6 = vsriq_n_u8(x6, x2, 4); x7 = vshrq_n_u8(hbits1, 2); x7 = vsriq_n_u8(x7, x3, 4); x0 = vsliq_n_u8(x0, hbits0, 4); x0 = vandq_u8(x0, m0); x1 = vsliq_n_u8(x1, hbits1, 4); x1 = vandq_u8(x1, m0); hbits0 = vshlq_n_u8(hbits0, 2); hbits0 = vandq_u8(hbits0, m1); x2 = vandq_u8(x2, m2); x2 = vorrq_u8(x2, hbits0); hbits1 = vshlq_n_u8(hbits1, 2); hbits1 = vandq_u8(hbits1, m1); x3 = vandq_u8(x3, m2); x3 = vorrq_u8(x3, hbits1); } { int8x16_t base = vdupq_n_s8(32); int8x16_t y0 = vld1q_s8(y_qs + 0*16 + 1*128); int8x16_t y1 = vld1q_s8(y_qs + 1*16 + 1*128); int8x16_t y2 = vld1q_s8(y_qs + 2*16 + 1*128); int8x16_t y3 = vld1q_s8(y_qs + 3*16 + 1*128); int8x16_t y4 = vld1q_s8(y_qs + 4*16 + 1*128); int8x16_t y5 = vld1q_s8(y_qs + 5*16 + 1*128); int8x16_t y6 = vld1q_s8(y_qs + 6*16 + 1*128); int8x16_t y7 = vld1q_s8(y_qs + 7*16 + 1*128); int32x4_t p00 = vdupq_n_s32(0); int32x4_t p01 = vdupq_n_s32(0); int32x4_t p10 = vdupq_n_s32(0); int32x4_t p11 = vdupq_n_s32(0); int32x4_t p20 = vdupq_n_s32(0); int32x4_t p21 = vdupq_n_s32(0); int32x4_t p30 = vdupq_n_s32(0); int32x4_t p31 = vdupq_n_s32(0); p00 = vdotq_s32(p00, vreinterpretq_s8_u8(x0), y0); p01 = vdotq_s32(p01, vreinterpretq_s8_u8(x1), y1); p10 = vdotq_s32(p10, vreinterpretq_s8_u8(x2), y2); p11 = vdotq_s32(p11, vreinterpretq_s8_u8(x3), y3); p20 = vdotq_s32(p20, vreinterpretq_s8_u8(x4), y4); p21 = vdotq_s32(p21, vreinterpretq_s8_u8(x5), y5); p30 = vdotq_s32(p30, vreinterpretq_s8_u8(x6), y6); p31 = vdotq_s32(p31, vreinterpretq_s8_u8(x7), y7); // p00 = vdotq_s32(p00, vsubq_s8(vreinterpretq_s8_u8(x0), base), y0); // p01 = vdotq_s32(p01, vsubq_s8(vreinterpretq_s8_u8(x1), base), y1); // p10 = vdotq_s32(p10, vsubq_s8(vreinterpretq_s8_u8(x2), base), y2); // p11 = vdotq_s32(p11, vsubq_s8(vreinterpretq_s8_u8(x3), base), y3); // p20 = vdotq_s32(p20, vsubq_s8(vreinterpretq_s8_u8(x4), base), y4); // p21 = vdotq_s32(p21, vsubq_s8(vreinterpretq_s8_u8(x5), base), y5); // p30 = vdotq_s32(p30, vsubq_s8(vreinterpretq_s8_u8(x6), base), y6); // p31 = vdotq_s32(p31, vsubq_s8(vreinterpretq_s8_u8(x7), base), y7); p00 = vpaddq_s32(p00, p01); p10 = vpaddq_s32(p10, p11); p20 = vpaddq_s32(p20, p21); p30 = vpaddq_s32(p30, p31); p00 = vpaddq_s32(p00, p10); p20 = vpaddq_s32(p20, p30); sumi = vmlaq_s32(sumi, vmovl_s16(vget_low_s16(scale16_1)), p00); sumi = vmlaq_s32(sumi, vmovl_high_s16(scale16_1), p20); } { acc = vfmaq_n_f32(acc, vcvtq_f32_s32(sumi), c0); } return; } IQK_ALWAYS_INLINE void fusion_mul_mat_qX_K_q8_K_T_y1_d4k( float32x4_t &acc, const uint8_t *x_scale, // [12] 8*2*6bits const uint8_t *x_qs, // [128] 256*4bits float x_d, float x_dmin, const int8_t *y_qs, // [256] 8bit const int16_t *y_bsums, // [16] 16bit float y_d) { float c0 = x_d * y_d; float c1 = -x_dmin * y_d; const int OFFSET = 1024; __builtin_prefetch((x_scale + OFFSET + 0*64), 0, 0); __builtin_prefetch((x_scale + OFFSET + 1*64), 0, 0); int16x8_t scale_min; int16x8_t scale; { uint32_t utmp[4]; const uint8_t * sc8 = (const uint8_t *)utmp; make_q4_scales(x_scale, utmp); int8x16_t ss = vld1q_s8((const int8_t *)sc8); scale = vmovl_s8(vget_low_s8(ss)); scale_min = vmovl_high_s8(ss); } { int16x8_t q8s0 = vld1q_s16(y_bsums + 0); int16x8_t q8s1 = vld1q_s16(y_bsums + 8); q8s0 = vpaddq_s16(q8s0, q8s1); int32x4_t b0 = vmull_s16(vget_low_s16(scale_min), vget_low_s16(q8s0)); b0 = vmlal_high_s16(b0, scale_min, q8s0); acc = vfmaq_n_f32(acc, vcvtq_f32_s32(b0), c1); } int32x4_t sumi = vdupq_n_s32(0); const uint8x16_t m4b = vdupq_n_u8(0x0f); uint8x16_t x0, x1, x2, x3, x4, x5, x6, x7; { x0 = vld1q_u8(x_qs + 0*16 + 0*64); x1 = vld1q_u8(x_qs + 1*16 + 0*64); x4 = vld1q_u8(x_qs + 2*16 + 0*64); x5 = vld1q_u8(x_qs + 3*16 + 0*64); x2 = vshrq_n_u8(x0, 4); x3 = vshrq_n_u8(x1, 4); x6 = vshrq_n_u8(x4, 4); x7 = vshrq_n_u8(x5, 4); x0 = vandq_u8(x0, m4b); x1 = vandq_u8(x1, m4b); x4 = vandq_u8(x4, m4b); x5 = vandq_u8(x5, m4b); } { int8x16_t y0 = vld1q_s8(y_qs + 0*16 + 0*128); int8x16_t y1 = vld1q_s8(y_qs + 1*16 + 0*128); int8x16_t y2 = vld1q_s8(y_qs + 2*16 + 0*128); int8x16_t y3 = vld1q_s8(y_qs + 3*16 + 0*128); int8x16_t y4 = vld1q_s8(y_qs + 4*16 + 0*128); int8x16_t y5 = vld1q_s8(y_qs + 5*16 + 0*128); int8x16_t y6 = vld1q_s8(y_qs + 6*16 + 0*128); int8x16_t y7 = vld1q_s8(y_qs + 7*16 + 0*128); int32x4_t p0 = vdupq_n_s32(0); int32x4_t p1 = vdupq_n_s32(0); int32x4_t p2 = vdupq_n_s32(0); int32x4_t p3 = vdupq_n_s32(0); p0 = vdotq_s32(p0, vreinterpretq_s8_u8(x0), y0); p1 = vdotq_s32(p1, vreinterpretq_s8_u8(x2), y2); p2 = vdotq_s32(p2, vreinterpretq_s8_u8(x4), y4); p3 = vdotq_s32(p3, vreinterpretq_s8_u8(x6), y6); p0 = vdotq_s32(p0, vreinterpretq_s8_u8(x1), y1); p1 = vdotq_s32(p1, vreinterpretq_s8_u8(x3), y3); p2 = vdotq_s32(p2, vreinterpretq_s8_u8(x5), y5); p3 = vdotq_s32(p3, vreinterpretq_s8_u8(x7), y7); p0 = vpaddq_s32(p0, p1); p2 = vpaddq_s32(p2, p3); p0 = vpaddq_s32(p0, p2); sumi = vmlaq_s32(sumi, vmovl_s16(vget_low_s16(scale)), p0); } { x0 = vld1q_u8(x_qs + 0*16 + 1*64); x1 = vld1q_u8(x_qs + 1*16 + 1*64); x4 = vld1q_u8(x_qs + 2*16 + 1*64); x5 = vld1q_u8(x_qs + 3*16 + 1*64); x2 = vshrq_n_u8(x0, 4); x3 = vshrq_n_u8(x1, 4); x6 = vshrq_n_u8(x4, 4); x7 = vshrq_n_u8(x5, 4); x0 = vandq_u8(x0, m4b); x1 = vandq_u8(x1, m4b); x4 = vandq_u8(x4, m4b); x5 = vandq_u8(x5, m4b); } { int8x16_t y0 = vld1q_s8(y_qs + 0*16 + 1*128); int8x16_t y1 = vld1q_s8(y_qs + 1*16 + 1*128); int8x16_t y2 = vld1q_s8(y_qs + 2*16 + 1*128); int8x16_t y3 = vld1q_s8(y_qs + 3*16 + 1*128); int8x16_t y4 = vld1q_s8(y_qs + 4*16 + 1*128); int8x16_t y5 = vld1q_s8(y_qs + 5*16 + 1*128); int8x16_t y6 = vld1q_s8(y_qs + 6*16 + 1*128); int8x16_t y7 = vld1q_s8(y_qs + 7*16 + 1*128); int32x4_t p0 = vdupq_n_s32(0); int32x4_t p1 = vdupq_n_s32(0); int32x4_t p2 = vdupq_n_s32(0); int32x4_t p3 = vdupq_n_s32(0); p0 = vdotq_s32(p0, vreinterpretq_s8_u8(x0), y0); p1 = vdotq_s32(p1, vreinterpretq_s8_u8(x2), y2); p2 = vdotq_s32(p2, vreinterpretq_s8_u8(x4), y4); p3 = vdotq_s32(p3, vreinterpretq_s8_u8(x6), y6); p0 = vdotq_s32(p0, vreinterpretq_s8_u8(x1), y1); p1 = vdotq_s32(p1, vreinterpretq_s8_u8(x3), y3); p2 = vdotq_s32(p2, vreinterpretq_s8_u8(x5), y5); p3 = vdotq_s32(p3, vreinterpretq_s8_u8(x7), y7); p0 = vpaddq_s32(p0, p1); p2 = vpaddq_s32(p2, p3); p0 = vpaddq_s32(p0, p2); sumi = vmlaq_s32(sumi, vmovl_high_s16(scale), p0); } { acc = vfmaq_n_f32(acc, vcvtq_f32_s32(sumi), c0); } } template IQK_NOINLINE void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx, nrc_y); for (int ix = 0; ix < nrc_x; ++ix) { deq.new_row(ix); float32x4_t acc[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); //#pragma GCC unroll 4 for (int i = 0; i < nb; ++i) { #ifdef GEMV_Q4K if constexpr (nrc_y == 1 && std::is_same::value) { fusion_mul_mat_qX_K_q8_K_T_y1_d6k( acc[0], deq.x[i].ql, deq.x[i].qh, deq.x[i].scales, GGML_FP16_TO_FP32(deq.x[i].d), q8.y[0][i].qs, q8.y[0][i].bsums, q8.y[0][i].d); } else #endif #ifdef GEMV_Q6K if constexpr (nrc_y == 1 && std::is_same::value) { fusion_mul_mat_qX_K_q8_K_T_y1_d4k( acc[0], deq.x[i].scales, deq.x[i].qs, GGML_FP16_TO_FP32(deq.x[i].d), GGML_FP16_TO_FP32(deq.x[i].dmin), q8.y[0][i].qs, q8.y[0][i].bsums, q8.y[0][i].d); } else #endif { int32x4_t sumi[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) { deq.process_scales(i, q8, acc); deq.prepare(i, 0); deq.compute(q8, i, 0, sumi); deq.prepare(i, 1); deq.compute(q8, i, 1, sumi); } else { if constexpr (Dequantizer::num_blocks() == 8) { auto scales = deq.new_block(i, q8, acc); deq.prepare(i, 0); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); deq.prepare(i, 1); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); } else if constexpr (Dequantizer::num_blocks() == 16) { auto scales = deq.new_block(i, q8, acc); deq.prepare(i, 0); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); deq.prepare(i, 1); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); } else { GGML_ASSERT(false); } } #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); } } #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, vaddvq_f32(acc[iy])); } } } } // ============================= i-quants struct DequantizerIQ4XS final : public BaseDequantizer { static int8x16_t load_values() { static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; return vld1q_s8(iq4nl_values); } DequantizerIQ4XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(load_values()) {} constexpr static int num_blocks() { return 8; } constexpr static bool should_scale_quants() { return false; } inline void new_row(int ix) { x = (const block_iq4_xs *)((const char *)vx + bx*ix); } template inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { (void)q8; (void)acc; d = GGML_FP16_TO_FP32(x[i].d); const uint16_t scales_h = x[i].scales_h; const uint16_t * scales_l = (const uint16_t *)x[i].scales_l; aux32[0] = scales_l[0] | (scales_l[1] << 16); aux32[1] = aux32[0] >> 4; // scl is ordered as 0, 2, 4, 6, 1, 3, 5, 7 uint8x8_t scl8 = vand_u8(vld1_u8((const uint8_t *)aux32), vdup_n_u8(0xf)); uint16_t * aux16 = (uint16_t *)aux32; aux16[0] = scales_h << 4; aux16[1] = scales_h << 2; aux16[2] = scales_h; aux16[3] = scales_h >> 2; // sch is ordered as 0, 4, 1, 5, 2, 6, 3, 7 uint8x8_t sch8 = vand_u8(vld1_u8((const uint8_t *)aux16), vdup_n_u8(0x30)); int8x8_t scales8 = vadd_s8(vreinterpret_s8_u8(vorr_u8(scl8, vtbl1_u8(sch8, vreinterpret_u8_u32(hshuff)))), vdup_n_s8(-32)); // shuffle 0, 2, 4, 6, 1, 3, 5, 7 -> 0, 1, 2, 3, 4, 5, 6, 7 scales8 = vtbl1_s8(scales8, vreinterpret_s8_u32(hshuff)); int16x8_t scales16 = vmovl_s8(scales8); int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))}; return scales; } inline void prepare(int i, int j) { bits.prepare16(x[i].qs+64*j); for (int k = 0; k < 4; ++k) { bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b1.val[k])); bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b2.val[k])); } } Q4bits bits; const int8x16_t values; uint32_t aux32[2]; constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602}; float d; }; struct SimpleBits { uint8x16x4_t b1; uint8x16x4_t b2; }; IQK_ALWAYS_INLINE int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) { int32x4x2_t scales; auto one = vdupq_n_u32(1); scales.val[0] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v1, 28), 1)); scales.val[1] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v2, 28), 1)); return scales; } inline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) { auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127)))); auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >>14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >>21) & 127)))); b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1)); b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2)); } IQK_ALWAYS_INLINE int32x4_t prepare_scales_8(const uint32x4_t& v1) { return vreinterpretq_s32_u32(vsliq_n_u32(vdupq_n_u32(1), vshrq_n_u32(v1, 28), 1)); } struct DequantizerIQ2XXS final : public BaseDequantizer { DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} IQK_ALWAYS_INLINE float new_block(int i) const { return 0.125f * GGML_FP16_TO_FP32(x[i].d); } inline int32x4_t unpack(int i, int j, uint8x16_t * q) const { auto data = vld1q_u32_x2((const uint32_t *)(x[i].qs + 16*j)); prepare_all(data, q); return prepare_scales_8(vuzp2q_u32(data.val[0], data.val[1])); } private: static inline void prepare2(uint8x16_t * b, const uint32_t * bits, const uint64_t * signs) { const uint8_t * idx = (const uint8_t *)bits; b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); apply_signs_2(b, signs, bits[1]); } inline static void prepare_all(const uint32x4x2_t& data, uint8x16_t * quants) { const uint32_t * q2 = (const uint32_t *)data.val; prepare2(quants+0, q2+0, keven_signs); prepare2(quants+2, q2+2, keven_signs); prepare2(quants+4, q2+4, keven_signs); prepare2(quants+6, q2+6, keven_signs); } }; inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) { auto aux = vld1_u8(sc); auto scales_l = vand_u8(aux, vdup_n_u8(0xf)); auto scales_h = vshr_n_u8(aux, 4); auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); auto scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1))); int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) }; return make_wider(scales16); } struct DequantizerIQ2XS final : public BaseDequantizer { DequantizerIQ2XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 16; } constexpr static bool should_scale_quants() { return false; } SimpleBits bits; float d; inline int32x4x4_t new_block(int i) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); prepare_internal(i, 0); return prepare_4bit_scales16(x[i].scales); } inline void prepare(int i, int j) { if (j == 1) prepare_internal(i, 1); } private: static void make2(const uint16_t * qs, uint8x16_t * b) { auto v1 = vcombine_s8(vld1_s8((const int8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_s8((const int8_t *)(iq2xs_grid + (qs[1] & 511)))); auto v2 = vcombine_s8(vld1_s8((const int8_t *)(iq2xs_grid + (qs[2] & 511))), vld1_s8((const int8_t *)(iq2xs_grid + (qs[3] & 511)))); auto s1 = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9)))); auto s2 = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[2] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[3] >> 9)))); b[0] = vreinterpretq_u8_s8(vmulq_s8(v1, s1)); b[1] = vreinterpretq_u8_s8(vmulq_s8(v2, s2)); } inline static void make4(const uint16_t * qs, uint8x16_t * b) { make2(qs + 0, b + 0); make2(qs + 4, b + 2); } IQK_ALWAYS_INLINE void prepare_internal(int i, int j) { make4(x[i].qs + 16*j + 0, bits.b1.val); make4(x[i].qs + 16*j + 8, bits.b2.val); } }; // So, I hate to include this table, but with the GCC 12.3 compiler // bundled in the Cosmopolitan tools, loading the unpacked sign bytes // from this table using the packed 8 sign bits as index is faster than // using the standard trick of vceqq_u8(vandq_u8(bits, mask), mask) to // expand the bits to bytes. static const uint64_t kall_signs[256] = { 0x0101010101010101, 0x01010101010101ff, 0x010101010101ff01, 0x010101010101ffff, 0x0101010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0x0101010101ffffff, 0x01010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0x01010101ff01ffff, 0x01010101ffff0101, 0x01010101ffff01ff, 0x01010101ffffff01, 0x01010101ffffffff, 0x010101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0x010101ff0101ffff, 0x010101ff01ff0101, 0x010101ff01ff01ff, 0x010101ff01ffff01, 0x010101ff01ffffff, 0x010101ffff010101, 0x010101ffff0101ff, 0x010101ffff01ff01, 0x010101ffff01ffff, 0x010101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0x010101ffffffffff, 0x0101ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0x0101ff010101ffff, 0x0101ff0101ff0101, 0x0101ff0101ff01ff, 0x0101ff0101ffff01, 0x0101ff0101ffffff, 0x0101ff01ff010101, 0x0101ff01ff0101ff, 0x0101ff01ff01ff01, 0x0101ff01ff01ffff, 0x0101ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0x0101ff01ffffffff, 0x0101ffff01010101, 0x0101ffff010101ff, 0x0101ffff0101ff01, 0x0101ffff0101ffff, 0x0101ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0x0101ffff01ffffff, 0x0101ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0x0101ffffff01ffff, 0x0101ffffffff0101, 0x0101ffffffff01ff, 0x0101ffffffffff01, 0x0101ffffffffffff, 0x01ff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0x01ff01010101ffff, 0x01ff010101ff0101, 0x01ff010101ff01ff, 0x01ff010101ffff01, 0x01ff010101ffffff, 0x01ff0101ff010101, 0x01ff0101ff0101ff, 0x01ff0101ff01ff01, 0x01ff0101ff01ffff, 0x01ff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0x01ff0101ffffffff, 0x01ff01ff01010101, 0x01ff01ff010101ff, 0x01ff01ff0101ff01, 0x01ff01ff0101ffff, 0x01ff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0x01ff01ff01ffffff, 0x01ff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0x01ff01ffff01ffff, 0x01ff01ffffff0101, 0x01ff01ffffff01ff, 0x01ff01ffffffff01, 0x01ff01ffffffffff, 0x01ffff0101010101, 0x01ffff01010101ff, 0x01ffff010101ff01, 0x01ffff010101ffff, 0x01ffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0x01ffff0101ffffff, 0x01ffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0x01ffff01ff01ffff, 0x01ffff01ffff0101, 0x01ffff01ffff01ff, 0x01ffff01ffffff01, 0x01ffff01ffffffff, 0x01ffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0x01ffffff0101ffff, 0x01ffffff01ff0101, 0x01ffffff01ff01ff, 0x01ffffff01ffff01, 0x01ffffff01ffffff, 0x01ffffffff010101, 0x01ffffffff0101ff, 0x01ffffffff01ff01, 0x01ffffffff01ffff, 0x01ffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0x01ffffffffffffff, 0xff01010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0xff0101010101ffff, 0xff01010101ff0101, 0xff01010101ff01ff, 0xff01010101ffff01, 0xff01010101ffffff, 0xff010101ff010101, 0xff010101ff0101ff, 0xff010101ff01ff01, 0xff010101ff01ffff, 0xff010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0xff010101ffffffff, 0xff0101ff01010101, 0xff0101ff010101ff, 0xff0101ff0101ff01, 0xff0101ff0101ffff, 0xff0101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0xff0101ff01ffffff, 0xff0101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0xff0101ffff01ffff, 0xff0101ffffff0101, 0xff0101ffffff01ff, 0xff0101ffffffff01, 0xff0101ffffffffff, 0xff01ff0101010101, 0xff01ff01010101ff, 0xff01ff010101ff01, 0xff01ff010101ffff, 0xff01ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0xff01ff0101ffffff, 0xff01ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0xff01ff01ff01ffff, 0xff01ff01ffff0101, 0xff01ff01ffff01ff, 0xff01ff01ffffff01, 0xff01ff01ffffffff, 0xff01ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0xff01ffff0101ffff, 0xff01ffff01ff0101, 0xff01ffff01ff01ff, 0xff01ffff01ffff01, 0xff01ffff01ffffff, 0xff01ffffff010101, 0xff01ffffff0101ff, 0xff01ffffff01ff01, 0xff01ffffff01ffff, 0xff01ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0xff01ffffffffffff, 0xffff010101010101, 0xffff0101010101ff, 0xffff01010101ff01, 0xffff01010101ffff, 0xffff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0xffff010101ffffff, 0xffff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0xffff0101ff01ffff, 0xffff0101ffff0101, 0xffff0101ffff01ff, 0xffff0101ffffff01, 0xffff0101ffffffff, 0xffff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0xffff01ff0101ffff, 0xffff01ff01ff0101, 0xffff01ff01ff01ff, 0xffff01ff01ffff01, 0xffff01ff01ffffff, 0xffff01ffff010101, 0xffff01ffff0101ff, 0xffff01ffff01ff01, 0xffff01ffff01ffff, 0xffff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0xffff01ffffffffff, 0xffffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0xffffff010101ffff, 0xffffff0101ff0101, 0xffffff0101ff01ff, 0xffffff0101ffff01, 0xffffff0101ffffff, 0xffffff01ff010101, 0xffffff01ff0101ff, 0xffffff01ff01ff01, 0xffffff01ff01ffff, 0xffffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0xffffff01ffffffff, 0xffffffff01010101, 0xffffffff010101ff, 0xffffffff0101ff01, 0xffffffff0101ffff, 0xffffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0xffffffff01ffffff, 0xffffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0xffffffffff01ffff, 0xffffffffffff0101, 0xffffffffffff01ff, 0xffffffffffffff01, 0xffffffffffffffff, }; struct SignHelper { IQK_ALWAYS_INLINE void apply_signs_1x(uint8x16_t * b, const uint8_t * sign_bits) const { auto s = vreinterpretq_s8_u64(uint64x2_t{kall_signs[sign_bits[0]], kall_signs[sign_bits[1]]}); // Normally we would expect this to be faster, but it isn't. // auto aux = vcombine_u8(vdup_n_u8(sign_bits[0]), vdup_n_u8(sign_bits[1])); // auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1)); b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s)); } // We would need these two if we weren't loading from the unpacked sign table. //const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); //const uint8x16_t m1 = vdupq_n_u8(1); }; struct DequantizerIQ2S final : public BaseDequantizer { DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 16; } constexpr static bool should_scale_quants() { return false; } SimpleBits bits; float d; inline int32x4x4_t new_block(int i) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); prepare_internal(i, 0, bits); return prepare_4bit_scales16(x[i].scales); } inline void prepare(int i, int j) { if (j == 1) prepare_internal(i, 1, bits); } private: static void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { uint32_t aux32[2]; const uint16_t * aux16 = (const uint16_t *)aux32; for (int k = 0; k < 2; ++k) { aux32[1] = (qh[k] << 4) | (qh[k] << 18); aux32[0] = (aux32[1] << 4) & 0x03000300; aux32[1] &= 0x03000300; b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))), vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1])))); b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))), vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3])))); sh.apply_signs_1x(b+2*k+0, sign_bits); sign_bits += 2; sh.apply_signs_1x(b+2*k+1, sign_bits); sign_bits += 2; } } void prepare_internal(int i, int j, SimpleBits& sb) { const auto * qs = x[i].qs + 16*j; const auto * qh = x[i].qh + 4*j; const auto * sign_bits = qs + QK_K/8; make4(sh, sign_bits+0, qs+0, qh+0, sb.b1.val); make4(sh, sign_bits+8, qs+8, qh+2, sb.b2.val); } SignHelper sh; }; struct DequantizerIQ3XXS final : public BaseDequantizer { DequantizerIQ3XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} IQK_ALWAYS_INLINE float new_block(int i) const { return 0.25f * GGML_FP16_TO_FP32(x[i].d); } inline int32x4_t unpack(int i, int j, uint8x16_t * q) const { auto q3data = vld1q_u8_x2(x[i].qs + 32*j); auto gas = vld1q_u32((const uint32_t *)(x[i].qs + QK_K/4 + 16*j)); prepare_block((const uint8_t *)q3data.val, (const uint32_t *)&gas, q); return prepare_scales_8(gas); } private: inline static void make2(const uint8_t * q3, const uint32_t sidx, uint8x16_t * b) { b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]}); b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]}); apply_signs_2(b, keven_signs, sidx); } inline static void prepare_block(const uint8_t * q3, const uint32_t * signs, uint8x16_t * quants) { make2(q3+ 0, signs[0], quants + 0); make2(q3+ 8, signs[1], quants + 2); make2(q3+16, signs[2], quants + 4); make2(q3+24, signs[3], quants + 6); } }; struct DequantizerIQ3S final : public BaseDequantizer { DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 8; } constexpr static bool should_scale_quants() { return false; } SimpleBits bits; float d; inline int32x4x2_t new_block(int i) { d = GGML_FP16_TO_FP32(x[i].d); uint32_t scales32[2]; auto qs = vld1q_u8_x2(x[i].qs); auto signs = vld1q_u8(x[i].signs); prepare_block((const uint8_t *)qs.val, x[i].qh, (const uint8_t *)&signs); std::memcpy(scales32, x[i].scales, 4); scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7 scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400))); auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8)); int32x4x2_t scales; scales.val[0] = vmovl_s16(vget_low_s16(scales16)); scales.val[1] = vmovl_s16(vget_high_s16(scales16)); return scales; } inline void prepare(int i, int j) { if (j == 1) { auto qs = vld1q_u8_x2(x[i].qs + 32); auto signs = vld1q_u8(x[i].signs + 16); prepare_block((const uint8_t *)qs.val, x[i].qh + 4, (const uint8_t *)&signs); } } private: static inline void make2(const SignHelper& sh, const uint8_t * sign_bits, const uint16x8_t& idx_l, uint8_t qh, const int16x8_t& hshift, uint8x16_t * b) { auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256))); const uint16_t * idx = (const uint16_t *)&vindex; b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]}); sh.apply_signs_1x(b+0, sign_bits+0); b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]}); sh.apply_signs_1x(b+1, sign_bits+2); } static inline void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, const int16x8_t& hshift, uint8x16_t * b) { auto idx_l = vld1q_u8(qs); make2(sh, sign_bits+0, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0); make2(sh, sign_bits+4, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2); } static int16x8_t load_shift() { static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; return vld1q_s16(k_shift); } inline void prepare_block(const uint8_t * qs, const uint8_t * qh, const uint8_t * sign_bits) { auto signs = vld1q_u8(sign_bits); auto s = (const uint8_t *)&signs; make4(sh, s + 0, qs+ 0, qh+0, hshift, bits.b1.val); make4(sh, s + 8, qs+16, qh+2, hshift, bits.b2.val); } SignHelper sh; const int16x8_t hshift = load_shift(); }; template IQK_NOINLINE void mul_mat_qX_K_q8_K_IQXXS(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx, nrc_y); uint8x16_t qx[8]; int32x4_t sumi[nrc_y]; float32x4_t acc[nrc_y]; for (int ix = 0; ix < nrc_x; ++ix) { deq.new_row(ix); for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); for (int i = 0; i < nb; ++i) { float d = deq.new_block(i); auto scales = deq.unpack(i, 0, qx); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { sumi[iy] = vdupq_n_s32(0); compute_8_blocks((const int8x16_t *)qx, q8, scales, iy, i, 0, sumi[iy]); } scales = deq.unpack(i, 1, qx); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { compute_8_blocks((const int8x16_t *)qx, q8, scales, iy, i, 1, sumi[iy]); acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, i)), vcvtq_f32_s32(sumi[iy])); } } #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, vaddvq_f32(acc[iy])); } } } // =========================================== Legacy quants template inline float16x4_t load_scales_q0(const Block * x, ggml_half * aux) { for (int k = 0; k < 4; ++k) aux[k] = x[k].d; return vld1_f16((const float16_t *)aux); } template inline float16x8_t load_scales_q1(const Block * x, ggml_half * aux) { if constexpr (std::is_same_v) { for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].s; } } else { for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].m; } } return vld1q_f16((const float16_t *)aux); } struct Q4LegacyBits { template inline void prepare(const Block * x) { for (int i = 0; i < 4; ++i) { auto q4bits = vld1q_u8(x[i].qs); b[2*i+0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b)); b[2*i+1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4)); } } inline void prepare1(const uint8_t * qs, int8x16_t * q) const { auto q4bits = vld1q_u8(qs); q[0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b)); q[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4)); } inline void prepare1(const uint8_t * qs) { prepare1(qs, b); } const uint8x16_t m4b = vdupq_n_u8(0xf); int8x16_t b[8]; }; // One would think this commented out version would do better than the one below // because it offers more opportunities to execute instructions in parallel. // Instead, it runs significantly slower. Why? If the compiler is running out of vector registers // cannot it just do the sequential version below on its own? //inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) { // const auto q8b_1 = vld1q_s8_x2(qs + 0); // auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b_1.val[0]), b[1], q8b_1.val[1]); // const auto q8b_2 = vld1q_s8_x2(qs + 32); // auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b_2.val[0]), b[3], q8b_2.val[1]); // auto p1234 = vpaddq_s32(p12, p34); // const auto q8b_3 = vld1q_s8_x2(qs + 64); // auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b_3.val[0]), b[5], q8b_3.val[1]); // const auto q8b_4 = vld1q_s8_x2(qs + 96); // auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b_4.val[0]), b[7], q8b_4.val[1]); // return vpaddq_s32(p1234, vpaddq_s32(p56, p78)); //} inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) { auto q8b = vld1q_s8_x2(qs + 0); auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b.val[0]), b[1], q8b.val[1]); q8b = vld1q_s8_x2(qs + 32); auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b.val[0]), b[3], q8b.val[1]); auto p1234 = vpaddq_s32(p12, p34); q8b = vld1q_s8_x2(qs + 64); auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b.val[0]), b[5], q8b.val[1]); q8b = vld1q_s8_x2(qs + 96); auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b.val[0]), b[7], q8b.val[1]); return vpaddq_s32(p1234, vpaddq_s32(p56, p78)); } typedef struct { ggml_half d[4]; int8_t qs[4*QK8_0]; } block_q8_0_x4; static_assert(sizeof(block_q8_0_x4) == 4*sizeof(block_q8_0), "wrong q8_0_x4 block size/padding"); template struct Q80 { constexpr static int nrc_y = nrc; Q80(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy); } inline const int8_t * quant_data(int iy, int i) const { const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i; return y4->qs; } inline float16x4_t load_scales(int iy, int i) const { const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i; return vld1_f16((const float16_t *)y4->d); } template inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * /*acc*/) const { auto qx_scales = deq.new_block(i); for (int iy = 0; iy < nrc; ++iy) { auto q8_scales = load_scales(iy, i); sc16[iy] = vmul_f16(qx_scales, q8_scales); } } template inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const { deq.prepare1(i); float d = GGML_FP16_TO_FP32(deq.x[i].d); for (int iy = 0; iy < nrc; ++iy) { auto q8b = vld1q_s8_x2(y[iy][i].qs); auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]); acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p)); } } const block_q8_0 * y[nrc_y]; }; typedef struct { ggml_half d[8]; int8_t qs[4*QK8_1]; } block_q8_1_x4; static_assert(sizeof(block_q8_1_x4) == 4*sizeof(block_q8_1), "wrong q8_1_x4 block size/padding"); template struct Q81 { constexpr static int nrc_y = nrc; Q81(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_1 *)info.src1_row(iy); } inline const int8_t * quant_data(int iy, int i) const { const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i; return y4->qs; } inline float16x8_t load_scales(int iy, int i) const { const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i; return vld1q_f16((const float16_t *)y4->d); } template inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * acc) const { auto qx_scales = deq.new_block(i); for (int iy = 0; iy < nrc; ++iy) { auto q8_scales = load_scales(iy, i); auto m = vmul_f16(vget_high_f16(qx_scales), vget_high_f16(q8_scales)); acc[iy] = vaddq_f32(acc[iy], vcvt_f32_f16(m)); sc16[iy] = vmul_f16(vget_low_f16(qx_scales), vget_low_f16(q8_scales)); } } template inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const { deq.prepare1(i); float d = GGML_FP16_TO_FP32(deq.x[i].d), m = 0.25f*GGML_FP16_TO_FP32(deq.x[i].m); for (int iy = 0; iy < nrc; ++iy) { auto q8b = vld1q_s8_x2(y[iy][i].qs); auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]); acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p)); acc[iy] = vaddq_f32(acc[iy], vdupq_n_f32(m*GGML_FP16_TO_FP32(y[iy][i].s))); } } const block_q8_1 * y[nrc_y]; }; template struct BaseLegacyDequantizer { BaseLegacyDequantizer(const void * vx, size_t bx) : vx(vx), x(nullptr), bx(bx) {} inline void new_row(int ix) { x = (const block_q *)((const char *)vx + bx*ix); } Q4LegacyBits bits; const void * vx; const block_q * x; size_t bx; }; struct DequantizerQ40 final : public BaseLegacyDequantizer { DequantizerQ40(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} inline void prepare1(int i, int8x16_t * q) const { bits.prepare1(x[i].qs, q); q[0] = vaddq_s8(q[0], m8); q[1] = vaddq_s8(q[1], m8); } inline void prepare1(int i) { prepare1(i, bits.b); } inline float16x4_t new_block(int i) { ggml_half aux[4]; for (int k = 0; k < 4; ++k) { aux[k] = x[4*i+k].d; prepare1(4*i+k, bits.b + 2*k); } return vld1_f16((const float16_t *)aux); } const int8x16_t m8 = vdupq_n_s8(-8); //ggml_half aux[4]; }; struct DequantizerQ41 : public BaseLegacyDequantizer { DequantizerQ41(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} inline void prepare1(int i) { bits.prepare1(x[i].qs); } inline float16x8_t new_block(int i) { uint32_t aux32[4]; const uint32_t * s32 = (const uint32_t *)&x[4*i].d; for (int k = 0; k < 4; ++k) { aux32[k] = *s32; s32 += sizeof(block_q4_1)/4; bits.prepare1(x[4*i+k].qs, bits.b + 2*k); } return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle))); } // Leaving this commented out attempt to be reminded that I already tried this. // It has basically the same performance as the version above. //inline float16x8_t new_block(int i) { // uint32x4_t scales = {}; // const block_q4_1 * xi = x + 4*i; // const uint32_t * s32 = (const uint32_t *)&xi->d; // scales = vsetq_lane_u32(*s32, scales, 0); s32 += sizeof(block_q4_1)/4; // bits.prepare1(xi[0].qs, bits.b + 0); // scales = vsetq_lane_u32(*s32, scales, 1); s32 += sizeof(block_q4_1)/4; // bits.prepare1(xi[1].qs, bits.b + 2); // scales = vsetq_lane_u32(*s32, scales, 2); s32 += sizeof(block_q4_1)/4; // bits.prepare1(xi[2].qs, bits.b + 4); // scales = vsetq_lane_u32(*s32, scales, 3); // bits.prepare1(xi[3].qs, bits.b + 6); // return vreinterpretq_f16_u8(vqtbl1q_u8(vreinterpretq_u8_u32(scales), vreinterpretq_u8_u64(shuffle))); //} const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302}; }; struct HighBit5Legacy { inline uint8x16_t to_bytes(const uint8_t * qh) const { uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle); return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vreinterpretq_u8_u64(mask)); } inline uint8x16_t to_negated_bytes(const uint8_t * qh) const { uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle); return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vdupq_n_u8(0)); } const uint64x2_t mask = vdupq_n_u64(0x8040201008040201); const uint8x16_t shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); }; struct DequantizerQ50 final : public BaseLegacyDequantizer { DequantizerQ50(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} inline void prepare1(int i, int8x16_t * q) const { bits.prepare1(x[i].qs, q); auto qh = x[i].qh; q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0)))); q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2)))); } inline void prepare1(int i) { prepare1(i, bits.b); } inline float16x4_t new_block(int i) { ggml_half aux[4]; for (int k = 0; k < 4; ++k) { aux[k] = x[4*i+k].d; prepare1(4*i+k, bits.b + 2*k); } return vld1_f16((const float16_t *)aux); } HighBit5Legacy hbits; const uint8x16_t mh = vdupq_n_u8(0xf0); }; struct DequantizerQ80 final : public BaseLegacyDequantizer { DequantizerQ80(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} inline void prepare1(int i) { bits.b[0] = vld1q_s8(x[i].qs); bits.b[1] = vld1q_s8(x[i].qs+16); } inline float16x4_t new_block(int i) { ggml_half aux[4]; for (int k = 0; k < 4; ++k) { aux[k] = x[4*i+k].d; bits.b[2*k+0] = vld1q_s8(x[4*i+k].qs); bits.b[2*k+1] = vld1q_s8(x[4*i+k].qs+16); } return vld1_f16((const float16_t *)aux); } }; struct DequantizerQ51 final : public BaseLegacyDequantizer { DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} inline void prepare1(int i, int8x16_t * q) const { bits.prepare1(x[i].qs, q); auto qh = x[i].qh; q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_bytes(qh+0)))); q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_bytes(qh+2)))); } inline void prepare1(int i) { bits.prepare1(x[i].qs, bits.b); } inline float16x8_t new_block(int i) { uint32_t aux32[4]; const uint32_t * s32 = (const uint32_t *)&x[4*i].d; for (int k = 0; k < 4; ++k) { aux32[k] = *s32; s32 += sizeof(block_q5_1)/4; prepare1(4*i+k, bits.b + 2*k); } return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle))); } HighBit5Legacy hbits; const uint8x16_t mh = vdupq_n_u8(0x10); const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302}; }; template inline void sum_4(int i, Dequantizer& deq, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) { for (int iy = 0; iy < Q8::nrc_y; ++iy) { auto pall = sum_4_blocks(deq.bits.b, q8.quant_data(iy, i)); auto scale = vcvt_f32_f16(sc16[iy]); acc[iy] = vmlaq_f32(acc[iy], scale, vcvtq_f32_s32(pall)); } } template inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& info, int nrc_x) { const int nb = n / QK4_1; float16x4_t sc16[Q8::nrc_y]; for (int ix = 0; ix < nrc_x; ++ix) { deq.new_row(ix); float32x4_t acc[Q8::nrc_y]; for (int iy = 0; iy < Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); for (int i = 0; i < nb/4; ++i) { q8.process_scales(i, deq, sc16, acc); sum_4(i, deq, q8, sc16, acc); } for (int i = 4*(nb/4); i < nb; ++i) { q8.process_1_block(i, deq, acc); } for (int iy = 0; iy < Q8::nrc_y; ++iy) { info.store(ix, iy, vaddvq_f32(acc[iy])); } } } template inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) { const int nb = n / QK4_1; float16x4_t sc16[2]; for (int ix = 0; ix < nrc_x; ++ix) { deq1.new_row(ix); deq2.new_row(ix); float32x4_t acc[2] = { vdupq_n_f32(0.f), vdupq_n_f32(0.f) }; for (int i = 0; i < nb/8; ++i) { q8.process_scales(2*i+0, deq1, sc16+0, acc+0); q8.process_scales(2*i+1, deq2, sc16+1, acc+1); sum_4(2*i+0, deq1, q8, sc16+0, acc+0); sum_4(2*i+1, deq2, q8, sc16+1, acc+1); } for (int i = 2*(nb/8); i < nb/4; ++i) { q8.process_scales(i, deq1, sc16, acc); sum_4(i, deq1, q8, sc16, acc); } for (int i = 4*(nb/4); i < nb; ++i) { q8.process_1_block(i, deq1, acc); } info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1]))); } } template static void IQK_NOINLINE mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { Q81 q8(info); if constexpr (nrc_y == 1) { Dequantizer deq1(vx, bx), deq2(vx, bx); mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); } else { Dequantizer deq(vx, bx); mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); } } template static void IQK_NOINLINE mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { Q80 q8(info); if constexpr (nrc_y == 1) { Dequantizer deq1(vx, bx), deq2(vx, bx); mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); } else { Dequantizer deq(vx, bx); mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); } } template static void IQK_NOINLINE mul_mat_qX_1_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { Dequantizer deq1(vx, bx), deq2(vx, bx); Q81<1> q8(info); mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); } template static void IQK_NOINLINE mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { Dequantizer deq1(vx, bx), deq2(vx, bx); Q80<1> q8(info); mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x); } template void MulMat::set_functions(MulMat& m) { if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qX_0_q8_0; m.funcs[1] = mul_mat_qX_0_q8_0; m.funcs[2] = mul_mat_qX_0_q8_0; m.funcs[3] = mul_mat_qX_0_q8_0; m.funcs[4] = mul_mat_qX_0_q8_0; m.funcs[5] = mul_mat_qX_0_q8_0; m.funcs[6] = mul_mat_qX_0_q8_0; m.funcs[7] = mul_mat_qX_0_q8_0; } else if constexpr (std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qX_1_q8_1; m.funcs[1] = mul_mat_qX_1_q8_1; m.funcs[2] = mul_mat_qX_1_q8_1; m.funcs[3] = mul_mat_qX_1_q8_1; m.funcs[4] = mul_mat_qX_1_q8_1; m.funcs[5] = mul_mat_qX_1_q8_1; m.funcs[6] = mul_mat_qX_1_q8_1; m.funcs[7] = mul_mat_qX_1_q8_1; } else if constexpr (std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qX_K_q8_K_IQXXS<1, Dequantizer>; m.funcs[1] = mul_mat_qX_K_q8_K_IQXXS<2, Dequantizer>; m.funcs[2] = mul_mat_qX_K_q8_K_IQXXS<3, Dequantizer>; m.funcs[3] = mul_mat_qX_K_q8_K_IQXXS<4, Dequantizer>; m.funcs[4] = mul_mat_qX_K_q8_K_IQXXS<5, Dequantizer>; m.funcs[5] = mul_mat_qX_K_q8_K_IQXXS<6, Dequantizer>; m.funcs[6] = mul_mat_qX_K_q8_K_IQXXS<7, Dequantizer>; m.funcs[7] = mul_mat_qX_K_q8_K_IQXXS<8, Dequantizer>; } else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qX_K_q8_K_IQ<1, Dequantizer>; m.funcs[1] = mul_mat_qX_K_q8_K_IQ<2, Dequantizer>; m.funcs[2] = mul_mat_qX_K_q8_K_IQ<3, Dequantizer>; m.funcs[3] = mul_mat_qX_K_q8_K_IQ<4, Dequantizer>; m.funcs[4] = mul_mat_qX_K_q8_K_IQ<5, Dequantizer>; m.funcs[5] = mul_mat_qX_K_q8_K_IQ<6, Dequantizer>; m.funcs[6] = mul_mat_qX_K_q8_K_IQ<7, Dequantizer>; m.funcs[7] = mul_mat_qX_K_q8_K_IQ<8, Dequantizer>; } else { m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>; m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>; m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>; m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>; m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>; m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>; m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>; m.funcs_v2 = mul_mat_qX_K_q8_K_T_v2; } } bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int Ny) { row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); (void)Ny; // Uncommenting out this would disable iqk_mul_mat for matrix x vector multiplications. //if (Ny == 1 && (typeA == GGML_TYPE_IQ2_XXS || typeA == GGML_TYPE_IQ2_XS || typeA == GGML_TYPE_IQ2_S || // typeA == GGML_TYPE_IQ3_XXS || typeA == GGML_TYPE_IQ3_S)) return false; switch (typeA) { case GGML_TYPE_Q2_K: MulMat::set_functions(m); break; case GGML_TYPE_Q3_K: MulMat::set_functions(m); break; case GGML_TYPE_Q4_K: MulMat::set_functions(m); break; case GGML_TYPE_Q5_K: MulMat::set_functions(m); break; case GGML_TYPE_Q6_K: MulMat::set_functions(m); break; case GGML_TYPE_IQ4_XS: MulMat::set_functions(m); break; case GGML_TYPE_IQ3_S: MulMat::set_functions(m); break; case GGML_TYPE_IQ3_XXS: MulMat::set_functions(m); break; case GGML_TYPE_IQ2_S: MulMat::set_functions(m); break; case GGML_TYPE_IQ2_XS: MulMat::set_functions(m); break; case GGML_TYPE_IQ2_XXS: MulMat::set_functions(m); break; case GGML_TYPE_Q4_0: MulMat::set_functions(m); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); break; case GGML_TYPE_Q4_1: MulMat::set_functions(m); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00); break; case GGML_TYPE_Q5_0: MulMat::set_functions(m); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); break; case GGML_TYPE_Q5_1: MulMat::set_functions(m); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00); break; case GGML_TYPE_Q8_0: MulMat::set_functions(m); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); break; default: return false; } return true; } } #endif // __x86_64__ or __aarch64__ ================================================ FILE: archive/third_party/llamafile/iqk_mul_mat_arm82.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat_arm82.cpp // Copyrigth 2024 Iwan Kawrakow. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #ifdef __aarch64__ #define iqk_mul_mat iqk_mul_mat_arm82 #define iqk_mul_mat_moe iqk_mul_mat_moe_arm82 #include "iqk_mul_mat.inc" #endif // __aarch64__ ================================================ FILE: archive/third_party/llamafile/iqk_mul_mat_x86.inc ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat.inc // Copyrigth 2024 Iwan Kawrakow - Apache 2.0 Licens // with additions from // https://github.com/ikawrakow/ik_llama.cpp/blob/main/ggml/src/iqk/iqk_mul_mat.cpp // Copyrigth 2024-2025 Iwan Kawrakow - MIT Licens // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp fenc=utf-8 :vi // // Copyright 2024 Iwan Kawrakow // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // // Copyright (C) 2024-2025 Iwan Kawrakow // MIT license // SPDX-License-Identifier: MIT // #include #include #if defined __x86_64__ || defined __aarch64__ || defined(_M_X64) #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-quants.h" #include "sgemm.h" // For i-quants, I had to explicitely specify which // functions to inline / not inline (at least for some // of the functions), else performance would be significantly // lower. This is worrysome as things can change with, // e.g., a different compiler version or running on a different // CPU. #ifdef _MSC_VER #define IQK_NOINLINE __declspec(noinline) #define IQK_ALWAYS_INLINE inline #else #define IQK_NOINLINE __attribute__((__noinline__)) #define IQK_ALWAYS_INLINE __attribute__((always_inline)) #endif #define GGML_COMMON_IMPL_C #include "llama.cpp/ggml-common.h" // clang-format off // This matrix - vector and matrix - matrix multiplication implementation // for legacy quants, k-quants and i-quants makes prompt processing 150-200% // (legacy and k-quants) or 250-400% (i-quants) faster. // compared to mainline llama.cpp (and llamafile). // It provides implementations for ARM_NEON (all quants) and AVX2 // (all quants except sub-4 bit i-quants). // // Main idea is that unpacking the quants and the block scales to // be ready for dot products with the corresponding Q8_Y quants // takes time (here 'Y' stands for K, 0, or 1, depending on quantization type). // Hence, if we are performing a QX x Q8_Y matrix matrix // multiplication (as needed for prompt processing), we can get // a significant speedup by reusing the unpacked QX quants and scales // for multiplication with several Q8_K columns. We also achieve fewer // loads from memory, which is the main purpose of tiling in general // purpose matrix multiplication packages. #include #include #endif constexpr ggml_type GGML_TYPE_Q8_0_X4 = static_cast(98); constexpr ggml_type GGML_TYPE_Q8_1_X4 = static_cast(99); namespace { typedef struct { int32_t i1; int32_t i2; } mmid_row_mapping; struct DataInfo { float * s; const char * cy; size_t bs; size_t by; int cur_y = 0; int ne11; const mmid_row_mapping * row_mapping = nullptr; size_t bs2 = 0; inline const char * src1_row(int iy) const { if (!row_mapping) return cy + (cur_y + iy)*by; int i11 = row_mapping[cur_y + iy].i1 % ne11; int i12 = row_mapping[cur_y + iy].i2; return cy + (i11 + i12*ne11)*by; } inline void store(int ix, int iy, float result) const { *(dst_row(iy) + ix) = result; //dst_row(iy)[ix] = result; } inline float * dst_row(int iy) const { if (!row_mapping) return s + (cur_y + iy)*bs; int i12 = row_mapping[cur_y + iy].i2; int i1 = row_mapping[cur_y + iy].i1; int i2 = i12; return s + i1*bs + i2*bs2; } }; /* moonll change param for set_mul_mat add func16 */ typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x); struct MulMat { std::array funcs = {}; mul_mat_t func16 = nullptr; //inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) { IQK_NOINLINE void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) { constexpr int k_x_step = 64; // This works best on my Ryzen-7950X and M2 Max CPUs (but differences to other tile size are small) // copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L162 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow if (func16 && nrc_y >= 16) { int n_step = (nrc_y - info.cur_y)/16; for (int ix = 0; ix < nrc_x; ix += k_x_step) { auto this_info = info; this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; for (int iy = 0; iy < n_step; ++iy) { func16(n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x); this_info.cur_y += 16; } } info.cur_y += 16 * n_step; if (info.cur_y == nrc_y) return; } // end copy int n_step = (nrc_y - info.cur_y)/funcs.size(); if (n_step > 0) { for (int ix = 0; ix < nrc_x; ix += k_x_step) { auto this_info = info; this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; for (int iy = 0; iy < n_step; ++iy) { funcs.back()(n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x); this_info.cur_y += funcs.size(); } } info.cur_y += funcs.size() * n_step; } int n_left = nrc_y - info.cur_y; if (n_left > 0) { funcs[n_left-1](n, vx, bx, info, nrc_x); } } static IQK_NOINLINE bool set_mul_mat(int typeA, int typeB,int ne00, MulMat& mm, int Ny); private: template static IQK_NOINLINE void set_functions(MulMat& m); }; inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) { const uint16_t * scales = (const uint16_t *)scales8; const uint32_t a0 = scales[0] | (scales[1] << 16); const uint32_t a1 = scales[2] | (scales[3] << 16); const uint32_t a2 = scales[4] | (scales[5] << 16); aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030); aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030); aux32[2] = a1 & 0x3f3f3f3f; aux32[0] = a0 & 0x3f3f3f3f; } /* moonll decoding tables */ // copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L570 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow #ifdef __AVX2__ static const uint64_t iq1s_grid_us[2048] = { 0x0000000000000000, 0x0000000000000002, 0x0000000000000101, 0x0000000000000200, 0x0000000000000202, 0x0000000000010001, 0x0000000000010101, 0x0000000000020000, 0x0000000000020002, 0x0000000000020200, 0x0000000000020202, 0x0000000001000101, 0x0000000001010001, 0x0000000001010100, 0x0000000001010102, 0x0000000001020101, 0x0000000002000000, 0x0000000002000002, 0x0000000002000200, 0x0000000002000202, 0x0000000002010101, 0x0000000002020000, 0x0000000002020002, 0x0000000002020200, 0x0000000002020202, 0x0000000100000100, 0x0000000100000101, 0x0000000100010001, 0x0000000100010100, 0x0000000100010102, 0x0000000100010201, 0x0000000100010202, 0x0000000100020101, 0x0000000101000001, 0x0000000101000102, 0x0000000101000201, 0x0000000101010002, 0x0000000101010101, 0x0000000101010202, 0x0000000101020001, 0x0000000101020100, 0x0000000101020102, 0x0000000101020200, 0x0000000102000101, 0x0000000102010001, 0x0000000102010100, 0x0000000102010102, 0x0000000102020101, 0x0000000200000000, 0x0000000200000002, 0x0000000200000200, 0x0000000200000202, 0x0000000200010101, 0x0000000200020000, 0x0000000200020002, 0x0000000200020200, 0x0000000200020202, 0x0000000201000101, 0x0000000201010001, 0x0000000201010201, 0x0000000201020100, 0x0000000201020201, 0x0000000202000000, 0x0000000202000002, 0x0000000202000200, 0x0000000202000202, 0x0000000202010001, 0x0000000202010101, 0x0000000202010201, 0x0000000202020000, 0x0000000202020002, 0x0000000202020200, 0x0000000202020202, 0x0000010000010001, 0x0000010000010100, 0x0000010000010102, 0x0000010000020101, 0x0000010001000001, 0x0000010001000201, 0x0000010001010101, 0x0000010001010202, 0x0000010001020100, 0x0000010001020101, 0x0000010002010001, 0x0000010002010201, 0x0000010002020101, 0x0000010100000001, 0x0000010100000100, 0x0000010100000101, 0x0000010100000102, 0x0000010100010101, 0x0000010100010200, 0x0000010100010202, 0x0000010100020201, 0x0000010101000000, 0x0000010101000101, 0x0000010101000202, 0x0000010101010000, 0x0000010101010001, 0x0000010101010100, 0x0000010101010101, 0x0000010101010102, 0x0000010101010201, 0x0000010101020000, 0x0000010101020002, 0x0000010101020101, 0x0000010101020200, 0x0000010101020202, 0x0000010102000001, 0x0000010102010001, 0x0000010102010101, 0x0000010102010200, 0x0000010102010202, 0x0000010102020001, 0x0000010102020100, 0x0000010102020101, 0x0000010102020102, 0x0000010102020201, 0x0000010200010100, 0x0000010200010201, 0x0000010201000001, 0x0000010201000100, 0x0000010201010000, 0x0000010201010002, 0x0000010201010101, 0x0000010201010200, 0x0000010201020000, 0x0000010201020001, 0x0000010201020102, 0x0000010201020201, 0x0000010202000101, 0x0000010202010001, 0x0000010202010100, 0x0000010202010201, 0x0000020000000000, 0x0000020000000002, 0x0000020000000200, 0x0000020000000202, 0x0000020000010101, 0x0000020000020000, 0x0000020000020002, 0x0000020000020200, 0x0000020000020202, 0x0000020001000101, 0x0000020001010001, 0x0000020001010102, 0x0000020001020101, 0x0000020002000000, 0x0000020002000002, 0x0000020002000200, 0x0000020002000202, 0x0000020002010101, 0x0000020002020000, 0x0000020002020002, 0x0000020002020200, 0x0000020002020202, 0x0000020100000101, 0x0000020100010001, 0x0000020100010100, 0x0000020100010201, 0x0000020100020100, 0x0000020100020101, 0x0000020101000001, 0x0000020101010000, 0x0000020101010001, 0x0000020101010101, 0x0000020101020001, 0x0000020101020100, 0x0000020101020201, 0x0000020102010001, 0x0000020102010100, 0x0000020102010102, 0x0000020102010201, 0x0000020102020101, 0x0000020200000000, 0x0000020200000002, 0x0000020200000200, 0x0000020200000202, 0x0000020200010101, 0x0000020200020000, 0x0000020200020002, 0x0000020200020200, 0x0000020200020202, 0x0000020201000101, 0x0000020201010001, 0x0000020201010201, 0x0000020201020001, 0x0000020201020101, 0x0000020202000000, 0x0000020202000002, 0x0000020202000101, 0x0000020202000200, 0x0000020202000202, 0x0000020202010101, 0x0000020202020000, 0x0000020202020002, 0x0000020202020200, 0x0000020202020202, 0x0001000000010000, 0x0001000000010001, 0x0001000000010100, 0x0001000000010201, 0x0001000000020100, 0x0001000000020101, 0x0001000001000001, 0x0001000001000100, 0x0001000001010000, 0x0001000001010101, 0x0001000001010200, 0x0001000001020001, 0x0001000001020100, 0x0001000001020101, 0x0001000001020201, 0x0001000002010001, 0x0001000002010100, 0x0001000002010102, 0x0001000002020001, 0x0001000002020101, 0x0001000100000001, 0x0001000100000100, 0x0001000100000102, 0x0001000100000201, 0x0001000100010000, 0x0001000100010002, 0x0001000100010101, 0x0001000100010200, 0x0001000100020001, 0x0001000100020100, 0x0001000100020201, 0x0001000101000101, 0x0001000101000202, 0x0001000101010000, 0x0001000101010001, 0x0001000101010002, 0x0001000101010100, 0x0001000101010101, 0x0001000101010102, 0x0001000101010201, 0x0001000101020000, 0x0001000101020101, 0x0001000102000100, 0x0001000102010002, 0x0001000102010101, 0x0001000102020001, 0x0001000102020100, 0x0001000200010001, 0x0001000200010100, 0x0001000200010102, 0x0001000200020101, 0x0001000201000000, 0x0001000201000102, 0x0001000201000201, 0x0001000201010002, 0x0001000201010101, 0x0001000201010200, 0x0001000201010202, 0x0001000201020100, 0x0001000201020102, 0x0001000202000101, 0x0001000202010001, 0x0001000202010100, 0x0001000202010102, 0x0001000202020101, 0x0001010000000001, 0x0001010000000102, 0x0001010000000201, 0x0001010000010100, 0x0001010000010101, 0x0001010000010200, 0x0001010000010201, 0x0001010000020001, 0x0001010000020102, 0x0001010001000001, 0x0001010001000101, 0x0001010001000102, 0x0001010001000200, 0x0001010001000202, 0x0001010001010001, 0x0001010001010100, 0x0001010001010101, 0x0001010001010102, 0x0001010001010201, 0x0001010001020002, 0x0001010001020101, 0x0001010001020200, 0x0001010002000100, 0x0001010002000201, 0x0001010002010000, 0x0001010002010100, 0x0001010002010101, 0x0001010002010200, 0x0001010002010201, 0x0001010002010202, 0x0001010002020001, 0x0001010002020100, 0x0001010002020101, 0x0001010002020201, 0x0001010100000002, 0x0001010100000101, 0x0001010100000202, 0x0001010100010001, 0x0001010100010100, 0x0001010100010101, 0x0001010100010102, 0x0001010100010201, 0x0001010100020000, 0x0001010100020002, 0x0001010100020101, 0x0001010100020200, 0x0001010100020202, 0x0001010101000001, 0x0001010101000100, 0x0001010101000101, 0x0001010101000102, 0x0001010101010001, 0x0001010101010002, 0x0001010101010100, 0x0001010101010101, 0x0001010101010102, 0x0001010101010201, 0x0001010101010202, 0x0001010101020001, 0x0001010101020100, 0x0001010101020101, 0x0001010101020102, 0x0001010101020201, 0x0001010102000000, 0x0001010102000002, 0x0001010102000100, 0x0001010102000101, 0x0001010102000200, 0x0001010102000202, 0x0001010102010000, 0x0001010102010001, 0x0001010102010100, 0x0001010102010101, 0x0001010102010102, 0x0001010102010201, 0x0001010102010202, 0x0001010102020000, 0x0001010102020002, 0x0001010102020101, 0x0001010200000001, 0x0001010200000100, 0x0001010200000101, 0x0001010200000102, 0x0001010200010101, 0x0001010200010102, 0x0001010200010200, 0x0001010200010202, 0x0001010200020001, 0x0001010200020102, 0x0001010201000000, 0x0001010201000002, 0x0001010201000100, 0x0001010201000101, 0x0001010201000200, 0x0001010201000202, 0x0001010201010001, 0x0001010201010101, 0x0001010201010102, 0x0001010201010200, 0x0001010201010201, 0x0001010201020001, 0x0001010201020100, 0x0001010201020101, 0x0001010201020200, 0x0001010201020201, 0x0001010201020202, 0x0001010202000102, 0x0001010202000202, 0x0001010202010002, 0x0001010202010101, 0x0001010202020100, 0x0001010202020201, 0x0001020000010001, 0x0001020000010102, 0x0001020000020101, 0x0001020001000001, 0x0001020001000100, 0x0001020001000102, 0x0001020001000201, 0x0001020001010000, 0x0001020001010101, 0x0001020001010200, 0x0001020001010202, 0x0001020001020000, 0x0001020001020001, 0x0001020001020100, 0x0001020001020102, 0x0001020001020201, 0x0001020002000101, 0x0001020002010001, 0x0001020002010100, 0x0001020002020101, 0x0001020100010000, 0x0001020100010002, 0x0001020100010101, 0x0001020100010202, 0x0001020100020001, 0x0001020100020101, 0x0001020101000002, 0x0001020101000100, 0x0001020101000101, 0x0001020101000200, 0x0001020101010001, 0x0001020101010100, 0x0001020101010101, 0x0001020101010102, 0x0001020101010201, 0x0001020101010202, 0x0001020101020000, 0x0001020101020101, 0x0001020101020202, 0x0001020102000201, 0x0001020102010001, 0x0001020102010002, 0x0001020102010101, 0x0001020102010200, 0x0001020102020001, 0x0001020102020102, 0x0001020102020201, 0x0001020200000201, 0x0001020200010102, 0x0001020200020100, 0x0001020200020102, 0x0001020201000100, 0x0001020201000102, 0x0001020201000201, 0x0001020201010000, 0x0001020201010002, 0x0001020201010101, 0x0001020201010200, 0x0001020201020001, 0x0001020201020102, 0x0001020201020201, 0x0001020202000101, 0x0001020202010001, 0x0001020202010102, 0x0001020202010202, 0x0002000000000000, 0x0002000000000002, 0x0002000000000200, 0x0002000000000202, 0x0002000000010101, 0x0002000000020000, 0x0002000000020002, 0x0002000000020101, 0x0002000000020200, 0x0002000000020202, 0x0002000001000101, 0x0002000001010001, 0x0002000001010201, 0x0002000001020001, 0x0002000001020101, 0x0002000002000000, 0x0002000002000002, 0x0002000002000200, 0x0002000002000202, 0x0002000002010101, 0x0002000002020000, 0x0002000002020002, 0x0002000002020101, 0x0002000002020200, 0x0002000002020202, 0x0002000100000101, 0x0002000100010001, 0x0002000100010100, 0x0002000100010201, 0x0002000100020101, 0x0002000101000002, 0x0002000101000100, 0x0002000101000201, 0x0002000101010101, 0x0002000101010200, 0x0002000101010202, 0x0002000101020001, 0x0002000101020100, 0x0002000101020101, 0x0002000101020102, 0x0002000102000101, 0x0002000102010000, 0x0002000102010102, 0x0002000102010201, 0x0002000102020101, 0x0002000200000001, 0x0002000200000200, 0x0002000200000202, 0x0002000200010001, 0x0002000200010101, 0x0002000200020000, 0x0002000200020002, 0x0002000200020200, 0x0002000200020202, 0x0002000201000101, 0x0002000201010001, 0x0002000201010102, 0x0002000201010201, 0x0002000201020101, 0x0002000202000001, 0x0002000202000200, 0x0002000202000202, 0x0002000202010001, 0x0002000202010101, 0x0002000202020000, 0x0002000202020002, 0x0002000202020200, 0x0002000202020202, 0x0002010000000101, 0x0002010000010100, 0x0002010000010102, 0x0002010000010201, 0x0002010000020101, 0x0002010001000100, 0x0002010001000101, 0x0002010001000102, 0x0002010001000201, 0x0002010001010002, 0x0002010001010101, 0x0002010001010200, 0x0002010001010202, 0x0002010001020102, 0x0002010002000101, 0x0002010002010001, 0x0002010002010100, 0x0002010002010201, 0x0002010002020001, 0x0002010002020101, 0x0002010100000201, 0x0002010100010101, 0x0002010100020001, 0x0002010100020201, 0x0002010101000000, 0x0002010101000101, 0x0002010101000200, 0x0002010101010001, 0x0002010101010100, 0x0002010101010101, 0x0002010101010201, 0x0002010101020002, 0x0002010101020101, 0x0002010101020200, 0x0002010102000201, 0x0002010102010000, 0x0002010102010100, 0x0002010102010101, 0x0002010102010200, 0x0002010102010202, 0x0002010102020001, 0x0002010102020100, 0x0002010102020102, 0x0002010102020201, 0x0002010200000101, 0x0002010200010000, 0x0002010200010002, 0x0002010200010201, 0x0002010200020101, 0x0002010201000001, 0x0002010201000201, 0x0002010201010101, 0x0002010201020000, 0x0002010201020001, 0x0002010201020201, 0x0002010202000100, 0x0002010202000102, 0x0002010202010000, 0x0002010202010202, 0x0002020000000000, 0x0002020000000002, 0x0002020000000200, 0x0002020000000202, 0x0002020000010101, 0x0002020000020000, 0x0002020000020002, 0x0002020000020200, 0x0002020000020202, 0x0002020001000101, 0x0002020001010001, 0x0002020001010100, 0x0002020001020101, 0x0002020002000000, 0x0002020002000002, 0x0002020002000200, 0x0002020002000202, 0x0002020002020000, 0x0002020002020002, 0x0002020002020200, 0x0002020002020202, 0x0002020100000201, 0x0002020100010001, 0x0002020100010100, 0x0002020100010201, 0x0002020100020101, 0x0002020101000102, 0x0002020101000201, 0x0002020101010002, 0x0002020101010101, 0x0002020101020001, 0x0002020101020100, 0x0002020101020102, 0x0002020101020201, 0x0002020102000101, 0x0002020102010000, 0x0002020102010102, 0x0002020102010201, 0x0002020102020100, 0x0002020102020101, 0x0002020200000000, 0x0002020200000002, 0x0002020200000200, 0x0002020200000202, 0x0002020200020000, 0x0002020200020002, 0x0002020200020200, 0x0002020200020202, 0x0002020201000101, 0x0002020201010001, 0x0002020201010102, 0x0002020201010201, 0x0002020201020101, 0x0002020202000000, 0x0002020202000002, 0x0002020202000200, 0x0002020202000202, 0x0002020202010101, 0x0002020202020000, 0x0002020202020002, 0x0002020202020200, 0x0002020202020202, 0x0100000000000101, 0x0100000000010001, 0x0100000000010102, 0x0100000000020101, 0x0100000001000201, 0x0100000001010002, 0x0100000001010101, 0x0100000001010200, 0x0100000001010202, 0x0100000001020001, 0x0100000001020100, 0x0100000001020102, 0x0100000002010100, 0x0100000002010201, 0x0100000002020001, 0x0100000002020102, 0x0100000100000000, 0x0100000100000001, 0x0100000100000100, 0x0100000100000102, 0x0100000100000201, 0x0100000100010002, 0x0100000100010101, 0x0100000100010102, 0x0100000100010200, 0x0100000100010202, 0x0100000100020001, 0x0100000100020102, 0x0100000100020201, 0x0100000101000101, 0x0100000101000200, 0x0100000101000202, 0x0100000101010001, 0x0100000101010100, 0x0100000101010101, 0x0100000101010102, 0x0100000101010201, 0x0100000101010202, 0x0100000101020101, 0x0100000101020200, 0x0100000101020202, 0x0100000102000001, 0x0100000102000100, 0x0100000102000102, 0x0100000102010000, 0x0100000102010002, 0x0100000102010101, 0x0100000102020000, 0x0100000102020001, 0x0100000102020002, 0x0100000200000101, 0x0100000200010001, 0x0100000200010100, 0x0100000200010102, 0x0100000200020101, 0x0100000201000001, 0x0100000201010002, 0x0100000201010101, 0x0100000201010202, 0x0100000201020100, 0x0100000201020201, 0x0100000202000201, 0x0100000202010100, 0x0100000202020101, 0x0100010000000001, 0x0100010000010101, 0x0100010000010201, 0x0100010000020201, 0x0100010001000101, 0x0100010001000200, 0x0100010001000202, 0x0100010001010001, 0x0100010001010100, 0x0100010001010101, 0x0100010001010102, 0x0100010001020001, 0x0100010001020002, 0x0100010001020101, 0x0100010001020200, 0x0100010001020202, 0x0100010002000001, 0x0100010002000102, 0x0100010002000201, 0x0100010002010000, 0x0100010002010002, 0x0100010002010101, 0x0100010002020000, 0x0100010002020001, 0x0100010002020201, 0x0100010100000001, 0x0100010100000002, 0x0100010100000101, 0x0100010100000202, 0x0100010100010001, 0x0100010100010100, 0x0100010100010101, 0x0100010100010102, 0x0100010100010201, 0x0100010100020000, 0x0100010100020101, 0x0100010100020202, 0x0100010101000001, 0x0100010101000100, 0x0100010101000101, 0x0100010101000102, 0x0100010101000201, 0x0100010101010000, 0x0100010101010001, 0x0100010101010100, 0x0100010101010101, 0x0100010101010102, 0x0100010101010200, 0x0100010101010201, 0x0100010101020001, 0x0100010101020100, 0x0100010101020101, 0x0100010101020102, 0x0100010101020201, 0x0100010102000002, 0x0100010102000100, 0x0100010102000101, 0x0100010102000200, 0x0100010102010001, 0x0100010102010100, 0x0100010102010101, 0x0100010102010102, 0x0100010102010201, 0x0100010102010202, 0x0100010102020101, 0x0100010102020200, 0x0100010102020202, 0x0100010200000001, 0x0100010200000101, 0x0100010200000201, 0x0100010200010100, 0x0100010200010101, 0x0100010200010200, 0x0100010200010202, 0x0100010200020001, 0x0100010200020100, 0x0100010200020201, 0x0100010201000000, 0x0100010201000002, 0x0100010201000101, 0x0100010201000200, 0x0100010201010000, 0x0100010201010001, 0x0100010201010002, 0x0100010201010101, 0x0100010201010102, 0x0100010201010201, 0x0100010201020002, 0x0100010201020101, 0x0100010201020200, 0x0100010202000001, 0x0100010202000101, 0x0100010202000202, 0x0100010202010100, 0x0100010202010101, 0x0100010202020001, 0x0100010202020100, 0x0100010202020102, 0x0100020000000101, 0x0100020000010001, 0x0100020000010101, 0x0100020000010202, 0x0100020000020101, 0x0100020001000002, 0x0100020001000201, 0x0100020001010000, 0x0100020001010101, 0x0100020001010200, 0x0100020001020001, 0x0100020001020100, 0x0100020001020102, 0x0100020001020201, 0x0100020002000101, 0x0100020002010001, 0x0100020002010100, 0x0100020002010102, 0x0100020002010201, 0x0100020002020101, 0x0100020100000001, 0x0100020100000101, 0x0100020100000102, 0x0100020100000202, 0x0100020100010000, 0x0100020100010100, 0x0100020100010101, 0x0100020100010200, 0x0100020100020001, 0x0100020100020100, 0x0100020100020102, 0x0100020101000000, 0x0100020101000101, 0x0100020101000202, 0x0100020101010001, 0x0100020101010002, 0x0100020101010100, 0x0100020101010101, 0x0100020101010102, 0x0100020101010201, 0x0100020101020000, 0x0100020101020002, 0x0100020101020101, 0x0100020101020102, 0x0100020101020202, 0x0100020102000102, 0x0100020102000201, 0x0100020102010002, 0x0100020102010101, 0x0100020102010102, 0x0100020102010200, 0x0100020102020001, 0x0100020102020100, 0x0100020102020102, 0x0100020102020201, 0x0100020200010102, 0x0100020201000100, 0x0100020201000102, 0x0100020201000201, 0x0100020201010101, 0x0100020201010200, 0x0100020201010202, 0x0100020201020100, 0x0100020201020201, 0x0100020202010100, 0x0100020202020101, 0x0101000000000001, 0x0101000000000100, 0x0101000000000101, 0x0101000000000102, 0x0101000000000201, 0x0101000000010002, 0x0101000000010101, 0x0101000000010202, 0x0101000000020001, 0x0101000000020100, 0x0101000000020201, 0x0101000001000000, 0x0101000001000101, 0x0101000001000200, 0x0101000001010001, 0x0101000001010100, 0x0101000001010101, 0x0101000001010102, 0x0101000001010201, 0x0101000001020101, 0x0101000001020200, 0x0101000002000102, 0x0101000002000201, 0x0101000002010101, 0x0101000002010200, 0x0101000002020000, 0x0101000002020001, 0x0101000002020102, 0x0101000002020201, 0x0101000100000101, 0x0101000100000200, 0x0101000100000201, 0x0101000100000202, 0x0101000100010001, 0x0101000100010100, 0x0101000100010101, 0x0101000100010102, 0x0101000100010200, 0x0101000100010201, 0x0101000100020000, 0x0101000100020101, 0x0101000100020102, 0x0101000100020200, 0x0101000100020202, 0x0101000101000001, 0x0101000101000100, 0x0101000101000101, 0x0101000101000102, 0x0101000101000201, 0x0101000101010000, 0x0101000101010001, 0x0101000101010002, 0x0101000101010100, 0x0101000101010101, 0x0101000101010102, 0x0101000101010200, 0x0101000101010201, 0x0101000101010202, 0x0101000101020001, 0x0101000101020100, 0x0101000101020101, 0x0101000101020102, 0x0101000101020201, 0x0101000102000002, 0x0101000102000101, 0x0101000102010001, 0x0101000102010100, 0x0101000102010101, 0x0101000102010102, 0x0101000102010201, 0x0101000102020000, 0x0101000102020101, 0x0101000102020202, 0x0101000200000001, 0x0101000200000102, 0x0101000200010002, 0x0101000200010101, 0x0101000200010202, 0x0101000200020001, 0x0101000200020100, 0x0101000201000002, 0x0101000201000101, 0x0101000201000202, 0x0101000201010001, 0x0101000201010100, 0x0101000201010101, 0x0101000201010102, 0x0101000201010201, 0x0101000201020002, 0x0101000201020101, 0x0101000202000101, 0x0101000202010000, 0x0101000202010002, 0x0101000202010101, 0x0101000202010201, 0x0101000202010202, 0x0101000202020100, 0x0101010000000100, 0x0101010000000101, 0x0101010000010001, 0x0101010000010100, 0x0101010000010101, 0x0101010000010102, 0x0101010000010200, 0x0101010000010201, 0x0101010000020001, 0x0101010000020101, 0x0101010000020200, 0x0101010000020202, 0x0101010001000001, 0x0101010001000100, 0x0101010001000101, 0x0101010001000102, 0x0101010001000201, 0x0101010001000202, 0x0101010001010000, 0x0101010001010001, 0x0101010001010100, 0x0101010001010101, 0x0101010001010102, 0x0101010001010200, 0x0101010001010201, 0x0101010001010202, 0x0101010001020001, 0x0101010001020002, 0x0101010001020100, 0x0101010001020101, 0x0101010001020102, 0x0101010001020201, 0x0101010002000000, 0x0101010002000200, 0x0101010002000202, 0x0101010002010001, 0x0101010002010100, 0x0101010002010101, 0x0101010002010102, 0x0101010002010201, 0x0101010002020001, 0x0101010002020100, 0x0101010002020101, 0x0101010002020202, 0x0101010100000001, 0x0101010100000002, 0x0101010100000100, 0x0101010100000101, 0x0101010100000102, 0x0101010100000201, 0x0101010100010000, 0x0101010100010001, 0x0101010100010002, 0x0101010100010100, 0x0101010100010101, 0x0101010100010102, 0x0101010100010201, 0x0101010100010202, 0x0101010100020001, 0x0101010100020100, 0x0101010100020101, 0x0101010100020102, 0x0101010100020201, 0x0101010101000000, 0x0101010101000001, 0x0101010101000002, 0x0101010101000100, 0x0101010101000101, 0x0101010101000102, 0x0101010101000200, 0x0101010101000201, 0x0101010101010000, 0x0101010101010001, 0x0101010101010002, 0x0101010101010100, 0x0101010101010101, 0x0101010101010102, 0x0101010101010200, 0x0101010101010201, 0x0101010101010202, 0x0101010101020000, 0x0101010101020001, 0x0101010101020100, 0x0101010101020101, 0x0101010101020102, 0x0101010101020200, 0x0101010101020201, 0x0101010101020202, 0x0101010102000001, 0x0101010102000100, 0x0101010102000101, 0x0101010102000201, 0x0101010102000202, 0x0101010102010000, 0x0101010102010001, 0x0101010102010100, 0x0101010102010101, 0x0101010102010102, 0x0101010102010200, 0x0101010102010201, 0x0101010102020001, 0x0101010102020100, 0x0101010102020101, 0x0101010102020102, 0x0101010102020201, 0x0101010200000000, 0x0101010200000001, 0x0101010200000002, 0x0101010200000100, 0x0101010200000102, 0x0101010200000200, 0x0101010200000201, 0x0101010200010001, 0x0101010200010100, 0x0101010200010101, 0x0101010200010200, 0x0101010200010201, 0x0101010200020000, 0x0101010200020001, 0x0101010200020002, 0x0101010200020100, 0x0101010200020101, 0x0101010200020102, 0x0101010200020200, 0x0101010200020201, 0x0101010201000001, 0x0101010201000101, 0x0101010201000102, 0x0101010201000200, 0x0101010201000201, 0x0101010201000202, 0x0101010201010000, 0x0101010201010001, 0x0101010201010002, 0x0101010201010100, 0x0101010201010101, 0x0101010201010102, 0x0101010201010200, 0x0101010201010201, 0x0101010201010202, 0x0101010201020001, 0x0101010201020100, 0x0101010201020101, 0x0101010201020201, 0x0101010202000002, 0x0101010202000101, 0x0101010202000102, 0x0101010202000200, 0x0101010202000201, 0x0101010202000202, 0x0101010202010001, 0x0101010202010101, 0x0101010202010202, 0x0101010202020002, 0x0101010202020101, 0x0101010202020102, 0x0101010202020200, 0x0101010202020201, 0x0101020000000100, 0x0101020000000101, 0x0101020000000102, 0x0101020000000201, 0x0101020000010000, 0x0101020000010101, 0x0101020000010200, 0x0101020000020001, 0x0101020000020202, 0x0101020001000101, 0x0101020001000200, 0x0101020001000202, 0x0101020001010001, 0x0101020001010100, 0x0101020001010101, 0x0101020001010102, 0x0101020001010200, 0x0101020001010201, 0x0101020001020000, 0x0101020001020002, 0x0101020001020100, 0x0101020001020101, 0x0101020002000002, 0x0101020002000201, 0x0101020002010000, 0x0101020002010002, 0x0101020002010101, 0x0101020002010200, 0x0101020002020001, 0x0101020002020201, 0x0101020100000001, 0x0101020100000002, 0x0101020100000101, 0x0101020100000202, 0x0101020100010001, 0x0101020100010100, 0x0101020100010101, 0x0101020100010102, 0x0101020100010201, 0x0101020100020101, 0x0101020101000001, 0x0101020101000100, 0x0101020101000101, 0x0101020101000102, 0x0101020101000201, 0x0101020101010000, 0x0101020101010001, 0x0101020101010002, 0x0101020101010100, 0x0101020101010101, 0x0101020101010102, 0x0101020101010200, 0x0101020101010201, 0x0101020101010202, 0x0101020101020001, 0x0101020101020100, 0x0101020101020101, 0x0101020101020102, 0x0101020101020201, 0x0101020102000001, 0x0101020102000101, 0x0101020102000201, 0x0101020102010001, 0x0101020102010100, 0x0101020102010101, 0x0101020102010102, 0x0101020102010200, 0x0101020102010201, 0x0101020102020101, 0x0101020200000100, 0x0101020200000200, 0x0101020200010101, 0x0101020200010202, 0x0101020200020000, 0x0101020200020101, 0x0101020200020102, 0x0101020200020201, 0x0101020201000101, 0x0101020201000200, 0x0101020201000201, 0x0101020201010001, 0x0101020201010101, 0x0101020201010102, 0x0101020201010200, 0x0101020201010201, 0x0101020201020002, 0x0101020201020101, 0x0101020201020200, 0x0101020201020202, 0x0101020202000001, 0x0101020202000202, 0x0101020202010002, 0x0101020202010101, 0x0101020202010102, 0x0101020202010200, 0x0101020202010202, 0x0101020202020001, 0x0102000000000101, 0x0102000000010100, 0x0102000000010102, 0x0102000000010201, 0x0102000000020101, 0x0102000001000100, 0x0102000001010000, 0x0102000001010101, 0x0102000001010102, 0x0102000001010200, 0x0102000001010202, 0x0102000001020001, 0x0102000001020100, 0x0102000001020102, 0x0102000001020201, 0x0102000002000001, 0x0102000002010102, 0x0102000002020101, 0x0102000100000001, 0x0102000100000100, 0x0102000100000102, 0x0102000100000201, 0x0102000100010002, 0x0102000100010101, 0x0102000100020001, 0x0102000100020002, 0x0102000100020102, 0x0102000100020201, 0x0102000101000101, 0x0102000101000201, 0x0102000101010001, 0x0102000101010101, 0x0102000101010102, 0x0102000101010201, 0x0102000101020101, 0x0102000101020102, 0x0102000101020202, 0x0102000102000100, 0x0102000102000202, 0x0102000102010002, 0x0102000102010101, 0x0102000102020001, 0x0102000102020102, 0x0102000102020201, 0x0102000200010001, 0x0102000200010102, 0x0102000200010201, 0x0102000201000000, 0x0102000201000001, 0x0102000201000102, 0x0102000201010101, 0x0102000201010102, 0x0102000201010200, 0x0102000201020000, 0x0102000202000101, 0x0102000202010001, 0x0102000202010102, 0x0102000202020101, 0x0102010000010001, 0x0102010000010002, 0x0102010000010101, 0x0102010000010102, 0x0102010000010202, 0x0102010000020001, 0x0102010000020102, 0x0102010000020201, 0x0102010001000000, 0x0102010001000002, 0x0102010001000101, 0x0102010001000200, 0x0102010001000202, 0x0102010001010001, 0x0102010001010100, 0x0102010001010101, 0x0102010001010102, 0x0102010001010201, 0x0102010001010202, 0x0102010001020000, 0x0102010001020002, 0x0102010001020101, 0x0102010002000100, 0x0102010002000101, 0x0102010002000201, 0x0102010002010000, 0x0102010002010002, 0x0102010002010100, 0x0102010002010101, 0x0102010002010102, 0x0102010002010200, 0x0102010002010202, 0x0102010002020001, 0x0102010002020100, 0x0102010002020201, 0x0102010100000101, 0x0102010100000200, 0x0102010100000202, 0x0102010100010001, 0x0102010100010101, 0x0102010100010102, 0x0102010100010201, 0x0102010101000100, 0x0102010101000101, 0x0102010101000102, 0x0102010101000201, 0x0102010101010000, 0x0102010101010001, 0x0102010101010100, 0x0102010101010101, 0x0102010101010102, 0x0102010101010201, 0x0102010101020001, 0x0102010101020100, 0x0102010101020101, 0x0102010101020102, 0x0102010101020201, 0x0102010102000102, 0x0102010102000201, 0x0102010102000202, 0x0102010102010001, 0x0102010102010101, 0x0102010102010102, 0x0102010102010201, 0x0102010102010202, 0x0102010102020002, 0x0102010102020101, 0x0102010102020102, 0x0102010102020200, 0x0102010200000002, 0x0102010200000201, 0x0102010200010101, 0x0102010200020000, 0x0102010200020102, 0x0102010200020200, 0x0102010200020201, 0x0102010201000000, 0x0102010201000101, 0x0102010201000200, 0x0102010201000202, 0x0102010201010001, 0x0102010201010100, 0x0102010201010101, 0x0102010201010102, 0x0102010201010200, 0x0102010201010202, 0x0102010201020000, 0x0102010201020101, 0x0102010201020200, 0x0102010202000000, 0x0102010202000002, 0x0102010202000101, 0x0102010202000202, 0x0102010202010100, 0x0102010202010102, 0x0102010202010200, 0x0102010202010201, 0x0102010202020000, 0x0102010202020100, 0x0102010202020102, 0x0102010202020202, 0x0102020000010102, 0x0102020000010201, 0x0102020000020101, 0x0102020001000001, 0x0102020001010002, 0x0102020001010101, 0x0102020001010202, 0x0102020001020001, 0x0102020001020201, 0x0102020002000101, 0x0102020002010001, 0x0102020002010200, 0x0102020002020102, 0x0102020100000001, 0x0102020100000100, 0x0102020100010000, 0x0102020100010101, 0x0102020100020001, 0x0102020100020100, 0x0102020100020102, 0x0102020100020201, 0x0102020101000000, 0x0102020101000001, 0x0102020101000101, 0x0102020101000102, 0x0102020101000200, 0x0102020101010001, 0x0102020101010100, 0x0102020101010101, 0x0102020101010102, 0x0102020101010201, 0x0102020101020000, 0x0102020101020101, 0x0102020101020202, 0x0102020102000002, 0x0102020102000100, 0x0102020102000202, 0x0102020102010101, 0x0102020102020001, 0x0102020102020100, 0x0102020102020101, 0x0102020102020201, 0x0102020200010001, 0x0102020200010102, 0x0102020200010200, 0x0102020201000001, 0x0102020201000100, 0x0102020201000201, 0x0102020201010000, 0x0102020201010101, 0x0102020201010200, 0x0102020201010202, 0x0102020201020100, 0x0102020201020101, 0x0102020201020201, 0x0102020202000102, 0x0102020202010100, 0x0102020202010200, 0x0102020202010202, 0x0102020202020102, 0x0200000000000000, 0x0200000000000002, 0x0200000000000200, 0x0200000000000202, 0x0200000000020000, 0x0200000000020002, 0x0200000000020200, 0x0200000000020202, 0x0200000001000101, 0x0200000001010000, 0x0200000001010001, 0x0200000001010100, 0x0200000001010102, 0x0200000001010201, 0x0200000001020101, 0x0200000002000000, 0x0200000002000002, 0x0200000002000200, 0x0200000002000202, 0x0200000002010101, 0x0200000002020000, 0x0200000002020002, 0x0200000002020200, 0x0200000002020202, 0x0200000100000101, 0x0200000100010001, 0x0200000100010100, 0x0200000100010102, 0x0200000100010201, 0x0200000100020101, 0x0200000101000001, 0x0200000101000100, 0x0200000101000201, 0x0200000101010000, 0x0200000101010002, 0x0200000101010101, 0x0200000101010102, 0x0200000101010200, 0x0200000101010201, 0x0200000101020100, 0x0200000101020102, 0x0200000101020201, 0x0200000102000101, 0x0200000102000201, 0x0200000102010100, 0x0200000102010102, 0x0200000102010201, 0x0200000102020101, 0x0200000200000000, 0x0200000200000002, 0x0200000200000200, 0x0200000200000202, 0x0200000200010101, 0x0200000200020000, 0x0200000200020002, 0x0200000200020200, 0x0200000200020202, 0x0200000201010001, 0x0200000201010100, 0x0200000201010201, 0x0200000201020101, 0x0200000202000000, 0x0200000202000002, 0x0200000202000200, 0x0200000202000202, 0x0200000202010101, 0x0200000202020000, 0x0200000202020002, 0x0200000202020200, 0x0200000202020202, 0x0200010000010100, 0x0200010000010201, 0x0200010001000001, 0x0200010001000100, 0x0200010001010001, 0x0200010001010101, 0x0200010001010202, 0x0200010001020001, 0x0200010001020100, 0x0200010001020201, 0x0200010002010100, 0x0200010002010201, 0x0200010100000001, 0x0200010100000201, 0x0200010100010002, 0x0200010100010101, 0x0200010100010202, 0x0200010100020102, 0x0200010100020201, 0x0200010101000000, 0x0200010101000001, 0x0200010101000101, 0x0200010101000200, 0x0200010101010001, 0x0200010101010100, 0x0200010101010101, 0x0200010101010102, 0x0200010101010201, 0x0200010101010202, 0x0200010101020101, 0x0200010101020102, 0x0200010101020200, 0x0200010101020202, 0x0200010102000001, 0x0200010102000100, 0x0200010102000102, 0x0200010102000201, 0x0200010102010000, 0x0200010102010002, 0x0200010102010101, 0x0200010102010200, 0x0200010102020102, 0x0200010200010001, 0x0200010200010102, 0x0200010200010201, 0x0200010200020101, 0x0200010201000001, 0x0200010201000100, 0x0200010201000201, 0x0200010201000202, 0x0200010201010000, 0x0200010201010101, 0x0200010201010201, 0x0200010201010202, 0x0200010201020001, 0x0200010201020102, 0x0200010201020202, 0x0200010202000101, 0x0200010202010001, 0x0200010202010202, 0x0200010202020100, 0x0200020000000000, 0x0200020000000002, 0x0200020000000200, 0x0200020000000202, 0x0200020000010101, 0x0200020000020000, 0x0200020000020002, 0x0200020000020200, 0x0200020000020202, 0x0200020001000001, 0x0200020001000101, 0x0200020001010001, 0x0200020001010100, 0x0200020001010201, 0x0200020001020101, 0x0200020001020201, 0x0200020002000000, 0x0200020002000002, 0x0200020002000200, 0x0200020002000202, 0x0200020002010101, 0x0200020002020000, 0x0200020002020002, 0x0200020002020200, 0x0200020002020202, 0x0200020100000101, 0x0200020100000102, 0x0200020100010001, 0x0200020100010100, 0x0200020100010102, 0x0200020100020101, 0x0200020101000001, 0x0200020101000100, 0x0200020101000102, 0x0200020101000201, 0x0200020101010000, 0x0200020101010002, 0x0200020101010101, 0x0200020101010202, 0x0200020101020001, 0x0200020101020100, 0x0200020102000101, 0x0200020102010102, 0x0200020102010201, 0x0200020102020101, 0x0200020200000000, 0x0200020200000002, 0x0200020200000200, 0x0200020200000202, 0x0200020200010101, 0x0200020200020000, 0x0200020200020002, 0x0200020200020200, 0x0200020200020202, 0x0200020201000101, 0x0200020201010001, 0x0200020201010100, 0x0200020201010102, 0x0200020202000000, 0x0200020202000002, 0x0200020202000200, 0x0200020202000202, 0x0200020202010101, 0x0200020202020000, 0x0200020202020002, 0x0200020202020200, 0x0200020202020202, 0x0201000000000101, 0x0201000000010001, 0x0201000000010102, 0x0201000000010200, 0x0201000000010201, 0x0201000000020101, 0x0201000001000001, 0x0201000001000102, 0x0201000001000201, 0x0201000001010101, 0x0201000001010200, 0x0201000001010202, 0x0201000001020201, 0x0201000001020202, 0x0201000002000101, 0x0201000002010001, 0x0201000002010100, 0x0201000002010102, 0x0201000002010201, 0x0201000002020101, 0x0201000100000001, 0x0201000100000100, 0x0201000100000102, 0x0201000100000201, 0x0201000100010000, 0x0201000100010101, 0x0201000100010200, 0x0201000100010202, 0x0201000100020001, 0x0201000100020100, 0x0201000100020102, 0x0201000100020201, 0x0201000101000000, 0x0201000101000101, 0x0201000101010000, 0x0201000101010001, 0x0201000101010100, 0x0201000101010101, 0x0201000101010102, 0x0201000101010201, 0x0201000101020002, 0x0201000101020101, 0x0201000102000100, 0x0201000102000102, 0x0201000102010002, 0x0201000102010101, 0x0201000102010200, 0x0201000102020001, 0x0201000102020100, 0x0201000102020102, 0x0201000102020201, 0x0201000200000101, 0x0201000200010001, 0x0201000200010100, 0x0201000200010201, 0x0201000200020101, 0x0201000201000100, 0x0201000201000102, 0x0201000201000201, 0x0201000201010000, 0x0201000201010002, 0x0201000201010101, 0x0201000201010200, 0x0201000201020102, 0x0201000201020201, 0x0201000202000101, 0x0201000202010100, 0x0201000202010102, 0x0201000202020201, 0x0201010000000001, 0x0201010000000100, 0x0201010000000102, 0x0201010000010000, 0x0201010000010101, 0x0201010000010200, 0x0201010000020102, 0x0201010001000000, 0x0201010001000202, 0x0201010001010001, 0x0201010001010100, 0x0201010001010101, 0x0201010001010102, 0x0201010001010200, 0x0201010001010201, 0x0201010001020000, 0x0201010001020001, 0x0201010001020002, 0x0201010001020101, 0x0201010002000100, 0x0201010002000102, 0x0201010002010002, 0x0201010002010100, 0x0201010002010101, 0x0201010002010200, 0x0201010002020001, 0x0201010002020201, 0x0201010100000000, 0x0201010100000101, 0x0201010100000200, 0x0201010100000202, 0x0201010100010000, 0x0201010100010001, 0x0201010100010100, 0x0201010100010101, 0x0201010100010102, 0x0201010100010201, 0x0201010100020001, 0x0201010100020101, 0x0201010100020201, 0x0201010100020202, 0x0201010101000001, 0x0201010101000100, 0x0201010101000101, 0x0201010101000102, 0x0201010101000201, 0x0201010101010000, 0x0201010101010001, 0x0201010101010002, 0x0201010101010100, 0x0201010101010101, 0x0201010101010102, 0x0201010101010200, 0x0201010101010201, 0x0201010101010202, 0x0201010101020001, 0x0201010101020100, 0x0201010101020101, 0x0201010101020102, 0x0201010101020201, 0x0201010102000001, 0x0201010102000101, 0x0201010102000200, 0x0201010102010001, 0x0201010102010002, 0x0201010102010100, 0x0201010102010101, 0x0201010102010102, 0x0201010102010201, 0x0201010102010202, 0x0201010102020000, 0x0201010102020002, 0x0201010102020101, 0x0201010102020200, 0x0201010102020202, 0x0201010200000001, 0x0201010200000100, 0x0201010200010000, 0x0201010200010101, 0x0201010200010201, 0x0201010200020000, 0x0201010200020102, 0x0201010200020201, 0x0201010201000101, 0x0201010201000200, 0x0201010201000201, 0x0201010201010001, 0x0201010201010002, 0x0201010201010101, 0x0201010201010102, 0x0201010201010201, 0x0201010201020101, 0x0201010201020200, 0x0201010202000002, 0x0201010202000100, 0x0201010202000201, 0x0201010202000202, 0x0201010202010002, 0x0201010202010100, 0x0201010202010101, 0x0201010202020100, 0x0201010202020102, 0x0201010202020201, 0x0201020000000101, 0x0201020000010102, 0x0201020000010201, 0x0201020000020101, 0x0201020001000001, 0x0201020001000102, 0x0201020001010000, 0x0201020001010002, 0x0201020001010101, 0x0201020001010102, 0x0201020001010202, 0x0201020001020100, 0x0201020001020101, 0x0201020002000101, 0x0201020002010001, 0x0201020002010102, 0x0201020002010201, 0x0201020002020101, 0x0201020100000100, 0x0201020100000102, 0x0201020100000201, 0x0201020100010000, 0x0201020100010002, 0x0201020100010101, 0x0201020100010200, 0x0201020100010202, 0x0201020100020000, 0x0201020100020001, 0x0201020100020100, 0x0201020100020102, 0x0201020101000000, 0x0201020101000002, 0x0201020101000101, 0x0201020101000200, 0x0201020101000202, 0x0201020101010001, 0x0201020101010100, 0x0201020101010101, 0x0201020101010102, 0x0201020101010201, 0x0201020101020002, 0x0201020101020101, 0x0201020101020102, 0x0201020101020202, 0x0201020102000001, 0x0201020102000100, 0x0201020102010000, 0x0201020102010002, 0x0201020102010101, 0x0201020102010202, 0x0201020102020001, 0x0201020102020102, 0x0201020200000101, 0x0201020200010101, 0x0201020200020101, 0x0201020201000100, 0x0201020201000102, 0x0201020201000201, 0x0201020201010000, 0x0201020201010101, 0x0201020201010200, 0x0201020201020001, 0x0201020202000101, 0x0201020202010001, 0x0201020202010100, 0x0201020202010101, 0x0201020202010102, 0x0202000000000000, 0x0202000000000002, 0x0202000000000200, 0x0202000000000202, 0x0202000000010101, 0x0202000000020000, 0x0202000000020002, 0x0202000000020200, 0x0202000000020202, 0x0202000001000101, 0x0202000001010001, 0x0202000001010100, 0x0202000001010102, 0x0202000001010201, 0x0202000002000000, 0x0202000002000002, 0x0202000002000200, 0x0202000002000202, 0x0202000002010101, 0x0202000002020000, 0x0202000002020002, 0x0202000002020200, 0x0202000002020202, 0x0202000100000101, 0x0202000100000201, 0x0202000100010001, 0x0202000100010100, 0x0202000100010102, 0x0202000100010201, 0x0202000100010202, 0x0202000101000102, 0x0202000101000201, 0x0202000101010001, 0x0202000101010101, 0x0202000101010200, 0x0202000101010202, 0x0202000101020001, 0x0202000101020100, 0x0202000102000101, 0x0202000102010000, 0x0202000102010002, 0x0202000102010102, 0x0202000102010201, 0x0202000200000002, 0x0202000200000200, 0x0202000200000202, 0x0202000200010000, 0x0202000200010201, 0x0202000200020002, 0x0202000200020200, 0x0202000200020202, 0x0202000201000101, 0x0202000201010001, 0x0202000201010102, 0x0202000201010201, 0x0202000201020101, 0x0202000202000000, 0x0202000202000002, 0x0202000202000200, 0x0202000202000202, 0x0202000202010101, 0x0202000202020000, 0x0202000202020002, 0x0202000202020200, 0x0202000202020202, 0x0202010000010201, 0x0202010000020101, 0x0202010001000001, 0x0202010001000100, 0x0202010001010000, 0x0202010001010100, 0x0202010001010101, 0x0202010001010200, 0x0202010001010202, 0x0202010001020001, 0x0202010001020101, 0x0202010001020102, 0x0202010001020200, 0x0202010001020201, 0x0202010002000101, 0x0202010100000102, 0x0202010100000201, 0x0202010100010000, 0x0202010100010002, 0x0202010100010101, 0x0202010100010200, 0x0202010100020102, 0x0202010100020201, 0x0202010101000002, 0x0202010101000101, 0x0202010101010001, 0x0202010101010100, 0x0202010101010101, 0x0202010101010102, 0x0202010101010201, 0x0202010101020101, 0x0202010101020202, 0x0202010102000001, 0x0202010102000100, 0x0202010102000101, 0x0202010102000102, 0x0202010102000201, 0x0202010102010002, 0x0202010102010101, 0x0202010102010200, 0x0202010200000101, 0x0202010200010001, 0x0202010200010102, 0x0202010200010202, 0x0202010200020001, 0x0202010200020101, 0x0202010201000100, 0x0202010201000102, 0x0202010201000202, 0x0202010201010002, 0x0202010201010101, 0x0202010201010102, 0x0202010201010200, 0x0202010201020000, 0x0202010201020002, 0x0202010202000102, 0x0202010202010000, 0x0202010202010101, 0x0202010202010102, 0x0202010202010201, 0x0202010202020001, 0x0202010202020100, 0x0202010202020102, 0x0202020000000000, 0x0202020000000002, 0x0202020000000200, 0x0202020000000202, 0x0202020000020000, 0x0202020000020002, 0x0202020000020200, 0x0202020000020202, 0x0202020001010001, 0x0202020001010100, 0x0202020001010102, 0x0202020001010201, 0x0202020002000000, 0x0202020002000002, 0x0202020002000200, 0x0202020002000202, 0x0202020002010101, 0x0202020002020000, 0x0202020002020002, 0x0202020002020200, 0x0202020002020202, 0x0202020100000101, 0x0202020100010100, 0x0202020100010201, 0x0202020100020001, 0x0202020100020101, 0x0202020101000001, 0x0202020101010000, 0x0202020101010101, 0x0202020101010202, 0x0202020101020001, 0x0202020101020102, 0x0202020101020201, 0x0202020102010000, 0x0202020102010102, 0x0202020200000000, 0x0202020200000002, 0x0202020200000200, 0x0202020200000202, 0x0202020200020000, 0x0202020200020002, 0x0202020200020200, 0x0202020200020202, 0x0202020201010001, 0x0202020201010100, 0x0202020201010102, 0x0202020202000000, 0x0202020202000002, 0x0202020202000200, 0x0202020202000202, 0x0202020202010101, 0x0202020202020000, 0x0202020202020002, 0x0202020202020200, 0x0202020202020202, }; #else static const uint32_t iq1s_grid_us[2048] = { 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212, 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011, 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111, 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220, 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022, 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220, 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101, 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110, 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111, 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010, 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210, 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221, 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021, 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002, 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101, 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101, 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211, 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110, 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022, 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121, 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220, 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001, 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101, 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102, 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012, 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010, 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111, 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122, 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222, 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001, 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102, 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101, 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000, 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101, 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112, 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110, 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211, 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012, 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111, 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120, 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122, 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121, 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221, 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001, 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101, 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101, 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011, 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111, 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011, 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122, 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121, 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222, 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101, 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000, 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200, 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110, 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112, 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222, 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021, 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121, 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201, 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200, 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101, 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011, 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010, 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211, 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121, 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000, 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202, 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202, 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211, 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112, 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020, 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121, 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222, 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102, 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100, 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110, 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011, 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111, 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110, 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121, 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222, 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201, 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102, 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201, 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012, 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010, 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010, 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110, 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011, 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212, 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021, 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021, 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021, 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101, 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101, 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100, 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010, 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111, 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010, 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111, 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120, 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120, 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101, 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001, 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201, 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210, 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211, 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111, 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112, 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211, 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010, 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021, 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122, 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221, 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102, 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100, 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101, 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101, 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101, 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012, 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110, 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112, 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210, 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210, 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210, 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010, 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110, 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122, 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020, 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021, 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022, 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120, 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222, 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221, 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001, 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102, 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201, 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012, 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111, 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012, 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110, 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110, 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121, 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221, 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220, 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222, 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000, 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201, 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012, 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011, 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212, 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221, 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121, 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202, 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202, 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002, 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101, 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210, 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112, 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011, 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011, 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210, 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020, 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220, 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222, 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222, 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001, 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010, 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111, 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010, 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110, 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221, 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122, 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202, 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100, 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101, 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112, 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111, 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211, 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222, 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221, 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022, 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101, 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211, 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111, 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111, 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010, 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121, 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222, 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000, 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202, 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000, 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202, 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110, 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110, 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222, 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120, 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022, 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101, 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202, 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110, 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110, 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111, 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111, 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120, 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121, 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001, 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202, 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001, 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200, 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011, 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212, 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012, 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110, 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012, 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111, 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020, 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121, 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222, 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102, 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102, 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101, 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212, 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210, 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111, 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212, 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221, 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121, 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002, 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000, 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202, 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112, 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111, 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020, 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221, 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022, 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100, 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201, 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112, 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211, 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012, 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121, 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020, 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120, 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200, 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200, 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110, 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011, 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222, 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, }; #endif // end copy https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L570 #ifndef HAVE_FANCY_SIMD const uint64_t keven_signs[128] = { 0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff, 0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff, 0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff, 0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff, 0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff, 0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff, 0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff, 0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff, 0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff, 0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff, 0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff, 0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff, 0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff, 0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff, 0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff, 0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff, 0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff, 0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff, 0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff, 0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff, 0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff, 0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff, 0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff, 0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff, 0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff, 0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff, 0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff, 0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff, 0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff, 0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff, 0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff, 0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff, }; #endif } /* moonll change mulmat add typeB and strideB }*/ // Adapted from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L406 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, float * C, long stride_C, int ith, int nth) { MulMat mm; if (!MulMat::set_mul_mat(typeA, typeB, ne00, mm, Ny)) { return false; } size_t row_size_qx = strideA*ggml_type_size(ggml_type(typeA)); size_t row_size_qy = strideB*ggml_type_size(ggml_type(typeB)); auto nrc_x = (Nx + nth - 1)/nth; auto first_x = ith*nrc_x; if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0}; mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); return true; } // end adapted from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L406 bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, const void * B, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping; assert(row_mapping != nullptr); MulMat mm; int row_size_q8; /* moonll if (!MulMat::set_mul_mat(typeA, ne00, mm, row_size_q8, Ny)) { return false; }*/ int row_size_qx = ggml_row_size((ggml_type)typeA, ne00); int nrc_x = (Nx + nth - 1)/nth; int first_x = ith*nrc_x; if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), (size_t)row_size_q8, 0, ne11, row_mapping, nb2/sizeof(float)}; mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); return true; } #if defined __x86_64__ || defined(_M_X64) #if defined HAVE_FANCY_SIMD #undef HAVE_FANCY_SIMD #endif #if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) #define HAVE_FANCY_SIMD #endif //#define HAVE_FANCY_SIMD namespace { inline float hsum_float_4(__m128 x) { x = _mm_add_ps(x, _mm_movehl_ps(x, x)); x = _mm_add_ss(x, _mm_movehdup_ps(x)); return _mm_cvtss_f32(x); } inline float hsum_float_8(__m256 x) { return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); } #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) template struct Q8 { constexpr static int nrc_y = nrc; Q8(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy); } #ifdef HAVE_FANCY_SIMD inline __m512i load_quants64(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); } #endif inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); } inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); } inline float scale(int iy, int i) const { return y[iy][i].d; } const block_q8 * y[nrc_y]; }; // Handles q4_K and q5_K scales/mins struct Scales8K { template inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) { make_q4_scales(data, utmp); const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1); accum_mins(mins128, q8, i, c, accd); const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); return MM256_SET_M128I(sc128, sc128); } #ifdef HAVE_FANCY_SIMD template inline __m512i process_mins_and_scales_64(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) { auto scales = process_mins_and_scales(data, c, i, q8, accd); return _mm512_inserti32x8(_mm512_castsi256_si512(scales), scales, 1); } #endif template inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const { const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0])); for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i q8s = q8.load_bsums(iy, i); const __m256i prod = _mm256_madd_epi16(mins, q8s); accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]); } } #ifdef HAVE_FANCY_SIMD const __m512i shuffles512[2] = { _mm512_set_epi64(0x0706070607060706, 0x0302030203020302, 0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100, 0x0504050405040504, 0x0100010001000100), _mm512_set_epi64(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908) }; #endif const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100), _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)}; uint32_t utmp[4]; }; template inline void process_mins_16(const __m256i& all_scales, const Q8& q8, int i, float d, __m256 * accm) { for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i prod = _mm256_madd_epi16(all_scales, q8.load_bsums(iy, i)); accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]); } } inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) { const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); scales[0] = MM256_SET_M128I(l_scales, l_scales); scales[1] = MM256_SET_M128I(h_scales, h_scales); } struct ScaleQ3 { inline __m128i make_scales(const uint16_t * s8) const { const uint16_t * scales16 = (const uint16_t *)s8; uint32_t aux0 = scales16[0] | (scales16[1] << 16); uint32_t aux1 = scales16[2] | (scales16[3] << 16); uint32_t aux2 = scales16[4] | (scales16[5] << 16); __m128i scales128 = _mm_set_epi32( ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030), ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030), (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030), (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030)); return _mm_add_epi8(scales128, m32); } const __m128i m32 = _mm_set1_epi8(-32); }; struct ScaleIQ4XS { inline __m128i make_scales(const uint32_t scales_l, const uint16_t scales_h) { uint32_t tmp32 = scales_h | (scales_h << 14); const __m128i sh = _mm_slli_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(tmp32), hshift), hmask), 4); const __m128i sl = _mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(scales_l), lshift), lmask); return _mm_add_epi16(_mm_or_si128(sh, _mm_cvtepi8_epi16(_mm_shuffle_epi8(sl, lshuffle))), m32); } const __m128i hshift = _mm_set_epi32(12, 8, 4, 0); const __m128i lshift = _mm_set_epi32(4, 0, 4, 0); const __m128i hmask = _mm_set1_epi16(0x03); const __m128i lmask = _mm_set1_epi8(0xf); const __m128i lshuffle = _mm_set_epi32(0x07030602, 0x05010400, 0x07030602, 0x05010400); const __m128i m32 = _mm_set1_epi16(-32); }; // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1455 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow struct Scales8KBase { template inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const { const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0])); for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i q8s = q8.load_bsums(iy, i); const __m256i prod = _mm256_madd_epi16(mins, q8s); accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]); } } inline __m256i shuffle(__m128i mins) const { return MM256_SET_M128I(_mm_shuffle_epi8(mins, shuffles[1]), _mm_shuffle_epi8(mins, shuffles[0])); } const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100), _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)}; }; // end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1455 template struct BaseDequantizer { BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {} inline void new_row(int ix) { x = (const Block *)((const char *)vx + bx*ix); } const void * vx; size_t bx; const Block * x; float d; }; // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1698 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow __m128i inline load_iq4nl_values_128() { static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241}; return _mm_loadu_si128((const __m128i *)kvalues_iq4nl); } __m256i inline load_iq4nl_values_256() { auto val128 = load_iq4nl_values_128(); return MM256_SET_M128I(val128, val128); } // end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1698 #ifdef HAVE_FANCY_SIMD //====================================== Zen4 ================================================== struct BlockPermuter { const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0); const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4); }; struct Q4Bits { inline void prepare(const uint8_t * q4) { auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0); auto tmp1 = _mm512_and_si512(q4bits, ml); auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2); values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2); q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1); tmp1 = _mm512_and_si512(q4bits, ml); tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2); values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2); } inline void prepare64(const uint8_t * q4) { auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0); values[0] = _mm512_and_si512(q4bits, ml); values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1); values[2] = _mm512_and_si512(q4bits, ml); values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml); } __m512i values[4]; const __m512i ml = _mm512_set1_epi8(0xf); BlockPermuter perm; }; struct Q2Bits { inline void prepare(const uint8_t * q2) { auto q2bits = _mm512_loadu_si512((const __m512i*)q2); auto tmp = _mm512_srli_epi16(q2bits, 2); values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp); values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp); values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml); values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml); values[0] = _mm512_and_si512(values[0], ml); values[2] = _mm512_and_si512(values[2], ml); } __m512i values[4]; const __m512i ml = _mm512_set1_epi8(0x03); BlockPermuter perm; }; struct DequantizerQ4K final : public BaseDequantizer { DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) { d = GGML_FP16_TO_FP32(x[i].d); bits.prepare(x[i].qs); auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]); scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]); } Q4Bits bits; Scales8K s8k; }; /* moonll DequantizerIQ4XS */ // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1775 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow __m512i inline load_iq4nl_values_512() { auto val256 = load_iq4nl_values_256(); return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); } // end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1775 // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1781 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow struct DequantizerIQ4XS final : public BaseDequantizer { // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1782 DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {} template inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) { d = GGML_FP16_TO_FP32(x[i].d); prepare(x[i].qs); auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h); s8k.accum_mins(scales128, q8, i, -128.f*d, accd); auto scales256 = MM256_SET_M128I(scales128, scales128); auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]); scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]); scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]); scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]); } inline void prepare(const uint8_t * q4) { bits.prepare64(q4); // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111 // bits.valuse[1]: 16..31, 48...63, 80...95, 112..127 // etc. auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]); bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1])); bits.values[0] = _mm512_shuffle_epi8(values, tmp); tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]); bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3])); bits.values[2] = _mm512_shuffle_epi8(values, tmp); } Q4Bits bits; Scales8KBase s8k; ScaleIQ4XS siq4; const __m512i values; const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); const __m512i shuffles[4] = { _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1), _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1), _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1), _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1), }; }; // end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1781 struct HighBit5 { inline void apply(const uint8_t * h, Q4Bits& bits) { auto hbits256 = _mm256_loadu_si256((const __m256i *)h); auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1); bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh)); bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh)); bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(hbits, mh)); bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh)); } const __m512i mh = _mm512_set1_epi8(0x10); }; struct HighBit3 { inline void apply(const uint8_t * h, Q2Bits& bits) { auto hbits256 = _mm256_loadu_si256((const __m256i *)h); auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1); bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh)); bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, mh)); bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh)); bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), mh)); } const __m512i mh = _mm512_set1_epi8(0x04); }; struct DequantizerQ5K final : public BaseDequantizer { DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) { d = GGML_FP16_TO_FP32(x[i].d); bits.prepare(x[i].qs); hbits.apply(x[i].qh, bits); auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]); scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]); } Q4Bits bits; HighBit5 hbits; Scales8K s8k; }; struct Scale16 { inline void make_scales(const __m128i& scales8, __m512i * scales) const { auto all_scales8 = MM256_SET_M128I(scales8, scales8); auto scales1 = _mm256_shuffle_epi8(all_scales8, shuffle1); auto scales2 = _mm256_shuffle_epi8(all_scales8, shuffle2); scales[0] = _mm512_cvtepi8_epi16(scales1); scales[1] = _mm512_cvtepi8_epi16(scales2); } template inline void process_mins_and_scales(int i, float c, const __m128i& mins8, const __m128i& scales8, const Q8& q8, __m256 * accm, __m512i * scales) const { process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, c, accm); make_scales(scales8, scales); } const __m256i shuffle1 = _mm256_set_epi32(0x07070707, 0x03030303, 0x06060606, 0x02020202, 0x05050505, 0x01010101, 0x04040404, 0x00000000); const __m256i shuffle2 = _mm256_set_epi32(0x0f0f0f0f, 0x0b0b0b0b, 0x0e0e0e0e, 0x0a0a0a0a, 0x0d0d0d0d, 0x09090909, 0x0c0c0c0c, 0x08080808); }; struct DequantizerQ2K final : public BaseDequantizer { DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { d = GGML_FP16_TO_FP32(x[i].d); bits.prepare(x[i].qs); const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); sc16.process_mins_and_scales(i, -GGML_FP16_TO_FP32(x[i].dmin), mins8, scales8, q8, accm, scales); } Q2Bits bits; Scale16 sc16; const __m128i m4 = _mm_set1_epi8(0xf); }; struct DequantizerQ3K final : public BaseDequantizer { DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { d = GGML_FP16_TO_FP32(x[i].d); bits.prepare(x[i].qs); hbits.apply(x[i].hmask, bits); auto scales128 = sc3.make_scales((const uint16_t *)x[i].scales); sc16.process_mins_and_scales(i, -4.f*d, scales128, scales128, q8, accm, scales); } Q2Bits bits; HighBit3 hbits; ScaleQ3 sc3; Scale16 sc16; const __m128i m4 = _mm_set1_epi8(0xf); const __m128i m32 = _mm_set1_epi8(-32); }; struct DequantizerQ6K final : public BaseDequantizer { DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { d = GGML_FP16_TO_FP32(x[i].d); bits.prepare64(x[i].ql); add_high_bits(x[i].qh, bits); auto scales128 = _mm_loadu_si128((const __m128i *)x[i].scales); sc16.process_mins_and_scales(i, -32.f*d, scales128, scales128, q8, accm, scales); } inline void add_high_bits(const uint8_t * qh, Q4Bits& bits) const { auto hbits = _mm512_loadu_si512((const __m512i *)qh); auto tmp1 = _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh); auto tmp2 = _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh); bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2)); bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2)); tmp1 = _mm512_and_si512(hbits, mh); tmp2 = _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh); bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2)); bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2)); } Q4Bits bits; HighBit3 hbits; Scale16 sc16; const __m512i mh = _mm512_set1_epi8(0x30); }; template static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx); __m256 accm[nrc_y]; __m512 accd[nrc_y]; __m512i scales[2]; for (int ix = 0; ix < nrc_x; ++ix) { for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps(); for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps(); deq.new_row(ix); for (int i = 0; i < nb; ++i) { deq.new_block(i, q8, accm, scales); for (int iy = 0; iy < nrc_y; ++iy) { const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants(iy, i, 0)); const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants(iy, i, 1)); const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants(iy, i, 2)); const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants(iy, i, 3)); auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); } } } // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L2408 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow template inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) { const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0)); const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1)); const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2)); const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3)); auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); } template static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx); __m256 accm[nrc_y]; __m512 accd[nrc_y]; __m512i scales[2]; for (int ix = 0; ix < nrc_x; ++ix) { for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps(); for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps(); deq.new_row(ix); for (int i = 0; i < nb; ++i) { deq.new_block(i, q8, accm, scales); for (int iy = 0; iy < nrc_y; ++iy) { const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0)); const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1)); const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2)); const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3)); auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2)); sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4)); accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); } } } template static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx); __m256 accm[nrc_y]; __m512 accd[nrc_y]; __m512i scales[4]; for (int ix = 0; ix < nrc_x; ++ix) { for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps(); for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps(); deq.new_row(ix); for (int i = 0; i < nb; ++i) { deq.new_block(i, q8, accm, scales); for (int iy = 0; iy < nrc_y; ++iy) { const __m512i p1 = _mm512_maddubs_epi16(deq.bits.values[0], q8.load_quants64(iy, i, 0)); const __m512i p2 = _mm512_maddubs_epi16(deq.bits.values[1], q8.load_quants64(iy, i, 1)); const __m512i p3 = _mm512_maddubs_epi16(deq.bits.values[2], q8.load_quants64(iy, i, 2)); const __m512i p4 = _mm512_maddubs_epi16(deq.bits.values[3], q8.load_quants64(iy, i, 3)); auto sumi = _mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_setzero_si512(), p1, scales[0]), p2, scales[1]), p3, scales[2]), p4, scales[3]); accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); } } } template static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; constexpr int k_nx = 2; Q8<1> q8(info); Dequantizer deq1(vx, bx); Dequantizer deq2(vx, bx); Dequantizer * deq[k_nx]; deq[0] = &deq1; deq[1] = &deq2; __m512i scales[2*k_nx]; for (int ix = 0; ix < nrc_x; ++ix) { auto accd = _mm512_setzero_ps(); auto accm = _mm256_setzero_ps(); for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_row(ix); for (int i = 0; i < nb/k_nx; ++i) { for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx); for (int kx = 0; kx < k_nx; ++kx) { compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd); } } if (2*(nb/2) < nb) { int i0 = 2*(nb/2); deq[0]->new_block(i0, q8, &accm, scales); compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd); } auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1)); info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256))); } } // end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L2408 #else // ===================================== Vanilla AVX2 ===================================== struct Q4Bits { inline void prepare(const uint8_t * q4, int j) { auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0); values[0] = _mm256_and_si256(q4bits, ml); values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1); values[2] = _mm256_and_si256(q4bits, ml); values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); } inline void prepare64(const uint8_t * q4, int j) { auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0); values[0] = _mm256_and_si256(q4bits, ml); values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1); values[1] = _mm256_and_si256(q4bits, ml); values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml); } inline void prepare16(const uint8_t * q4, int j) { values[0] = dequant16(q4 + 64*j + 0); values[1] = dequant16(q4 + 64*j + 16); values[2] = dequant16(q4 + 64*j + 32); values[3] = dequant16(q4 + 64*j + 48); } inline __m256i dequant16(const uint8_t * qs) const { const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs); const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128); return _mm256_and_si256(ml, aux256); }; __m256i values[4]; const __m256i ml = _mm256_set1_epi8(0xf); }; struct Q2Bits { inline void prepare(const uint8_t * q2, int j) { auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j); values[0] = _mm256_and_si256(q2bits, ml); values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml); values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml); values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml); } __m256i values[4]; const __m256i ml = _mm256_set1_epi8(0x03); }; struct HighBit5 { inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); } inline void apply(Q4Bits& bits, bool do_shift) { bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh)); bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh)); bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)); if (do_shift) { hbits = _mm256_srli_epi16(hbits, 4); } } const __m256i mh = _mm256_set1_epi8(0x10); __m256i hbits; }; struct HighBit3 { inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); } inline void apply(Q2Bits& bits, bool do_shift) { bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)); bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh)); bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh)); if (do_shift) { hbits = _mm256_srli_epi16(hbits, 4); } } const __m256i mh = _mm256_set1_epi8(0x04); __m256i hbits; }; /* template inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { if (j == 0) { for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2))); const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3))); sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4)); } } else { for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6))); const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7))); sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3)); sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4)); } } }*/ struct DequantizerQ4K final : public BaseDequantizer { DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { d = GGML_FP16_TO_FP32(x[i].d); return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); } inline void prepare(int i, int j) { bits.prepare(x[i].qs, j); } Q4Bits bits; Scales8K s8k; }; struct DequantizerIQ4XS final : public BaseDequantizer { DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {} template inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { d = GGML_FP16_TO_FP32(x[i].d); auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h); s8k.accum_mins(scales128, q8, i, -128.f*d, accd); return MM256_SET_M128I(scales128, scales128); } inline void prepare(int i, int j) { bits.prepare16(x[i].qs, j); bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]); bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]); bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]); bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]); } static __m256i load_values() { static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241}; auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl); return MM256_SET_M128I(val128, val128); } Q4Bits bits; Scales8K s8k; ScaleIQ4XS siq4; const __m256i values; }; struct DequantizerQ5K final : public BaseDequantizer { DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { d = GGML_FP16_TO_FP32(x[i].d); hbits.load(x[i].qh); return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); } inline void prepare(int i, int j) { bits.prepare(x[i].qs, j); hbits.apply(bits, j == 0); } Q4Bits bits; HighBit5 hbits; Scales8K s8k; }; template inline void process_mins_and_scales_16(const __m128i& scales128, const Q8& q8, int i, float d, __m256 * accm, __m256i * scales) { const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); process_mins_16(all_scales, q8, i, d, accm); prepare_scales_16(all_scales, scales); } struct DequantizerQ3K final : public BaseDequantizer { DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { d = GGML_FP16_TO_FP32(x[i].d); hbits.load(x[i].hmask); process_mins_and_scales_16(sc3.make_scales((const uint16_t *)x[i].scales), q8, i, -4.f*d, accm, scales); } inline void prepare(int i, int j) { bits.prepare(x[i].qs, j); hbits.apply(bits, j == 0); } Q2Bits bits; HighBit3 hbits; ScaleQ3 sc3; const __m128i m32 = _mm_set1_epi8(-32); }; struct DequantizerQ2K final : public BaseDequantizer { DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { d = GGML_FP16_TO_FP32(x[i].d); const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, -GGML_FP16_TO_FP32(x[i].dmin), accm); prepare_scales_16(_mm256_cvtepi8_epi16(scales8), scales); } inline void prepare(int i, int j) { bits.prepare(x[i].qs, j); } Q2Bits bits; const __m128i m4 = _mm_set1_epi8(0xf); }; struct DequantizerQ6K final : public BaseDequantizer { DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} template inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { d = GGML_FP16_TO_FP32(x[i].d); process_mins_and_scales_16(_mm_loadu_si128((const __m128i *)x[i].scales), q8, i, -32.f*d, accm, scales); } inline void prepare(int i, int j) { bits.prepare64(x[i].ql, j); auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j); bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh)); bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh)); bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh)); } Q4Bits bits; const __m256i mh = _mm256_set1_epi8(0x30); }; inline __m256i get_scale_shuffle_8(int i); inline void set_scales_8(const __m256i& all_scales, int j, __m256i* scales); inline __m256i get_scale_shuffle_16(int i); inline void set_scales_16(const __m256i& all_scales, __m256i* scales); template static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); const int nb = n/QK_K; Q8 q8(info); __m256i all_scales[2]; __m256i scales[4]; __m256 accd[nrc_y]; Dequantizer deq(vx, bx); for (int ix = 0; ix < nrc_x; ++ix) { deq.new_row(ix); for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { deq.new_block(i, q8, accd, all_scales); __m256i sumi[nrc_y]; for (int j = 0; j < QK_K/128; ++j) { deq.prepare(i, j); set_scales_16(all_scales[j], scales); multiply_add(deq.bits, scales, j, i, q8, sumi); } for (int iy = 0; iy < nrc_y; ++iy) { accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(sumi[iy]), accd[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, hsum_float_8(accd[iy])); } } } template static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx); __m256 accd[nrc_y]; __m256i scales[4]; for (int ix = 0; ix < nrc_x; ++ix) { for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); deq.new_row(ix); for (int i = 0; i < nb; ++i) { auto all_scales = deq.new_block(i, q8, accd); __m256i sumi[nrc_y]; for (int j = 0; j < QK_K/128; ++j) { deq.prepare(i, j); set_scales_8(all_scales, j, scales); multiply_add(deq.bits, scales, j, i, q8, sumi); } for (int iy = 0; iy < nrc_y; ++iy) { const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i)); accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, hsum_float_8(accd[iy])); } } } #endif // Zen4 or vanilla AVX2 // // ============================== Legacy quants // struct DotHelper { const __m256i m1 = _mm256_set1_epi16(1); #if defined(__AVX512VNNI__) && defined(__AVX512VL__) inline __m256i dot(__m256i x, __m256i y) const { return _mm256_dpbusd_epi32(_mm256_setzero_si256(), x, y); } #else inline __m256i dot(__m256i x, __m256i y) const { return _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x, y)); } #endif }; struct SignedDot { DotHelper helper; inline __m256i compute(__m256i x, __m256i y) const { return helper.dot(_mm256_sign_epi8(x, x), _mm256_sign_epi8(y, x)); } }; struct UnsignedDot { DotHelper helper; inline __m256i compute(__m256i x, __m256i y) const { return helper.dot(x, y); } }; template struct Sum4 { Dot dot; inline __m256i compute(const __m256i * qx, const Q8 * y) const { const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[0].qs)); const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y[1].qs)); const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y[2].qs)); const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y[3].qs)); const __m256i p01 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1 const __m256i p23 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3 return _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p01, p23)); // 0,1,2,3, 0,1,2,3 } }; struct Sum4_Q8 { SignedDot dot; static inline __m256i add1(__m256i a, __m256i b) { return _mm256_add_epi32(_mm256_unpacklo_epi32(a, b), _mm256_unpackhi_epi32(a, b)); } static inline __m256i add2(__m256i a, __m256i b) { return _mm256_add_epi32(_mm256_unpacklo_epi64(a, b), _mm256_unpackhi_epi64(a, b)); } inline __m256i compute(const __m256i * qx, const block_q8_0 * y) const { const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[0].qs)); const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y[1].qs)); const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y[2].qs)); const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y[3].qs)); const __m256i p01 = add1(p0, p1); // 0,1, 0,1, 0,1, 0,1 const __m256i p23 = add1(p2, p3); // 2,3, 2,3, 2,3, 2,3 return add2(p01, p23); // returns 0,1,2,3, 0,1,2,3 } }; struct ScaleHelperQ_0 { ggml_half scales8[4]; template inline __m128 prepare4(const Q * y) { for (int j = 0; j < 4; ++j) scales8[j] = y[j].d; return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8)); } template inline __m128 prepare4(__m128 other_scales, const Q * y) { return _mm_mul_ps(other_scales, prepare4(y)); } template inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); } template inline float prepare1(float d, const Q * y) const { return d*prepare1(y); } }; // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8187 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow template struct ScaleHelperQ_0_1 { ggml_half scales8[4]; template inline __m256 prepare4(const Q * y) { for (int j = 0; j < 4; ++j) scales8[j] = y[j].d; auto s4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8)); return _mm256_set_m128(_mm_mul_ps(s4, min), s4); } template inline __m256 prepare4(__m256 other_scales, const Q * y) { return _mm_mul256_ps(other_scales, prepare4(y)); } template inline std::pair prepare1(const Q * y) const { float d = GGML_FP16_TO_FP32(y->d); return std::make_pair(d, -d*float(min_value)); } std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); } const __m128 min = _mm_set1_ps(float(-min_value)); }; // end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8187 struct ScaleHelperQ_1 { uint32_t scales8[4]; const __m128i shuffle = _mm_set_epi16(0x0f0e, 0x0b0a, 0x0706, 0x0302, 0x0d0c, 0x0908, 0x0504, 0x0100); template inline __m256 prepare4(const Q * y) { for (int j = 0; j < 4; ++j) { // it is slightly faster to directly dereference (const uint32 *)&y[j].d, but some compilers // complain that this breaks strict-aliasing rules. memcpy(scales8 + j, &y[j].d, sizeof(uint32_t)); } return _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *)scales8), shuffle)); } template inline __m256 prepare4(__m256 other_scales, const Q * y) { return _mm256_mul_ps(other_scales, prepare4(y)); } template inline std::pair prepare1(const Q * y) const { return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m)); } template inline std::pair prepare1(const std::pair& dm, const Q * y) const { return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m)); } std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); } }; struct MinusType0 { inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); } inline float compute(float d, int) const { return d; } inline float result(__m256 acc, int) const { return hsum_float_8(acc); } }; template struct MinusType1 { __m128 accm[nrc_y]; MinusType1() { for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm_setzero_ps(); } inline __m256 compute(__m256 dm, int iy) { const __m128 d = _mm256_castps256_ps128(dm); const __m128 m = _mm256_extractf128_ps(dm, 1); accm[iy] = _mm_add_ps(accm[iy], m); return _mm256_set_m128(d, d); } inline float compute(const std::pair& dm, int iy) { accm[iy] = _mm_add_ps(accm[iy], _mm_set1_ps(dm.second*0.25f)); return dm.first; } inline float result(__m256 acc, int iy) const { const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); return hsum_float_4(_mm_add_ps(sum, accm[iy])); } }; template struct AccumT { __m256 acc[nrc_y]; Minus accm; AccumT() { for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = _mm256_setzero_ps(); } template inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, const DataInfo& info, int ix) { auto qx = unp.quants(); __m256 dall[nrc_y]; for (int i = 0; i < nb/4; ++i) { auto other_scales = unp.set_block_4(i); for (int iy = 0; iy < nrc_y; ++iy) { auto s12 = scales.prepare4(other_scales, y[iy] + 4*i); dall[iy] = accm.compute(s12, iy); } for (int iy = 0; iy < nrc_y; ++iy) { auto pall = sum.compute(qx, y[iy] + 4*i); acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]); } } if (!is_multiple_of_4) { for (int i = 4*(nb/4); i < nb; ++i) { auto other_scales = unp.set_block(i); for (int iy = 0; iy < nrc_y; ++iy) { auto s12 = scales.prepare1(other_scales, y[iy] + i); auto d = accm.compute(s12, iy); const __m256i p0 = sum.dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]); } } } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, accm.result(acc[iy], iy)); //s[iy*bs] = accm.result(acc[iy], iy); } } }; template using AccumType0 = AccumT; template using AccumType1 = AccumT, nrc_y, is_multiple_of_4>; using Sum4Type0 = Sum4; using Sum4Type1 = Sum4; template void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) { Unpacker unp(vx, bx); Sum4Type sum4; Scales scales; for (int ix = 0; ix < nrc_x; ++ix) { unp.set_row(ix); AccumType accum; accum.compute(nb, unp, scales, sum4, y, info, ix); } } template void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%Unpacker::block_size() == 0); Q8 q8(info); int nb = n/Unpacker::block_size(); if (nb%4 == 0) { mul_mat_qX_q8_Helper, ScaleHelperQ_0, block_q8_0, nrc_y>( nb, vx, bx, info, q8.y, nrc_x ); } else { mul_mat_qX_q8_Helper, ScaleHelperQ_0, block_q8_0, nrc_y>( nb, vx, bx, info, q8.y, nrc_x ); } } template void mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%Unpacker::block_size() == 0); Q8 q8(info); int nb = n/Unpacker::block_size(); if (nb%4 == 0) { mul_mat_qX_q8_Helper, ScaleHelperQ_1, block_q8_1, nrc_y>( nb, vx, bx, info, q8.y, nrc_x ); } else { mul_mat_qX_q8_Helper, ScaleHelperQ_1, block_q8_1, nrc_y>( nb, vx, bx, info, q8.y, nrc_x ); } } struct Dequantizer4bit { const __m256i m4 = _mm256_set1_epi8(0xf); inline __m256i dequant(const uint8_t * qs) const { const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs); return _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128), m4); } }; struct Q8_0_Dequantizer { inline __m256i dequant(const block_q8_0 * x) const { return _mm256_loadu_si256((const __m256i *)x->qs); } }; // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8455 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow struct Q8_0_1_Dequantizer { inline __m256i dequant(const block_q8_0 * x) const { return _mm256_add_epi8(_mm256_set1_epi8(127), _mm256_loadu_si256((const __m256i *)x->qs)); } }; // end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8455 struct Q4_0_Dequantizer { Dequantizer4bit b4; const __m256i m8 = _mm256_set1_epi8(-8); inline __m256i dequant(const block_q4_0 * x) const { return _mm256_add_epi8(b4.dequant(x->qs), m8); } }; struct Q4_1_Dequantizer { Dequantizer4bit b4; inline __m256i dequant(const block_q4_1 * x) const { return b4.dequant(x->qs); } }; struct HBitDequantizer { const __m256i shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); const __m256i mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); const __m256i minus1 = _mm256_set1_epi64x(-1); inline __m256i to_bytes(const uint8_t * bits) const { // Note: Data in all ggml quants is at least 2-byte aligned. // => we can cast to uint16_t and use or on two consecutive entries // which is faster than memcpy const uint16_t * aux16 = (const uint16_t *)bits; const uint32_t aux32 = aux16[0] | (aux16[1] << 16); //uint32_t aux32; memcpy(&aux32, bits, sizeof(uint32_t)); __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(aux32), shuffle); bytes = _mm256_or_si256(bytes, mask); return _mm256_cmpeq_epi8(bytes, minus1); } }; struct Q5_0_Dequantizer { Dequantizer4bit b4; HBitDequantizer hbit; const __m256i mh = _mm256_set1_epi8((char)0xF0); inline __m256i dequant(const block_q5_0 * x) const { const __m256i vqh = _mm256_andnot_si256(hbit.to_bytes(x->qh), mh); return _mm256_or_si256(b4.dequant(x->qs), vqh); } }; struct Q5_1_Dequantizer { Dequantizer4bit b4; HBitDequantizer hbit; const __m256i mh = _mm256_set1_epi8(0x10); inline __m256i dequant(const block_q5_1 * x) const { const __m256i vqh = _mm256_and_si256(hbit.to_bytes(x->qh), mh); return _mm256_or_si256(b4.dequant(x->qs), vqh); } }; template struct Q_Unpacker { Q_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const Q*)cx_0), bx(bx) {} const char * cx_0; const Q * x; size_t bx; Scales scales; Dequantizer deq; __m256i qx[4]; inline const __m256i* quants() const { return qx; } inline void set_row(int ix) { x = (const Q*)(cx_0 + ix*bx); } inline auto set_block_4(int i) { for (int j = 0; j < 4; ++j) { qx[j] = deq.dequant(x + 4*i + j); } return scales.prepare4(x + 4*i); } inline auto set_block(int i) { qx[0] = deq.dequant(x + i); return scales.prepare1(x + i); } }; struct Q8_0_Unpacker final : public Q_Unpacker { Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} inline static int block_size() { return QK4_0; } }; // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8574 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow struct Q8_0_1_Unpacker final : public Q_Unpacker, Q8_0_1_Dequantizer> { Q8_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} // using Sum4T = Sum4TypeQ81; inline static int block_size() { return QK8_0; } }; // end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8574 struct Q4_0_Unpacker final : public Q_Unpacker { Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} inline static int block_size() { return QK4_0; } }; struct Q5_0_Unpacker final : public Q_Unpacker { Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} inline static int block_size() { return QK5_0; } }; struct Q4_1_Unpacker final : public Q_Unpacker { Q4_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} inline static int block_size() { return QK4_1; } }; struct Q5_1_Unpacker final : public Q_Unpacker { Q5_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} inline static int block_size() { return QK4_1; } }; template void mul_mat_q8_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%Q8_0_Unpacker::block_size() == 0); Q8 q8(info); int nb = n/Q8_0_Unpacker::block_size(); if (nb%4 == 0) { mul_mat_qX_q8_Helper, ScaleHelperQ_0, block_q8_0, nrc_y>( nb, vx, bx, info, q8.y, nrc_x ); } else { mul_mat_qX_q8_Helper, ScaleHelperQ_0, block_q8_0, nrc_y>( nb, vx, bx, info, q8.y, nrc_x ); } } /* moonll add some structs for DequantizerIQ2XXS SimpleBits EvenSignHelper */ struct SimpleBits { __m256i values[4]; }; // fix for #829: Add checks of AVX512VPOPCNTDQ #if defined(HAVE_FANCY_SIMD) && defined(__AVX512VPOPCNTDQ__) #define HAVE_AVX512_POPCNT 1 #else #define HAVE_AVX512_POPCNT 0 #endif // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L7736 // with the addition of a branch that handles a missing _mm256_popcnt_epi32 instruction // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow struct EvenSignHelper { #if defined HAVE_FANCY_SIMD // #pragma message("Using AVX512VPOPCNTDQ in even sign helper") union sbits_t { __m128i vec; __mmask32 mask[4]; }; IQK_ALWAYS_INLINE void sign_2_values(__m256i aux, __m256i * values) const { aux = _mm256_and_si256(_mm256_srlv_epi32(aux, shifts), mask); // fix for #829: Compatibility with processors using Intel Cascade Lake architecture // If AVX512VPOPCNTDQ extension is not supported, use alternative implementation #if HAVE_AVX512_POPCNT auto pcnt = _mm256_popcnt_epi32(aux); #else // Alternative implementation: Using standard bit counting method __m256i pcnt; int* pcnt_ptr = reinterpret_cast(&pcnt); int* aux_ptr = reinterpret_cast(&aux); // Get address of aux directly, avoid unnecessary copies #pragma unroll 8 // Hint compiler to unroll loops, increasing throughput of SIMD computing for (int i = 0; i < 8; i++) { pcnt_ptr[i] = __builtin_popcount(aux_ptr[i]); // Use compiler builtin popcount } #endif sbits_t sbits; sbits.vec = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7))); values[0] = _mm256_mask_sub_epi8(values[0], sbits.mask[0], _mm256_setzero_si256(), values[0]); values[1] = _mm256_mask_sub_epi8(values[1], sbits.mask[1], _mm256_setzero_si256(), values[1]); //auto sign_bits = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7))); //const __mmask32 * m32 = (const __mmask32 *)&sign_bits; //values[0] = _mm256_mask_sub_epi8(values[0], m32[0], _mm256_setzero_si256(), values[0]); //values[1] = _mm256_mask_sub_epi8(values[1], m32[1], _mm256_setzero_si256(), values[1]); } const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0); const __m256i mask = _mm256_set1_epi32(127); const __m256i mone = _mm256_set1_epi32(1); #else inline void sign_value(uint32_t aux32, __m256i& value) const { auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127], keven_signs[(aux32 >> 7) & 127], keven_signs[(aux32 >> 0) & 127]); value = _mm256_sign_epi8(value, signs); } #endif }; /* moonll ad multiply_add for mul_mat_qX_K_q8_K_IQ_1 add func get_scale_shuffle_8 get_scale_shuffle_16 set_scales_16 */ // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1578 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow inline __m256i get_scale_shuffle_8(int i) { return _mm256_set1_epi16((2*i) | ((2*i+1) << 8)); } inline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) { scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0)); scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1)); scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2)); scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3)); } inline __m256i get_scale_shuffle_16(int i) { static const uint8_t k_shuffle[128] = { 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, }; return _mm256_loadu_si256((const __m256i*)k_shuffle + i); } inline void set_scales_16(const __m256i& all_scales, __m256i * scales) { scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0)); scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1)); scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2)); scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3)); } template inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { if (j == 0) { #ifdef HAVE_FANCY_SIMD for (int iy = 0; iy < Q8::nrc_y; ++iy) { sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3))); } #else for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2))); const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3))); sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4)); } #endif } else { #ifdef HAVE_FANCY_SIMD for (int iy = 0; iy < Q8::nrc_y; ++iy) { sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7))); } #else for (int iy = 0; iy < Q8::nrc_y; ++iy) { const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6))); const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7))); sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3)); sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4)); } #endif } } /* moonll ad multiply_add_1 for mul_mat_qX_K_q8_K_IQ_1 add func set_scales_8_iq set_scales_16_iq add MUL_MAT mul_mat_qX_K_q8_K_IQ_1 mul_mat_qX_K_q8_K_IQ_N mul_mat_qX_K_q8_K_IQ */ template inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) { if (j == 0) { #ifdef HAVE_FANCY_SIMD auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]); auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]); auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]); auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]); sumi[0] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_packs_epi32(p1, p2)); sumi[1] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[1], _mm256_packs_epi32(p3, p4)); #else const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0])); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1])); const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2])); const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3])); sumi[0] = _mm256_add_epi32(p1, p3); sumi[1] = _mm256_add_epi32(p2, p4); #endif } else { #ifdef HAVE_FANCY_SIMD auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]); auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]); auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]); auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]); sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[0], _mm256_packs_epi32(p1, p2)); sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[1], _mm256_packs_epi32(p3, p4)); #else const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0])); const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1])); const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2])); const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3])); sumi[0] = _mm256_add_epi32(sumi[0], _mm256_add_epi32(p1, p3)); sumi[1] = _mm256_add_epi32(sumi[1], _mm256_add_epi32(p2, p4)); #endif } } // end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1578 // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L7278 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow inline void set_scales_8_iq(int j, const __m256i& all_scales, __m256i * scales) { //#ifdef HAVE_FANCY_SIMD auto shuffle = j == 0 ? _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100) : _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908); scales[0] = _mm256_shuffle_epi8(all_scales, shuffle); scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(4))); //#else // set_scales_8(all_scales, j, scales); //#endif } inline void set_scales_16_iq(const __m256i& all_scales, __m256i * scales) { #ifdef HAVE_FANCY_SIMD auto shuffle = _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100); scales[0] = _mm256_shuffle_epi8(all_scales, shuffle); scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(8))); #else set_scales_16(all_scales, scales); #endif } // end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L7278 // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L7299 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow template static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_K; Q8<1> q8(info); Dequantizer deq(vx, bx); __m256i scales[2]; __m256i q8_quants[4]; for (int ix = 0; ix < nrc_x; ++ix) { __m256 accd = _mm256_setzero_ps(); deq.new_row(ix); for (int i = 0; i < nb; ++i) { __m256i sumi[2], all_scales[Dequantizer::num_blocks/8]; deq.new_block(i, all_scales); for (int j = 0; j < QK_K/128; ++j) { deq.prepare(i, j, q8, q8_quants); if constexpr (Dequantizer::num_blocks == 8) { set_scales_8_iq(j, all_scales[0], scales); } else { set_scales_16_iq(all_scales[j], scales); } multiply_add_1(j, deq.bits, scales, q8_quants, sumi); } accd = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(0, i)), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi[0], sumi[1])), accd); } info.store(ix, 0, hsum_float_8(accd)); } } template static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx); __m256i scales[4]; __m256 accd[nrc_y]; for (int ix = 0; ix < nrc_x; ++ix) { for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); deq.new_row(ix); for (int i = 0; i < nb; ++i) { __m256i sumi[nrc_y], all_scales[Dequantizer::num_blocks/8]; //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256(); __m256i mins; float dmin = deq.new_block(i, all_scales, mins); for (int iy = 0; iy < nrc_y; ++iy) { auto bsums = q8.load_bsums(iy, i); auto prod = _mm256_madd_epi16(mins, bsums); accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]); } for (int j = 0; j < QK_K/128; ++j) { deq.prepare(i, j); if constexpr (Dequantizer::num_blocks == 8) { set_scales_8(all_scales[0], j, scales); } else { set_scales_16(all_scales[j], scales); } //multiply_add_iq(deq.bits, scales, j, i, q8, sumi); multiply_add(deq.bits, scales, j, i, q8, sumi); } for (int iy = 0; iy < nrc_y; ++iy) { const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i)); accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, hsum_float_8(accd[iy])); } } } template static void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); #ifdef HAVE_FANCY_SIMD if constexpr (nrc_y == 1) { mul_mat_qX_K_q8_K_IQ_1(n, vx, bx, info, nrc_x); } else { mul_mat_qX_K_q8_K_IQ_N(n, vx, bx, info, nrc_x); } #else mul_mat_qX_K_q8_K_IQ_N(n, vx, bx, info, nrc_x); #endif } // end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L7299 /* moonll iq1s core func for iq1s mul_mat_iq1_s_q8_K */ // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L3813 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow template static void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(n%QK_K == 0); Q8 q8(info); __m256i qx[8]; __m256i scales[4]; __m256 acc[nrc_y] = {}; auto delta_mask = _mm_set1_epi16(-32768); // to avoid stupid overflow warnings when using 0x8000 __m256i shuffle0 = _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100); for (int ix = 0; ix < nrc_x; ++ix) { auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx); for (int ibl = 0; ibl < n/QK_K; ++ibl) { float d = GGML_FP16_TO_FP32(iq1s[ibl].d); auto qhb = _mm_loadu_si128((const __m128i *)iq1s[ibl].qh); auto scales128 = _mm_and_si128(_mm_srli_epi16(qhb, 12), _mm_set1_epi16(7)); scales128 = _mm_add_epi16(_mm_slli_epi16(scales128, 1), _mm_set1_epi16(1)); #ifdef HAVE_FANCY_SIMD auto mask = _mm_cmpeq_epi16_mask(_mm_and_si128(qhb, delta_mask), delta_mask); auto deltas128 = _mm_mask_blend_epi16(mask, _mm_set1_epi16(-7), _mm_set1_epi16(-9)); #else auto mask = _mm_cmpeq_epi16(_mm_and_si128(qhb, delta_mask), delta_mask); auto deltas128 = _mm_or_si128(_mm_and_si128(mask, _mm_set1_epi16(-9)), _mm_andnot_si128(mask, _mm_set1_epi16(-7))); #endif deltas128 = _mm_mullo_epi16(scales128, deltas128); scales128 = _mm_slli_epi16(scales128, 3); auto deltas_l = _mm_unpacklo_epi16(deltas128, deltas128); auto deltas_h = _mm_unpackhi_epi16(deltas128, deltas128); auto deltas = MM256_SET_M128I(deltas_h, deltas_l); // blocks 0,0, 1,1, 2,2, ..., 7,7 auto all_scales = MM256_SET_M128I(scales128, scales128); auto shuffle = shuffle0; for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { scales[ib64] = _mm256_shuffle_epi8(all_scales, shuffle); shuffle = _mm256_add_epi8(shuffle, _mm256_set1_epi8(4)); } const uint8_t * qs = iq1s[ibl].qs; const uint16_t * qh = iq1s[ibl].qh; for (int ib = 0; ib < QK_K/32; ib += 2) { qx[ib+0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid_us[qs[2] | ((qh[ib+0] << 2) & 0x700)], iq1s_grid_us[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid_us[qs[0] | ((qh[ib+0] << 8) & 0x700)]); qx[ib+1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid_us[qs[6] | ((qh[ib+1] << 2) & 0x700)], iq1s_grid_us[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid_us[qs[4] | ((qh[ib+1] << 8) & 0x700)]); qs += 8; } for (int iy = 0; iy < nrc_y; ++iy) { auto bsums = q8.load_bsums(iy, ibl); auto sumi = _mm256_setzero_si256(); for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { auto qy1 = q8.load_quants(iy, ibl, 2*ib64+0); auto qy2 = q8.load_quants(iy, ibl, 2*ib64+1); #ifdef HAVE_FANCY_SIMD auto dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+0], qy1); auto dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+1], qy2); sumi = _mm256_dpwssd_epi32(sumi, scales[ib64], _mm256_packs_epi32(dot1, dot2)); #else auto dot1 = _mm256_maddubs_epi16(qx[2*ib64+0], qy1); auto dot2 = _mm256_maddubs_epi16(qx[2*ib64+1], qy2); auto dot = _mm256_add_epi16(_mm256_unpacklo_epi64(dot1, dot2), _mm256_unpackhi_epi64(dot1, dot2)); sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(scales[ib64], dot)); #endif } #ifdef HAVE_FANCY_SIMD sumi = _mm256_dpwssd_epi32(sumi, bsums, deltas); #else sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(bsums, deltas)); #endif acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d*q8.scale(iy, ibl)), _mm256_cvtepi32_ps(sumi), acc[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, 0.125f*hsum_float_8(acc[iy])); acc[iy] = _mm256_setzero_ps(); } } } // end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L3813 /* moonll iq1s DequantizerIQ2XXS DequantizerIQ2XXS is important Dequantizer for DequantizerIQ1_S */ // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8035 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow struct DequantizerIQ2XXS final : public BaseDequantizer { DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} constexpr static int num_blocks = 8; union Data { __m256i vec; uint32_t val[8]; }; inline __m128i load_scales(int i) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); const uint16_t * a16 = (const uint16_t *)x[i].qs; auto scales = _mm_srli_epi16(_mm_set_epi16(a16[31], a16[27], a16[23], a16[19], a16[15], a16[11], a16[7], a16[3]), 12); return _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi16(1)); } inline void new_block(int i, __m256i * scales) { auto sc16 = load_scales(i); scales[0] = MM256_SET_M128I(sc16, sc16); } inline float new_block(int i, __m256i * scales, __m256i& mins) { auto sc16 = load_scales(i); mins = scb.shuffle(sc16); scales[0] = MM256_SET_M128I(sc16, sc16); return -d*minv; } inline static void make4(const uint32_t * aux32, __m256i * values) { const uint8_t * aux8 = (const uint8_t *)aux32; values[0] = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[ 1]], iq2xxs_grid[aux8[ 0]]); values[1] = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[ 9]], iq2xxs_grid[aux8[ 8]]); values[2] = _mm256_set_epi64x(iq2xxs_grid[aux8[19]], iq2xxs_grid[aux8[18]], iq2xxs_grid[aux8[17]], iq2xxs_grid[aux8[16]]); values[3] = _mm256_set_epi64x(iq2xxs_grid[aux8[27]], iq2xxs_grid[aux8[26]], iq2xxs_grid[aux8[25]], iq2xxs_grid[aux8[24]]); } IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const { #ifdef HAVE_FANCY_SIMD esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0); esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2); #else esh.sign_value(aux32[1], values[0]); esh.sign_value(aux32[3], values[1]); esh.sign_value(aux32[5], values[2]); esh.sign_value(aux32[7], values[3]); #endif } inline void make4_signed(const uint32_t * aux32, const __m256i& min_value, __m256i * values) const { make4(aux32, values); sign_values(aux32, values); for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value); } inline void make4(const uint32_t * aux32, __m256i * values, __m256i * q8) const { make4(aux32, values); sign_values(aux32, q8); } inline void prepare(int i, int j) { Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j); make4_signed(data.val, min_value, bits.values); } inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) { for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k); Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j); make4(data.val, bits.values, q8_quants); } constexpr static int minv = 43; SimpleBits bits; Scales8KBase scb; EvenSignHelper esh; const __m256i min_value = _mm256_set1_epi8(minv); const __m256i shuffle = _mm256_set_epi32(7, 5, 3, 1, 7, 5, 3, 1); }; /* moonll add Q8_0_Unpacker && DequantizerIQ2XXS support add func mul_mat_qX_K_q8_K_IQ */ // Copied/adapted from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L9092 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow template void MulMat::set_functions(MulMat& m) { if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qX_0_q8_0_T; m.funcs[1] = mul_mat_qX_0_q8_0_T; m.funcs[2] = mul_mat_qX_0_q8_0_T; m.funcs[3] = mul_mat_qX_0_q8_0_T; m.funcs[4] = mul_mat_qX_0_q8_0_T; m.funcs[5] = mul_mat_qX_0_q8_0_T; m.funcs[6] = mul_mat_qX_0_q8_0_T; m.funcs[7] = mul_mat_qX_0_q8_0_T; } else if constexpr (std::is_same_v || std::is_same_v|| std::is_same_v) { m.funcs[0] = mul_mat_qX_1_q8_1_T; m.funcs[1] = mul_mat_qX_1_q8_1_T; m.funcs[2] = mul_mat_qX_1_q8_1_T; m.funcs[3] = mul_mat_qX_1_q8_1_T; m.funcs[4] = mul_mat_qX_1_q8_1_T; m.funcs[5] = mul_mat_qX_1_q8_1_T; m.funcs[6] = mul_mat_qX_1_q8_1_T; m.funcs[7] = mul_mat_qX_1_q8_1_T; } else if constexpr (std::is_same_v) { m.funcs[0] = mul_mat_qX_K_q8_K_IQ; m.funcs[1] = mul_mat_qX_K_q8_K_IQ; m.funcs[2] = mul_mat_qX_K_q8_K_IQ; m.funcs[3] = mul_mat_qX_K_q8_K_IQ; m.funcs[4] = mul_mat_qX_K_q8_K_IQ; m.funcs[5] = mul_mat_qX_K_q8_K_IQ; m.funcs[6] = mul_mat_qX_K_q8_K_IQ; m.funcs[7] = mul_mat_qX_K_q8_K_IQ; } else { #ifdef HAVE_FANCY_SIMD if constexpr (std::is_same_v) { m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[3] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[4] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512; } else { m.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1; m.funcs[1] = mul_mat_qX_K_q8_K_AVX512; m.funcs[2] = mul_mat_qX_K_q8_K_AVX512; m.funcs[3] = mul_mat_qX_K_q8_K_AVX512; m.funcs[4] = mul_mat_qX_K_q8_K_AVX512; m.funcs[5] = mul_mat_qX_K_q8_K_AVX512; m.funcs[6] = mul_mat_qX_K_q8_K_AVX512; m.funcs[7] = mul_mat_qX_K_q8_K_AVX512; } #else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qY_K_q8_K_T; m.funcs[1] = mul_mat_qY_K_q8_K_T; m.funcs[2] = mul_mat_qY_K_q8_K_T; m.funcs[3] = mul_mat_qY_K_q8_K_T; m.funcs[4] = mul_mat_qY_K_q8_K_T; m.funcs[5] = mul_mat_qY_K_q8_K_T; m.funcs[6] = mul_mat_qY_K_q8_K_T; m.funcs[7] = mul_mat_qY_K_q8_K_T; } else { m.funcs[0] = mul_mat_qX_K_q8_K_T; m.funcs[1] = mul_mat_qX_K_q8_K_T; m.funcs[2] = mul_mat_qX_K_q8_K_T; m.funcs[3] = mul_mat_qX_K_q8_K_T; m.funcs[4] = mul_mat_qX_K_q8_K_T; m.funcs[5] = mul_mat_qX_K_q8_K_T; m.funcs[6] = mul_mat_qX_K_q8_K_T; m.funcs[7] = mul_mat_qX_K_q8_K_T; } #endif } } // end copied/adapted from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L9092 // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8622 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow struct QFBase { #ifdef __AVX512F__ constexpr static int k_step = 16; using Data = __m512; using Acc = __m512; static inline Data load(const ggml_half * x) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x)); } static inline Data load(const float * x) { return _mm512_loadu_ps(x); } static inline Data load(const ggml_bf16_t * x) { return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)x)), 16)); } static inline Acc acc(Acc prev, const Data& y, const Data& x) { return _mm512_fmadd_ps(y, x, prev); } static inline Acc acc_first(const Data& y, const Data& x) { return _mm512_mul_ps(y, x); } static inline Acc add(Acc x, Acc y) { return _mm512_add_ps(x, y); } static inline float hsum(Acc acc) { return _mm512_reduce_add_ps(acc); } template static inline Data load4Floats(const Float * x) { return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0); } static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) { acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc); acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc); acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc); acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc); return acc; } static inline Acc acc_r4_first(const Data * xv, const Data& yv) { auto acc = _mm512_mul_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00)); acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc); acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc); acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc); return acc; } static inline __m128 hsum_r4(Acc acc) { auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 0), _mm512_extractf32x4_ps(acc, 1)); auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 2), _mm512_extractf32x4_ps(acc, 3)); return _mm_add_ps(sum1, sum2); } #else constexpr static int k_step = 8; using Data = __m256; using Acc = __m256; static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); } static inline Data load(const float * x) { return _mm256_loadu_ps(x); } static inline Data load(const ggml_bf16_t * x) { return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)x)), 16)); } static inline Acc acc(Acc prev, const Data& y, const Data& x) { return _mm256_fmadd_ps(y, x, prev); } static inline Acc add(Acc x, Acc y) { return _mm256_add_ps(x, y); } static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) { acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc); acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc); acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc); acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc); return acc; } static inline Acc acc_r4_first(const Data * xv, const Data& yv) { auto acc = _mm256_mul_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00)); acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc); acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc); acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc); return acc; } static inline Acc acc_first(const Data& y, const Data& x) { return _mm256_mul_ps(y, x); } static inline float hsum(Acc acc) { return hsum_float_8(acc); } static inline __m128 hsum_r4(Acc acc) { return _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); } template static inline Data load4Floats(const Float * x) { return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0); } #endif static inline __m128 load128(const ggml_half * x) { return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x)); } static inline __m128 load128(const float * x) { return _mm_loadu_ps(x); } static inline __m128 load128(const ggml_bf16_t * x) { return _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*)x)), 16)); } }; template struct QFT final : public QFBase { constexpr static int nrc = nrc_in; QFT(const DataInfo& info) { for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)info.src1_row(iy); } QFT(const char * cx, size_t bx) { for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)(cx + iy*bx); } IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); } IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); } IQK_ALWAYS_INLINE void load_r4(int ix, int i, Data * xv) const { xv[0] = load1(ix+0, i); xv[1] = load1(ix+1, i); xv[2] = load1(ix+2, i); xv[3] = load1(ix+3, i); #ifdef __AVX512F__ auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]); auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]); auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]); auto t3 = _mm512_unpackhi_ps(xv[2], xv[3]); xv[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1))); xv[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1))); xv[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3))); xv[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3))); #else auto t0 = _mm256_unpacklo_ps(xv[0], xv[1]); auto t1 = _mm256_unpacklo_ps(xv[2], xv[3]); auto t2 = _mm256_unpackhi_ps(xv[0], xv[1]); auto t3 = _mm256_unpackhi_ps(xv[2], xv[3]); xv[0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1))); xv[1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1))); xv[2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3))); xv[3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3))); #endif } const Float * y[nrc]; }; template IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { int nb = n/QFBase::k_step; int nb4 = n/4; Qy y(info); Qx x(cx + ix0*bx, bx); QFBase::Data xv[Qx::nrc]; QFBase::Acc acc[Qx::nrc*Qy::nrc]; auto yv = y.load1(0, 0); for (int ix = 0; ix < Qx::nrc; ++ix) { xv[ix] = x.load1(ix, 0); acc[ix] = QFBase::acc_first(yv, xv[ix]); } for (int iy = 1; iy < Qy::nrc; ++iy) { yv = y.load1(iy, 0); for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]); } for (int i = 1; i < nb; ++i) { yv = y.load1(0, i); for (int ix = 0; ix < Qx::nrc; ++ix) { xv[ix] = x.load1(ix, i); acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]); } for (int iy = 1; iy < Qy::nrc; ++iy) { yv = y.load1(iy, i); for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]); } } for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) { yv = y.load_tail(0, i); for (int ix = 0; ix < Qx::nrc; ++ix) { xv[ix] = x.load_tail(ix, i); acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]); } for (int iy = 1; iy < Qy::nrc; ++iy) { yv = y.load_tail(iy, i); for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]); } } for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix])); } // This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done // in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in // f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now. template void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const char * cx = (const char *)vx; // TBD if we want this //if constexpr (nrc_y == 1) { // constexpr int k_nx = 2; // for (int ix = 0; ix < nrc_x/k_nx; ++ix) { // mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, ix*k_nx, info); // } // if (int lastx = k_nx*(nrc_x/k_nx); lastx < nrc_x) { // int nx = nrc_x - lastx; // switch (nx) { // case 1: mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, lastx, info); break; // case 2: mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, lastx, info); break; // case 3: mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, lastx, info); break; // } // //mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, lastx, info); // } // return; //} #ifdef __AVX512F__ constexpr int k_nx = 5; #else constexpr int k_nx = nrc_y == 1 ? 4 : 2; #endif for (int ix = 0; ix < nrc_x/k_nx; ++ix) { mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, ix*k_nx, info); } int last_x = k_nx*(nrc_x/k_nx); if (last_x == nrc_x) return; int nx = nrc_x - last_x; #ifdef __AVX512F__ switch (nx) { case 1: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; case 2: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; case 3: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; case 4: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; } #else if constexpr (nrc_y == 1) { switch (nx) { case 1: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; case 2: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; case 3: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; } } else { switch (nx) { case 1: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; } } #endif } template void set_mul_mat_f(MulMat& mm) { for (auto& f : mm.funcs) f = nullptr; mm.funcs[0] = mul_mat_fX_fY_T<1, FloatX, FloatY>; mm.funcs[1] = mul_mat_fX_fY_T<2, FloatX, FloatY>; mm.funcs[2] = mul_mat_fX_fY_T<3, FloatX, FloatY>; mm.funcs[3] = mul_mat_fX_fY_T<4, FloatX, FloatY>; mm.funcs[4] = mul_mat_fX_fY_T<5, FloatX, FloatY>; #ifndef __AVX512F__ mm.funcs[5] = mul_mat_fX_fY_T<6, FloatX, FloatY>; #endif } // end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8622 /* moonll add typeb TO compare return not expected type of weight matrix add IQ2XSS add IQ1_S add GGML_TYPE_IQ4_XS */ // Modifications extracted from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L9231 // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow bool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { (void)Ny; auto expected_typeB = GGML_TYPE_Q8_K; switch (typeA) { case GGML_TYPE_Q2_K: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; case GGML_TYPE_Q3_K: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; case GGML_TYPE_Q4_K: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; case GGML_TYPE_Q5_K: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; case GGML_TYPE_Q6_K: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; case GGML_TYPE_IQ4_XS: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; case GGML_TYPE_IQ2_XXS: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; case GGML_TYPE_Q4_0: assert (ne00 % QK4_0 == 0); MulMat::set_functions(mm); expected_typeB = GGML_TYPE_Q8_0; break; case GGML_TYPE_Q4_1: assert (ne00 % QK4_1 == 0); MulMat::set_functions(mm); expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q5_0: assert (ne00 % QK5_0 == 0); MulMat::set_functions(mm); expected_typeB = GGML_TYPE_Q8_0; break; case GGML_TYPE_Q5_1: assert (ne00 % QK5_1 == 0); MulMat::set_functions(mm); expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q8_0: assert (ne00 % QK8_0 == 0); #ifdef HAVE_FANCY_SIMD MulMat::set_functions(mm); expected_typeB = GGML_TYPE_Q8_1_X4; #else MulMat::set_functions(mm); expected_typeB = GGML_TYPE_Q8_0_X4; #endif break; case GGML_TYPE_IQ1_S: mm.funcs[0] = mul_mat_iq1_s_q8_K<1>; mm.funcs[1] = mul_mat_iq1_s_q8_K<2>; mm.funcs[2] = mul_mat_iq1_s_q8_K<3>; mm.funcs[3] = mul_mat_iq1_s_q8_K<4>; mm.funcs[4] = mul_mat_iq1_s_q8_K<5>; mm.funcs[5] = mul_mat_iq1_s_q8_K<6>; mm.funcs[6] = mul_mat_iq1_s_q8_K<7>; mm.funcs[7] = mul_mat_iq1_s_q8_K<8>; #ifdef HAVE_FANCY_SIMD mm.func16 = mul_mat_iq1_s_q8_K<16>; #endif // row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); expected_typeB = GGML_TYPE_Q8_K; break; default: { // printf("case:%d",typeA); return false; } } return ggml_type(typeB) == expected_typeB; } // end extracted from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L9231 } // namespace /* iq1_s is not support for arm */ #else // __aarch64__ namespace { template struct Q8 { constexpr static int nrc_y = nrc; Q8(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy); } inline int8x16_t load_quants_16(int iy, int i, int j) const { return vld1q_s8(y[iy][i].qs + 16*j); } inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); } inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); } inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); } inline int16x8_t load_bsums8(int iy, int i) const { auto q8s = vld1q_s16_x2(y[iy][i].bsums); return vpaddq_s16(q8s.val[0], q8s.val[1]); } inline float scale(int iy, int i) const { return y[iy][i].d; } const block_q8 * y[nrc_y]; }; template IQK_NOINLINE void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx, nrc_y); for (int ix = 0; ix < nrc_x; ++ix) { deq.new_row(ix); float32x4_t acc[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); //#pragma GCC unroll 4 for (int i = 0; i < nb; ++i) { int32x4_t sumi[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) { deq.process_scales(i, q8, acc); deq.prepare(i, 0); deq.compute(q8, i, 0, sumi); deq.prepare(i, 1); deq.compute(q8, i, 1, sumi); } else { if constexpr (Dequantizer::num_blocks() == 8) { auto scales = deq.new_block(i, q8, acc); deq.prepare(i, 0); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); deq.prepare(i, 1); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); } else if constexpr (Dequantizer::num_blocks() == 16) { auto scales = deq.new_block(i, q8, acc); deq.prepare(i, 0); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); deq.prepare(i, 1); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); } else { GGML_ASSERT(false); } } #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); } } #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, vaddvq_f32(acc[iy])); } } } template IQK_NOINLINE void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx, nrc_y); for (int ix = 0; ix < nrc_x; ++ix) { deq.new_row(ix); float32x4_t acc[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); for (int i = 0; i < nb; ++i) { int32x4_t sumi[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); if constexpr (Dequantizer::num_blocks() == 8) { auto scales = deq.new_block(i); deq.prepare(i, 0); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); deq.prepare(i, 1); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); } else if constexpr (Dequantizer::num_blocks() == 16) { auto scales = deq.new_block(i); deq.prepare(i, 0); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); deq.prepare(i, 1); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); } else { GGML_ASSERT(false); } #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); } } #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, vaddvq_f32(acc[iy])); } } } template IQK_ALWAYS_INLINE void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) { auto mzero = vdupq_n_s32(0); const int8x16_t * qs_1 = (const int8x16_t *)qx_1.val; const int8x16_t * qs_2 = (const int8x16_t *)qx_2.val; auto q8b_1 = q8.load_quants(iy, i, 4*j+0); auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_1[0], q8b_1.val[0]), qs_1[1], q8b_1.val[1]); // block 1 auto q8b_2 = q8.load_quants(iy, i, 4*j+1); auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_1[2], q8b_2.val[0]), qs_1[3], q8b_2.val[1]); // block 2 auto p12 = vpaddq_s32(p1, p2); auto q8b_3 = q8.load_quants(iy, i, 4*j+2); auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_2[0], q8b_3.val[0]), qs_2[1], q8b_3.val[1]); // block 3 auto q8b_4 = q8.load_quants(iy, i, 4*j+3); auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_2[2], q8b_4.val[0]), qs_2[3], q8b_4.val[1]); // block 4 auto p34 = vpaddq_s32(p3, p4); auto pall = vpaddq_s32(p12, p34); sumi = vmlaq_s32(sumi, scales.val[j], pall); } template IQK_ALWAYS_INLINE void compute_8_blocks(const int8x16_t * qx, const Q8& q8, const int32x4_t& scales, int iy, int i, int j, int32x4_t& sumi) { auto mzero = vdupq_n_s32(0); auto q8b_1 = q8.load_quants(iy, i, 4*j+0); auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[0], q8b_1.val[0]), qx[1], q8b_1.val[1]); // block 1 auto q8b_2 = q8.load_quants(iy, i, 4*j+1); auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[2], q8b_2.val[0]), qx[3], q8b_2.val[1]); // block 2 auto p12 = vpaddq_s32(p1, p2); auto q8b_3 = q8.load_quants(iy, i, 4*j+2); auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[4], q8b_3.val[0]), qx[5], q8b_3.val[1]); // block 3 auto q8b_4 = q8.load_quants(iy, i, 4*j+3); auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[6], q8b_4.val[0]), qx[7], q8b_4.val[1]); // block 4 auto p34 = vpaddq_s32(p3, p4); auto pall = vpaddq_s32(p12, p34); sumi = vmlaq_s32(sumi, scales, pall); } template IQK_ALWAYS_INLINE void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) { auto mzero = vdupq_n_s32(0); auto q8b_1 = q8.load_quants(iy, i, 4*j+0); auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]), ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1, auto q8b_2 = q8.load_quants(iy, i, 4*j+1); auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]), ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4, auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3 sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12); auto q8b_3 = q8.load_quants(iy, i, 4*j+2); auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]), ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5, auto q8b_4 = q8.load_quants(iy, i, 4*j+3); auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]), ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7, auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7 sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34); } template inline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) { for (int iy = 0; iy < Q8::nrc_y; ++iy) { auto q8s = q8.load_bsums8(iy, i); int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s)); int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s)); float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2)); acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i))); } } template inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) { for (int iy = 0; iy < Q8::nrc_y; ++iy) { auto q8s = q8.load_bsums(iy, i); int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0])); int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0])); int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1])); int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1])); float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4))); acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i))); } } struct Scales8 { uint32_t utmp[4]; const uint8_t * sc8 = (const uint8_t *)utmp; template inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) { make_q4_scales(x.scales, utmp); int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8)); accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin)); uint8x8_t scales8 = vld1_u8(sc8); uint16x8_t scales16 = vmovl_u8(scales8); int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))), vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))}; return scales; } }; struct Q4bits { const uint8x16_t m4b = vdupq_n_u8(0xf); uint8x16x4_t b1, b2; inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const { b.val[0] = vandq_u8(val[0], m4b); b.val[2] = vshrq_n_u8(val[0], 4); b.val[1] = vandq_u8(val[1], m4b); b.val[3] = vshrq_n_u8(val[1], 4); } inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const { b.val[0] = vandq_u8(val[0], m4b); b.val[1] = vshrq_n_u8(val[0], 4); b.val[2] = vandq_u8(val[1], m4b); b.val[3] = vshrq_n_u8(val[1], 4); } inline void prepare(const uint8_t * qs) { auto q4bits = vld1q_u8_x2(qs); prepare4(b1, q4bits.val); q4bits = vld1q_u8_x2(qs+32); prepare4(b2, q4bits.val); } inline void prepare_v2(const uint8_t * qs) { auto q4bits = vld1q_u8_x4(qs); prepare4(b1, q4bits.val+0); prepare4(b2, q4bits.val+2); } inline void prepare64(const uint8_t * qs) { auto q4bits = vld1q_u8_x4(qs); b1.val[0] = vandq_u8(q4bits.val[0], m4b); b1.val[1] = vandq_u8(q4bits.val[1], m4b); b1.val[2] = vandq_u8(q4bits.val[2], m4b); b1.val[3] = vandq_u8(q4bits.val[3], m4b); b2.val[0] = vshrq_n_u8(q4bits.val[0], 4); b2.val[1] = vshrq_n_u8(q4bits.val[1], 4); b2.val[2] = vshrq_n_u8(q4bits.val[2], 4); b2.val[3] = vshrq_n_u8(q4bits.val[3], 4); } inline void prepare16(const uint8_t * qs) { auto q4bits = vld1q_u8_x2(qs); prepare4_16(b1, q4bits.val); q4bits = vld1q_u8_x2(qs+32); prepare4_16(b2, q4bits.val); } inline void prepare16_v2(const uint8_t * qs) { auto q4bits = vld1q_u8_x4(qs); prepare4_16(b1, q4bits.val+0); prepare4_16(b2, q4bits.val+2); } }; struct Q2bits { const uint8x16_t m4b = vdupq_n_u8(0x03); uint8x16x4_t b1, b2; inline void prepare(const uint8_t * qs) { auto q2bits = vld1q_u8_x2(qs); b1.val[0] = vandq_u8(q2bits.val[0], m4b); b1.val[1] = vandq_u8(q2bits.val[1], m4b); q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); b1.val[2] = vandq_u8(q2bits.val[0], m4b); b1.val[3] = vandq_u8(q2bits.val[1], m4b); q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); b2.val[0] = vandq_u8(q2bits.val[0], m4b); b2.val[1] = vandq_u8(q2bits.val[1], m4b); q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); b2.val[2] = vandq_u8(q2bits.val[0], m4b); b2.val[3] = vandq_u8(q2bits.val[1], m4b); } }; template struct BaseDequantizer { BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {} inline void new_row(int ix) { x = (const block_q *)((const char *)vx + ix*bx); } const void * vx; const block_q * x; const size_t bx; const int nrc; }; struct DequantizerQ4K final : public BaseDequantizer { DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 8; } constexpr static bool should_scale_quants() { return false; } template inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { d = GGML_FP16_TO_FP32(x[i].d); return s8.process_scales_mins(x[i], q8, i, acc); } inline void prepare(int i, int j) { if (nrc == 1) bits.prepare_v2(x[i].qs+64*j); else bits.prepare(x[i].qs+64*j); } Q4bits bits; Scales8 s8; float d; }; struct HighBit5 { const uint8x16_t mhb = vdupq_n_u8(0x10); uint8x16x2_t bits; inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) { b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb)); b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb)); b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb)); b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb)); b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb)); b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb)); b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb)); b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb)); if (do_shift) { bits.val[0] = vshrq_n_u8(bits.val[0], 4); bits.val[1] = vshrq_n_u8(bits.val[1], 4); } } }; struct HighBit3 { const uint8x16_t mhb = vdupq_n_u8(0x04); uint8x16x2_t bits; inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) { b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb)); b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb)); b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb)); b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb)); b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb)); b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb)); b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb)); b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb)); if (do_shift) { bits.val[0] = vshrq_n_u8(bits.val[0], 4); bits.val[1] = vshrq_n_u8(bits.val[1], 4); } } }; struct DequantizerQ5K final : public BaseDequantizer { DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 8; } constexpr static bool should_scale_quants() { return false; } template inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { d = GGML_FP16_TO_FP32(x[i].d); h.bits = vld1q_u8_x2(x[i].qh); return s8.process_scales_mins(x[i], q8, i, acc); } inline void prepare(int i, int j) { bits.prepare(x[i].qs+64*j); h.apply(bits.b1, bits.b2, j == 0); } Q4bits bits; HighBit5 h; Scales8 s8; uint8x16x2_t hbits; float d; }; inline int32x4x4_t make_wider(const int16x8x2_t& scales16) { int32x4x4_t scales = { vmovl_s16(vget_low_s16 (scales16.val[0])), vmovl_s16(vget_high_s16(scales16.val[0])), vmovl_s16(vget_low_s16 (scales16.val[1])), vmovl_s16(vget_high_s16(scales16.val[1])), }; return scales; } template inline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) { int16x8x2_t scales16; scales16.val[0] = vmovl_s8(vget_low_s8(scales8)); scales16.val[1] = vmovl_s8(vget_high_s8(scales8)); accum_mins_16(scales16, q8, acc, i, c); return make_wider(scales16); } struct DequantizerQ6K final : public BaseDequantizer { DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 16; } constexpr static bool should_scale_quants() { return false; } template inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { d = GGML_FP16_TO_FP32(x[i].d); return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d); } inline void prepare(int i, int j) { auto hbits = vld1q_u8_x2(x[i].qh + 32*j); bits.prepare64(x[i].ql+64*j); bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb)); bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb)); bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb)); bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb)); bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb)); bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb)); bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb)); bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb)); } Q4bits bits; const uint8x16_t mhb = vdupq_n_u8(0x30); float d; }; struct DequantizerQ3K final : public BaseDequantizer { DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 16; } constexpr static bool should_scale_quants() { return false; } template inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { d = GGML_FP16_TO_FP32(x[i].d); h.bits = vld1q_u8_x2(x[i].hmask); const uint16_t * sc16 = (const uint16_t *)x[i].scales; uint32_t aux0 = sc16[0] | (sc16[1] << 16); uint32_t aux1 = sc16[2] | (sc16[3] << 16); uint32_t aux2 = sc16[4] | (sc16[5] << 16); aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030); aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030); aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030); aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030); return process_scales_mins_16(vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)), q8, acc, i, -4.f*d); } inline void prepare(int i, int j) { bits.prepare(x[i].qs+32*j); h.apply(bits.b1, bits.b2, j == 0); } uint32_t aux32[4]; Q2bits bits; HighBit3 h; float d; }; struct DequantizerQ2K final : public BaseDequantizer { DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 16; } constexpr static bool should_scale_quants() { return true; } template inline void process_scales(int i, const Q8& q8, float32x4_t * acc) { d = GGML_FP16_TO_FP32(x[i].d); auto scales_and_mins = vld1q_u8(x[i].scales); auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4)); int16x8x2_t scales16; scales16.val[0] = vmovl_s8(vget_low_s8(mins8)); scales16.val[1] = vmovl_s8(vget_high_s8(mins8)); accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin)); scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf)); } template inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { process_scales(i, q8, acc); int16x8x2_t scales16; scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8))); scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8))); return make_wider(scales16); } template inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { auto m1 = vdupq_n_u8(1); auto shuffle = vdupq_n_u8(8*j); bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); for (int iy = 0; iy < Q8::nrc_y; ++iy) { auto q8b_1 = q8.load_quants(iy, i, 4*j+0); sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); auto q8b_2 = q8.load_quants(iy, i, 4*j+1); sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); auto q8b_3 = q8.load_quants(iy, i, 4*j+2); sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]), vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]); auto q8b_4 = q8.load_quants(iy, i, 4*j+3); sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]), vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]); } } inline void prepare(int i, int j) { bits.prepare(x[i].qs+32*j); } uint32_t aux32[4]; uint8x16_t scales8; Q2bits bits; float d; }; // ============================= i-quants struct DequantizerIQ4XS final : public BaseDequantizer { static int8x16_t load_values() { static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; return vld1q_s8(iq4nl_values); } DequantizerIQ4XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(load_values()) {} constexpr static int num_blocks() { return 8; } constexpr static bool should_scale_quants() { return false; } inline void new_row(int ix) { x = (const block_iq4_xs *)((const char *)vx + bx*ix); } template inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { (void)q8; (void)acc; d = GGML_FP16_TO_FP32(x[i].d); const uint16_t scales_h = x[i].scales_h; const uint16_t * scales_l = (const uint16_t *)x[i].scales_l; aux32[0] = scales_l[0] | (scales_l[1] << 16); aux32[1] = aux32[0] >> 4; // scl is ordered as 0, 2, 4, 6, 1, 3, 5, 7 uint8x8_t scl8 = vand_u8(vld1_u8((const uint8_t *)aux32), vdup_n_u8(0xf)); uint16_t * aux16 = (uint16_t *)aux32; aux16[0] = scales_h << 4; aux16[1] = scales_h << 2; aux16[2] = scales_h; aux16[3] = scales_h >> 2; // sch is ordered as 0, 4, 1, 5, 2, 6, 3, 7 uint8x8_t sch8 = vand_u8(vld1_u8((const uint8_t *)aux16), vdup_n_u8(0x30)); int8x8_t scales8 = vadd_s8(vreinterpret_s8_u8(vorr_u8(scl8, vtbl1_u8(sch8, vreinterpret_u8_u32(hshuff)))), vdup_n_s8(-32)); // shuffle 0, 2, 4, 6, 1, 3, 5, 7 -> 0, 1, 2, 3, 4, 5, 6, 7 scales8 = vtbl1_s8(scales8, vreinterpret_s8_u32(hshuff)); int16x8_t scales16 = vmovl_s8(scales8); int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))}; return scales; } inline void prepare(int i, int j) { bits.prepare16(x[i].qs+64*j); for (int k = 0; k < 4; ++k) { bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b1.val[k])); bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b2.val[k])); } } Q4bits bits; const int8x16_t values; uint32_t aux32[2]; constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602}; float d; }; struct SimpleBits { uint8x16x4_t b1; uint8x16x4_t b2; }; IQK_ALWAYS_INLINE int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) { int32x4x2_t scales; auto one = vdupq_n_u32(1); scales.val[0] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v1, 28), 1)); scales.val[1] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v2, 28), 1)); return scales; } inline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) { auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127)))); auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >>14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >>21) & 127)))); b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1)); b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2)); } IQK_ALWAYS_INLINE int32x4_t prepare_scales_8(const uint32x4_t& v1) { return vreinterpretq_s32_u32(vsliq_n_u32(vdupq_n_u32(1), vshrq_n_u32(v1, 28), 1)); } struct DequantizerIQ2XXS final : public BaseDequantizer { DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} IQK_ALWAYS_INLINE float new_block(int i) const { return 0.125f * GGML_FP16_TO_FP32(x[i].d); } inline int32x4_t unpack(int i, int j, uint8x16_t * q) const { auto data = vld1q_u32_x2((const uint32_t *)(x[i].qs + 16*j)); prepare_all(data, q); return prepare_scales_8(vuzp2q_u32(data.val[0], data.val[1])); } private: static inline void prepare2(uint8x16_t * b, const uint32_t * bits, const uint64_t * signs) { const uint8_t * idx = (const uint8_t *)bits; b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); apply_signs_2(b, signs, bits[1]); } inline static void prepare_all(const uint32x4x2_t& data, uint8x16_t * quants) { const uint32_t * q2 = (const uint32_t *)data.val; prepare2(quants+0, q2+0, keven_signs); prepare2(quants+2, q2+2, keven_signs); prepare2(quants+4, q2+4, keven_signs); prepare2(quants+6, q2+6, keven_signs); } }; inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) { auto aux = vld1_u8(sc); auto scales_l = vand_u8(aux, vdup_n_u8(0xf)); auto scales_h = vshr_n_u8(aux, 4); auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); auto scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1))); int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) }; return make_wider(scales16); } struct DequantizerIQ2XS final : public BaseDequantizer { DequantizerIQ2XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 16; } constexpr static bool should_scale_quants() { return false; } SimpleBits bits; float d; inline int32x4x4_t new_block(int i) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); prepare_internal(i, 0); return prepare_4bit_scales16(x[i].scales); } inline void prepare(int i, int j) { if (j == 1) prepare_internal(i, 1); } private: static void make2(const uint16_t * qs, uint8x16_t * b) { auto v1 = vcombine_s8(vld1_s8((const int8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_s8((const int8_t *)(iq2xs_grid + (qs[1] & 511)))); auto v2 = vcombine_s8(vld1_s8((const int8_t *)(iq2xs_grid + (qs[2] & 511))), vld1_s8((const int8_t *)(iq2xs_grid + (qs[3] & 511)))); auto s1 = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9)))); auto s2 = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[2] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[3] >> 9)))); b[0] = vreinterpretq_u8_s8(vmulq_s8(v1, s1)); b[1] = vreinterpretq_u8_s8(vmulq_s8(v2, s2)); } inline static void make4(const uint16_t * qs, uint8x16_t * b) { make2(qs + 0, b + 0); make2(qs + 4, b + 2); } IQK_ALWAYS_INLINE void prepare_internal(int i, int j) { make4(x[i].qs + 16*j + 0, bits.b1.val); make4(x[i].qs + 16*j + 8, bits.b2.val); } }; // So, I hate to include this table, but with the GCC 12.3 compiler // bundled in the Cosmopolitan tools, loading the unpacked sign bytes // from this table using the packed 8 sign bits as index is faster than // using the standard trick of vceqq_u8(vandq_u8(bits, mask), mask) to // expand the bits to bytes. static const uint64_t kall_signs[256] = { 0x0101010101010101, 0x01010101010101ff, 0x010101010101ff01, 0x010101010101ffff, 0x0101010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0x0101010101ffffff, 0x01010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0x01010101ff01ffff, 0x01010101ffff0101, 0x01010101ffff01ff, 0x01010101ffffff01, 0x01010101ffffffff, 0x010101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0x010101ff0101ffff, 0x010101ff01ff0101, 0x010101ff01ff01ff, 0x010101ff01ffff01, 0x010101ff01ffffff, 0x010101ffff010101, 0x010101ffff0101ff, 0x010101ffff01ff01, 0x010101ffff01ffff, 0x010101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0x010101ffffffffff, 0x0101ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0x0101ff010101ffff, 0x0101ff0101ff0101, 0x0101ff0101ff01ff, 0x0101ff0101ffff01, 0x0101ff0101ffffff, 0x0101ff01ff010101, 0x0101ff01ff0101ff, 0x0101ff01ff01ff01, 0x0101ff01ff01ffff, 0x0101ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0x0101ff01ffffffff, 0x0101ffff01010101, 0x0101ffff010101ff, 0x0101ffff0101ff01, 0x0101ffff0101ffff, 0x0101ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0x0101ffff01ffffff, 0x0101ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0x0101ffffff01ffff, 0x0101ffffffff0101, 0x0101ffffffff01ff, 0x0101ffffffffff01, 0x0101ffffffffffff, 0x01ff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0x01ff01010101ffff, 0x01ff010101ff0101, 0x01ff010101ff01ff, 0x01ff010101ffff01, 0x01ff010101ffffff, 0x01ff0101ff010101, 0x01ff0101ff0101ff, 0x01ff0101ff01ff01, 0x01ff0101ff01ffff, 0x01ff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0x01ff0101ffffffff, 0x01ff01ff01010101, 0x01ff01ff010101ff, 0x01ff01ff0101ff01, 0x01ff01ff0101ffff, 0x01ff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0x01ff01ff01ffffff, 0x01ff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0x01ff01ffff01ffff, 0x01ff01ffffff0101, 0x01ff01ffffff01ff, 0x01ff01ffffffff01, 0x01ff01ffffffffff, 0x01ffff0101010101, 0x01ffff01010101ff, 0x01ffff010101ff01, 0x01ffff010101ffff, 0x01ffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0x01ffff0101ffffff, 0x01ffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0x01ffff01ff01ffff, 0x01ffff01ffff0101, 0x01ffff01ffff01ff, 0x01ffff01ffffff01, 0x01ffff01ffffffff, 0x01ffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0x01ffffff0101ffff, 0x01ffffff01ff0101, 0x01ffffff01ff01ff, 0x01ffffff01ffff01, 0x01ffffff01ffffff, 0x01ffffffff010101, 0x01ffffffff0101ff, 0x01ffffffff01ff01, 0x01ffffffff01ffff, 0x01ffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0x01ffffffffffffff, 0xff01010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0xff0101010101ffff, 0xff01010101ff0101, 0xff01010101ff01ff, 0xff01010101ffff01, 0xff01010101ffffff, 0xff010101ff010101, 0xff010101ff0101ff, 0xff010101ff01ff01, 0xff010101ff01ffff, 0xff010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0xff010101ffffffff, 0xff0101ff01010101, 0xff0101ff010101ff, 0xff0101ff0101ff01, 0xff0101ff0101ffff, 0xff0101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0xff0101ff01ffffff, 0xff0101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0xff0101ffff01ffff, 0xff0101ffffff0101, 0xff0101ffffff01ff, 0xff0101ffffffff01, 0xff0101ffffffffff, 0xff01ff0101010101, 0xff01ff01010101ff, 0xff01ff010101ff01, 0xff01ff010101ffff, 0xff01ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0xff01ff0101ffffff, 0xff01ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0xff01ff01ff01ffff, 0xff01ff01ffff0101, 0xff01ff01ffff01ff, 0xff01ff01ffffff01, 0xff01ff01ffffffff, 0xff01ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0xff01ffff0101ffff, 0xff01ffff01ff0101, 0xff01ffff01ff01ff, 0xff01ffff01ffff01, 0xff01ffff01ffffff, 0xff01ffffff010101, 0xff01ffffff0101ff, 0xff01ffffff01ff01, 0xff01ffffff01ffff, 0xff01ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0xff01ffffffffffff, 0xffff010101010101, 0xffff0101010101ff, 0xffff01010101ff01, 0xffff01010101ffff, 0xffff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0xffff010101ffffff, 0xffff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0xffff0101ff01ffff, 0xffff0101ffff0101, 0xffff0101ffff01ff, 0xffff0101ffffff01, 0xffff0101ffffffff, 0xffff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0xffff01ff0101ffff, 0xffff01ff01ff0101, 0xffff01ff01ff01ff, 0xffff01ff01ffff01, 0xffff01ff01ffffff, 0xffff01ffff010101, 0xffff01ffff0101ff, 0xffff01ffff01ff01, 0xffff01ffff01ffff, 0xffff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0xffff01ffffffffff, 0xffffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0xffffff010101ffff, 0xffffff0101ff0101, 0xffffff0101ff01ff, 0xffffff0101ffff01, 0xffffff0101ffffff, 0xffffff01ff010101, 0xffffff01ff0101ff, 0xffffff01ff01ff01, 0xffffff01ff01ffff, 0xffffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0xffffff01ffffffff, 0xffffffff01010101, 0xffffffff010101ff, 0xffffffff0101ff01, 0xffffffff0101ffff, 0xffffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0xffffffff01ffffff, 0xffffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0xffffffffff01ffff, 0xffffffffffff0101, 0xffffffffffff01ff, 0xffffffffffffff01, 0xffffffffffffffff, }; struct SignHelper { IQK_ALWAYS_INLINE void apply_signs_1x(uint8x16_t * b, const uint8_t * sign_bits) const { auto s = vreinterpretq_s8_u64(uint64x2_t{kall_signs[sign_bits[0]], kall_signs[sign_bits[1]]}); // Normally we would expect this to be faster, but it isn't. // auto aux = vcombine_u8(vdup_n_u8(sign_bits[0]), vdup_n_u8(sign_bits[1])); // auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1)); b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s)); } // We would need these two if we weren't loading from the unpacked sign table. //const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); //const uint8x16_t m1 = vdupq_n_u8(1); }; struct DequantizerIQ2S final : public BaseDequantizer { DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 16; } constexpr static bool should_scale_quants() { return false; } SimpleBits bits; float d; inline int32x4x4_t new_block(int i) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); prepare_internal(i, 0, bits); return prepare_4bit_scales16(x[i].scales); } inline void prepare(int i, int j) { if (j == 1) prepare_internal(i, 1, bits); } private: static void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { uint32_t aux32[2]; const uint16_t * aux16 = (const uint16_t *)aux32; for (int k = 0; k < 2; ++k) { aux32[1] = (qh[k] << 4) | (qh[k] << 18); aux32[0] = (aux32[1] << 4) & 0x03000300; aux32[1] &= 0x03000300; b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))), vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1])))); b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))), vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3])))); sh.apply_signs_1x(b+2*k+0, sign_bits); sign_bits += 2; sh.apply_signs_1x(b+2*k+1, sign_bits); sign_bits += 2; } } void prepare_internal(int i, int j, SimpleBits& sb) { const auto * qs = x[i].qs + 16*j; const auto * qh = x[i].qh + 4*j; const auto * sign_bits = qs + QK_K/8; make4(sh, sign_bits+0, qs+0, qh+0, sb.b1.val); make4(sh, sign_bits+8, qs+8, qh+2, sb.b2.val); } SignHelper sh; }; struct DequantizerIQ3XXS final : public BaseDequantizer { DequantizerIQ3XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} IQK_ALWAYS_INLINE float new_block(int i) const { return 0.25f * GGML_FP16_TO_FP32(x[i].d); } inline int32x4_t unpack(int i, int j, uint8x16_t * q) const { auto q3data = vld1q_u8_x2(x[i].qs + 32*j); auto gas = vld1q_u32((const uint32_t *)(x[i].qs + QK_K/4 + 16*j)); prepare_block((const uint8_t *)q3data.val, (const uint32_t *)&gas, q); return prepare_scales_8(gas); } private: inline static void make2(const uint8_t * q3, const uint32_t sidx, uint8x16_t * b) { b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]}); b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]}); apply_signs_2(b, keven_signs, sidx); } inline static void prepare_block(const uint8_t * q3, const uint32_t * signs, uint8x16_t * quants) { make2(q3+ 0, signs[0], quants + 0); make2(q3+ 8, signs[1], quants + 2); make2(q3+16, signs[2], quants + 4); make2(q3+24, signs[3], quants + 6); } }; struct DequantizerIQ3S final : public BaseDequantizer { DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} constexpr static int num_blocks() { return 8; } constexpr static bool should_scale_quants() { return false; } SimpleBits bits; float d; inline int32x4x2_t new_block(int i) { d = GGML_FP16_TO_FP32(x[i].d); uint32_t scales32[2]; auto qs = vld1q_u8_x2(x[i].qs); auto signs = vld1q_u8(x[i].signs); prepare_block((const uint8_t *)qs.val, x[i].qh, (const uint8_t *)&signs); std::memcpy(scales32, x[i].scales, 4); scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7 scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400))); auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8)); int32x4x2_t scales; scales.val[0] = vmovl_s16(vget_low_s16(scales16)); scales.val[1] = vmovl_s16(vget_high_s16(scales16)); return scales; } inline void prepare(int i, int j) { if (j == 1) { auto qs = vld1q_u8_x2(x[i].qs + 32); auto signs = vld1q_u8(x[i].signs + 16); prepare_block((const uint8_t *)qs.val, x[i].qh + 4, (const uint8_t *)&signs); } } private: static inline void make2(const SignHelper& sh, const uint8_t * sign_bits, const uint16x8_t& idx_l, uint8_t qh, const int16x8_t& hshift, uint8x16_t * b) { auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256))); const uint16_t * idx = (const uint16_t *)&vindex; b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]}); sh.apply_signs_1x(b+0, sign_bits+0); b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]}); sh.apply_signs_1x(b+1, sign_bits+2); } static inline void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, const int16x8_t& hshift, uint8x16_t * b) { auto idx_l = vld1q_u8(qs); make2(sh, sign_bits+0, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0); make2(sh, sign_bits+4, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2); } static int16x8_t load_shift() { static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; return vld1q_s16(k_shift); } inline void prepare_block(const uint8_t * qs, const uint8_t * qh, const uint8_t * sign_bits) { auto signs = vld1q_u8(sign_bits); auto s = (const uint8_t *)&signs; make4(sh, s + 0, qs+ 0, qh+0, hshift, bits.b1.val); make4(sh, s + 8, qs+16, qh+2, hshift, bits.b2.val); } SignHelper sh; const int16x8_t hshift = load_shift(); }; template IQK_NOINLINE void mul_mat_qX_K_q8_K_IQXXS(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; Q8 q8(info); Dequantizer deq(vx, bx, nrc_y); uint8x16_t qx[8]; int32x4_t sumi[nrc_y]; float32x4_t acc[nrc_y]; for (int ix = 0; ix < nrc_x; ++ix) { deq.new_row(ix); for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); for (int i = 0; i < nb; ++i) { float d = deq.new_block(i); auto scales = deq.unpack(i, 0, qx); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { sumi[iy] = vdupq_n_s32(0); compute_8_blocks((const int8x16_t *)qx, q8, scales, iy, i, 0, sumi[iy]); } scales = deq.unpack(i, 1, qx); #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { compute_8_blocks((const int8x16_t *)qx, q8, scales, iy, i, 1, sumi[iy]); acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, i)), vcvtq_f32_s32(sumi[iy])); } } #pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, vaddvq_f32(acc[iy])); } } } // =========================================== Legacy quants template inline float16x4_t load_scales_q0(const Block * x, ggml_half * aux) { for (int k = 0; k < 4; ++k) aux[k] = x[k].d; return vld1_f16((const float16_t *)aux); } template inline float16x8_t load_scales_q1(const Block * x, ggml_half * aux) { if constexpr (std::is_same_v) { for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].s; } } else { for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].m; } } return vld1q_f16((const float16_t *)aux); } struct Q4LegacyBits { template inline void prepare(const Block * x) { for (int i = 0; i < 4; ++i) { auto q4bits = vld1q_u8(x[i].qs); b[2*i+0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b)); b[2*i+1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4)); } } inline void prepare1(const uint8_t * qs, int8x16_t * q) const { auto q4bits = vld1q_u8(qs); q[0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b)); q[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4)); } inline void prepare1(const uint8_t * qs) { prepare1(qs, b); } const uint8x16_t m4b = vdupq_n_u8(0xf); int8x16_t b[8]; }; // One would think this commented out version would do better than the one below // because it offers more opportunities to execute instructions in parallel. // Instead, it runs significantly slower. Why? If the compiler is running out of vector registers // cannot it just do the sequential version below on its own? //inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) { // const auto q8b_1 = vld1q_s8_x2(qs + 0); // auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b_1.val[0]), b[1], q8b_1.val[1]); // const auto q8b_2 = vld1q_s8_x2(qs + 32); // auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b_2.val[0]), b[3], q8b_2.val[1]); // auto p1234 = vpaddq_s32(p12, p34); // const auto q8b_3 = vld1q_s8_x2(qs + 64); // auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b_3.val[0]), b[5], q8b_3.val[1]); // const auto q8b_4 = vld1q_s8_x2(qs + 96); // auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b_4.val[0]), b[7], q8b_4.val[1]); // return vpaddq_s32(p1234, vpaddq_s32(p56, p78)); //} inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) { auto q8b = vld1q_s8_x2(qs + 0); auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b.val[0]), b[1], q8b.val[1]); q8b = vld1q_s8_x2(qs + 32); auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b.val[0]), b[3], q8b.val[1]); auto p1234 = vpaddq_s32(p12, p34); q8b = vld1q_s8_x2(qs + 64); auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b.val[0]), b[5], q8b.val[1]); q8b = vld1q_s8_x2(qs + 96); auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b.val[0]), b[7], q8b.val[1]); return vpaddq_s32(p1234, vpaddq_s32(p56, p78)); } template struct Q80 { constexpr static int nrc_y = nrc; Q80(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy); } inline const int8_t * quant_data(int iy, int i) const { const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i; return y4->qs; } inline float16x4_t load_scales(int iy, int i) const { const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i; return vld1_f16((const float16_t *)y4->d); } template inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * /*acc*/) const { auto qx_scales = deq.new_block(i); for (int iy = 0; iy < nrc; ++iy) { auto q8_scales = load_scales(iy, i); sc16[iy] = vmul_f16(qx_scales, q8_scales); } } template inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const { deq.prepare1(i); float d = GGML_FP16_TO_FP32(deq.x[i].d); for (int iy = 0; iy < nrc; ++iy) { auto q8b = vld1q_s8_x2(y[iy][i].qs); auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]); acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p)); } } const block_q8_0 * y[nrc_y]; }; template struct Q81 { constexpr static int nrc_y = nrc; Q81(const DataInfo& info) { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_1 *)info.src1_row(iy); } inline const int8_t * quant_data(int iy, int i) const { const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i; return y4->qs; } inline float16x8_t load_scales(int iy, int i) const { const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i; return vld1q_f16((const float16_t *)y4->d); } template inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * acc) const { auto qx_scales = deq.new_block(i); for (int iy = 0; iy < nrc; ++iy) { auto q8_scales = load_scales(iy, i); auto m = vmul_f16(vget_high_f16(qx_scales), vget_high_f16(q8_scales)); acc[iy] = vaddq_f32(acc[iy], vcvt_f32_f16(m)); sc16[iy] = vmul_f16(vget_low_f16(qx_scales), vget_low_f16(q8_scales)); } } template inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const { deq.prepare1(i); float d = GGML_FP16_TO_FP32(deq.x[i].d), m = 0.25f*GGML_FP16_TO_FP32(deq.x[i].m); for (int iy = 0; iy < nrc; ++iy) { auto q8b = vld1q_s8_x2(y[iy][i].qs); auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]); acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p)); acc[iy] = vaddq_f32(acc[iy], vdupq_n_f32(m*GGML_FP16_TO_FP32(y[iy][i].s))); } } const block_q8_1 * y[nrc_y]; }; template struct BaseLegacyDequantizer { BaseLegacyDequantizer(const void * vx, size_t bx) : vx(vx), x(nullptr), bx(bx) {} inline void new_row(int ix) { x = (const block_q *)((const char *)vx + bx*ix); } Q4LegacyBits bits; const void * vx; const block_q * x; size_t bx; }; struct DequantizerQ40 final : public BaseLegacyDequantizer { DequantizerQ40(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} inline void prepare1(int i, int8x16_t * q) const { bits.prepare1(x[i].qs, q); q[0] = vaddq_s8(q[0], m8); q[1] = vaddq_s8(q[1], m8); } inline void prepare1(int i) { prepare1(i, bits.b); } inline float16x4_t new_block(int i) { ggml_half aux[4]; for (int k = 0; k < 4; ++k) { aux[k] = x[4*i+k].d; prepare1(4*i+k, bits.b + 2*k); } return vld1_f16((const float16_t *)aux); } const int8x16_t m8 = vdupq_n_s8(-8); //ggml_half aux[4]; }; struct DequantizerQ41 : public BaseLegacyDequantizer { DequantizerQ41(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} inline void prepare1(int i) { bits.prepare1(x[i].qs); } inline float16x8_t new_block(int i) { uint32_t aux32[4]; const uint32_t * s32 = (const uint32_t *)&x[4*i].d; for (int k = 0; k < 4; ++k) { aux32[k] = *s32; s32 += sizeof(block_q4_1)/4; bits.prepare1(x[4*i+k].qs, bits.b + 2*k); } return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle))); } // Leaving this commented out attempt to be reminded that I already tried this. // It has basically the same performance as the version above. //inline float16x8_t new_block(int i) { // uint32x4_t scales = {}; // const block_q4_1 * xi = x + 4*i; // const uint32_t * s32 = (const uint32_t *)&xi->d; // scales = vsetq_lane_u32(*s32, scales, 0); s32 += sizeof(block_q4_1)/4; // bits.prepare1(xi[0].qs, bits.b + 0); // scales = vsetq_lane_u32(*s32, scales, 1); s32 += sizeof(block_q4_1)/4; // bits.prepare1(xi[1].qs, bits.b + 2); // scales = vsetq_lane_u32(*s32, scales, 2); s32 += sizeof(block_q4_1)/4; // bits.prepare1(xi[2].qs, bits.b + 4); // scales = vsetq_lane_u32(*s32, scales, 3); // bits.prepare1(xi[3].qs, bits.b + 6); // return vreinterpretq_f16_u8(vqtbl1q_u8(vreinterpretq_u8_u32(scales), vreinterpretq_u8_u64(shuffle))); //} const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302}; }; struct HighBit5Legacy { inline uint8x16_t to_bytes(const uint8_t * qh) const { uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle); return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vreinterpretq_u8_u64(mask)); } inline uint8x16_t to_negated_bytes(const uint8_t * qh) const { uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle); return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vdupq_n_u8(0)); } const uint64x2_t mask = vdupq_n_u64(0x8040201008040201); const uint8x16_t shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); }; struct DequantizerQ50 final : public BaseLegacyDequantizer { DequantizerQ50(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} inline void prepare1(int i, int8x16_t * q) const { bits.prepare1(x[i].qs, q); auto qh = x[i].qh; q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0)))); q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2)))); } inline void prepare1(int i) { prepare1(i, bits.b); } inline float16x4_t new_block(int i) { ggml_half aux[4]; for (int k = 0; k < 4; ++k) { aux[k] = x[4*i+k].d; prepare1(4*i+k, bits.b + 2*k); } return vld1_f16((const float16_t *)aux); } HighBit5Legacy hbits; const uint8x16_t mh = vdupq_n_u8(0xf0); }; struct DequantizerQ80 final : public BaseLegacyDequantizer { DequantizerQ80(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} inline void prepare1(int i) { bits.b[0] = vld1q_s8(x[i].qs); bits.b[1] = vld1q_s8(x[i].qs+16); } inline float16x4_t new_block(int i) { ggml_half aux[4]; for (int k = 0; k < 4; ++k) { aux[k] = x[4*i+k].d; bits.b[2*k+0] = vld1q_s8(x[4*i+k].qs); bits.b[2*k+1] = vld1q_s8(x[4*i+k].qs+16); } return vld1_f16((const float16_t *)aux); } }; struct DequantizerQ51 final : public BaseLegacyDequantizer { DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} inline void prepare1(int i, int8x16_t * q) const { bits.prepare1(x[i].qs, q); auto qh = x[i].qh; q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_bytes(qh+0)))); q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_bytes(qh+2)))); } inline void prepare1(int i) { bits.prepare1(x[i].qs, bits.b); } inline float16x8_t new_block(int i) { uint32_t aux32[4]; const uint32_t * s32 = (const uint32_t *)&x[4*i].d; for (int k = 0; k < 4; ++k) { aux32[k] = *s32; s32 += sizeof(block_q5_1)/4; prepare1(4*i+k, bits.b + 2*k); } return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle))); } HighBit5Legacy hbits; const uint8x16_t mh = vdupq_n_u8(0x10); const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302}; }; template inline void sum_4(int i, Dequantizer& deq, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) { for (int iy = 0; iy < Q8::nrc_y; ++iy) { auto pall = sum_4_blocks(deq.bits.b, q8.quant_data(iy, i)); auto scale = vcvt_f32_f16(sc16[iy]); acc[iy] = vmlaq_f32(acc[iy], scale, vcvtq_f32_s32(pall)); } } template inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& info, int nrc_x) { const int nb = n / QK4_1; float16x4_t sc16[Q8::nrc_y]; for (int ix = 0; ix < nrc_x; ++ix) { deq.new_row(ix); float32x4_t acc[Q8::nrc_y]; for (int iy = 0; iy < Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); for (int i = 0; i < nb/4; ++i) { q8.process_scales(i, deq, sc16, acc); sum_4(i, deq, q8, sc16, acc); } for (int i = 4*(nb/4); i < nb; ++i) { q8.process_1_block(i, deq, acc); } for (int iy = 0; iy < Q8::nrc_y; ++iy) { info.store(ix, iy, vaddvq_f32(acc[iy])); } } } template inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) { const int nb = n / QK4_1; float16x4_t sc16[2]; for (int ix = 0; ix < nrc_x; ++ix) { deq1.new_row(ix); deq2.new_row(ix); float32x4_t acc[2] = { vdupq_n_f32(0.f), vdupq_n_f32(0.f) }; for (int i = 0; i < nb/8; ++i) { q8.process_scales(2*i+0, deq1, sc16+0, acc+0); q8.process_scales(2*i+1, deq2, sc16+1, acc+1); sum_4(2*i+0, deq1, q8, sc16+0, acc+0); sum_4(2*i+1, deq2, q8, sc16+1, acc+1); } for (int i = 2*(nb/8); i < nb/4; ++i) { q8.process_scales(i, deq1, sc16, acc); sum_4(i, deq1, q8, sc16, acc); } for (int i = 4*(nb/4); i < nb; ++i) { q8.process_1_block(i, deq1, acc); } info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1]))); } } template static void IQK_NOINLINE mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { Q81 q8(info); if constexpr (nrc_y == 1) { Dequantizer deq1(vx, bx), deq2(vx, bx); mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); } else { Dequantizer deq(vx, bx); mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); } } template static void IQK_NOINLINE mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { Q80 q8(info); if constexpr (nrc_y == 1) { Dequantizer deq1(vx, bx), deq2(vx, bx); mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); } else { Dequantizer deq(vx, bx); mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); } } template static void IQK_NOINLINE mul_mat_qX_1_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { Dequantizer deq1(vx, bx), deq2(vx, bx); Q81<1> q8(info); mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); } template static void IQK_NOINLINE mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { Dequantizer deq1(vx, bx), deq2(vx, bx); Q80<1> q8(info); mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x); } template void MulMat::set_functions(MulMat& m) { if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qX_0_q8_0; m.funcs[1] = mul_mat_qX_0_q8_0; m.funcs[2] = mul_mat_qX_0_q8_0; m.funcs[3] = mul_mat_qX_0_q8_0; m.funcs[4] = mul_mat_qX_0_q8_0; m.funcs[5] = mul_mat_qX_0_q8_0; m.funcs[6] = mul_mat_qX_0_q8_0; m.funcs[7] = mul_mat_qX_0_q8_0; } else if constexpr (std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qX_1_q8_1; m.funcs[1] = mul_mat_qX_1_q8_1; m.funcs[2] = mul_mat_qX_1_q8_1; m.funcs[3] = mul_mat_qX_1_q8_1; m.funcs[4] = mul_mat_qX_1_q8_1; m.funcs[5] = mul_mat_qX_1_q8_1; m.funcs[6] = mul_mat_qX_1_q8_1; m.funcs[7] = mul_mat_qX_1_q8_1; } else if constexpr (std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qX_K_q8_K_IQXXS<1, Dequantizer>; m.funcs[1] = mul_mat_qX_K_q8_K_IQXXS<2, Dequantizer>; m.funcs[2] = mul_mat_qX_K_q8_K_IQXXS<3, Dequantizer>; m.funcs[3] = mul_mat_qX_K_q8_K_IQXXS<4, Dequantizer>; m.funcs[4] = mul_mat_qX_K_q8_K_IQXXS<5, Dequantizer>; m.funcs[5] = mul_mat_qX_K_q8_K_IQXXS<6, Dequantizer>; m.funcs[6] = mul_mat_qX_K_q8_K_IQXXS<7, Dequantizer>; m.funcs[7] = mul_mat_qX_K_q8_K_IQXXS<8, Dequantizer>; } else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qX_K_q8_K_IQ<1, Dequantizer>; m.funcs[1] = mul_mat_qX_K_q8_K_IQ<2, Dequantizer>; m.funcs[2] = mul_mat_qX_K_q8_K_IQ<3, Dequantizer>; m.funcs[3] = mul_mat_qX_K_q8_K_IQ<4, Dequantizer>; m.funcs[4] = mul_mat_qX_K_q8_K_IQ<5, Dequantizer>; m.funcs[5] = mul_mat_qX_K_q8_K_IQ<6, Dequantizer>; m.funcs[6] = mul_mat_qX_K_q8_K_IQ<7, Dequantizer>; m.funcs[7] = mul_mat_qX_K_q8_K_IQ<8, Dequantizer>; } else { m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>; m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>; m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>; m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>; m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>; m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>; m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>; } } bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int Ny) { row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); (void)Ny; // Uncommenting out this would disable iqk_mul_mat for matrix x vector multiplications. //if (Ny == 1 && (typeA == GGML_TYPE_IQ2_XXS || typeA == GGML_TYPE_IQ2_XS || typeA == GGML_TYPE_IQ2_S || // typeA == GGML_TYPE_IQ3_XXS || typeA == GGML_TYPE_IQ3_S)) return false; switch (typeA) { case GGML_TYPE_Q2_K: MulMat::set_functions(m); break; case GGML_TYPE_Q3_K: MulMat::set_functions(m); break; case GGML_TYPE_Q4_K: MulMat::set_functions(m); break; case GGML_TYPE_Q5_K: MulMat::set_functions(m); break; case GGML_TYPE_Q6_K: MulMat::set_functions(m); break; case GGML_TYPE_IQ4_XS: MulMat::set_functions(m); break; case GGML_TYPE_IQ3_S: MulMat::set_functions(m); break; case GGML_TYPE_IQ3_XXS: MulMat::set_functions(m); break; case GGML_TYPE_IQ2_S: MulMat::set_functions(m); break; case GGML_TYPE_IQ2_XS: MulMat::set_functions(m); break; case GGML_TYPE_IQ2_XXS: MulMat::set_functions(m); break; case GGML_TYPE_Q4_0: MulMat::set_functions(m); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); break; case GGML_TYPE_Q4_1: MulMat::set_functions(m); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00); break; case GGML_TYPE_Q5_0: MulMat::set_functions(m); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); break; case GGML_TYPE_Q5_1: MulMat::set_functions(m); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00); break; case GGML_TYPE_Q8_0: MulMat::set_functions(m); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); break; default: return false; } return true; } } #endif // __x86_64__ or __aarch64__ ================================================ FILE: archive/third_party/llamafile/macros.h ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/macros.h // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi #pragma once #define MIN(X, Y) ((Y) > (X) ? (X) : (Y)) #define MAX(X, Y) ((Y) < (X) ? (X) : (Y)) #define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) #define ROUNDUP(X, K) (((X) + (K) - 1) & -(K)) #define ARRAYLEN(A) ((sizeof(A) / sizeof(*(A))) / ((unsigned)!(sizeof(A) % sizeof(*(A))))) ================================================ FILE: archive/third_party/llamafile/micros.h ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/micros.h // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi #pragma once #include #ifndef _WIN32 #include #else #include #endif #ifdef _WIN32 static long long GetQueryPerformanceFrequency() { LARGE_INTEGER t; QueryPerformanceFrequency(&t); return t.QuadPart; } static long long GetQueryPerformanceCounter() { LARGE_INTEGER t; QueryPerformanceCounter(&t); return t.QuadPart; } #endif static long long micros(void) { #ifndef _WIN32 struct timespec ts; clock_gettime(CLOCK_REALTIME, &ts); return ts.tv_sec * 1000000 + (ts.tv_nsec + 999) / 1000; #else static long long timer_freq = GetQueryPerformanceFrequency(); static long long timer_start = GetQueryPerformanceCounter(); return ((GetQueryPerformanceCounter() - timer_start) * 1000000) / timer_freq; #endif } ================================================ FILE: archive/third_party/llamafile/numba.h ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/numba.h // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #pragma once inline int rand32(void) { static unsigned long long lcg = 1; lcg *= 6364136223846793005; lcg += 1442695040888963407; return lcg >> 32; } inline int popcount(unsigned x) { x = x - ((x >> 1) & 0x55555555); x = ((x >> 2) & 0x33333333) + (x & 0x33333333); x = (x + (x >> 4)) & 0x0F0F0F0F; x = (x + (x >> 16)); return (x + (x >> 8)) & 0x0000003F; } inline int hamming(int x, int y) { return popcount(x ^ y); } inline float float01(unsigned x) { // (0,1) return 1.f / 8388608 * ((x >> 9) + .5f); } inline float numba(void) { // (-10,10) return float01(rand32()) * 2.f - 1.f; } template void randomize(T* A, int n) { for (int i = 0; i < n; ++i) A[i] = numba(); } template void randomize(int m, int n, T* A, int lda) { for (int j = 0; j < n; ++j) for (int i = 0; i < m; ++i) A[lda * j + i] = numba(); } template void broadcast(T* A, int n, U x) { for (int i = 0; i < n; ++i) A[i] = x; } template void broadcast(int m, int n, T* A, int lda, U x) { for (int j = 0; j < n; ++j) for (int i = 0; i < m; ++i) A[lda * j + i] = x; } ================================================ FILE: archive/third_party/llamafile/sgemm.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi // // Copyright 2024 Mozilla Foundation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #if defined(KTRANSFORMERS_USE_NPU) && KTRANSFORMERS_USE_NPU // use ARM version #include "sgemm_arm.cpp" #else // use x86 version #include "sgemm_x86.cpp" #endif ================================================ FILE: archive/third_party/llamafile/sgemm.h ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.h // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #pragma once #include #include #ifdef __cplusplus extern "C" { #endif struct ggml_tensor; struct ggml_compute_params; /*moonll old add more params typeb... */ bool iqk_mul_mat(long, long, long,int, const void*, long, int, const void*, long,float*, long, int, int); bool iqk_mul_mat_zen4(long, long, long,int, const void*, long, int, const void*, long,float*, long, int, int); bool iqk_mul_mat_arm82(long, long, long,int, const void*, long, int, const void*, long,float*, long, int, int); bool iqk_mul_mat_moe(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int); bool iqk_mul_mat_moe_zen4(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int); bool iqk_mul_mat_moe_arm82(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int); bool iqk_mul_mat_moe_unsupported(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int); bool llamafile_sgemm(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int); bool llamafile_mixmul(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*); size_t llamafile_mixmul_needs(const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*); bool llamafile_sgemm_unsupported(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int); bool llamafile_sgemm_amd_avx(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int); bool llamafile_sgemm_amd_fma(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int); bool llamafile_sgemm_amd_avx2(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int); bool llamafile_sgemm_amd_avxvnni(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int); bool llamafile_sgemm_amd_avx512f(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int); bool llamafile_sgemm_amd_zen4(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int); bool llamafile_sgemm_arm80(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int); bool llamafile_sgemm_arm82(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int); bool llamafile_mixmul_unsupported(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*); bool llamafile_mixmul_amd_avx(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*); bool llamafile_mixmul_amd_fma(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*); bool llamafile_mixmul_amd_avx2(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*); bool llamafile_mixmul_amd_avxvnni(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*); bool llamafile_mixmul_amd_avx512f(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*); bool llamafile_mixmul_amd_zen4(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*); bool llamafile_mixmul_arm80(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*); bool llamafile_mixmul_arm82(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*); bool llamafile_mixmul_iqk(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int); #ifdef __cplusplus } #endif ================================================ FILE: archive/third_party/llamafile/sgemm_arm.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi // // Copyright 2024 Mozilla Foundation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "sgemm.h" // #include // #include // #include #include // #include #include // #include "llamafile.h" static const struct GemmFuncs { bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int); bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*); bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int); // typeof(llamafile_sgemm)* sgemm; // typeof(llamafile_mixmul)* mixmul; // typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported; GemmFuncs() { #if defined(__x86_64__) || defined(_M_X64) // if (X86_HAVE(AVX)) { // if (X86_HAVE(FMA)) { // if (X86_HAVE(AVX2)) { // if (X86_HAVE(AVX512F)) { // if (X86_HAVE(AVX512VL) && // // X86_HAVE(AVX512BW) && // // X86_HAVE(AVX512DQ) && // // X86_HAVE(AVX512_VNNI) && // // X86_HAVE(AVX512_BF16)) { // // AMD Zen4+ (2023-) // sgemm = llamafile_sgemm_amd_zen4; // mixmul = llamafile_mixmul_amd_zen4; // iqk_mixmul = iqk_mul_mat_moe_zen4; // } else { // // Intel Xeon Skylake+ (2015-) // sgemm = llamafile_sgemm_amd_avx512f; // mixmul = llamafile_mixmul_amd_avx512f; // iqk_mixmul = iqk_mul_mat_moe; // } // } else if (X86_HAVE(AVXVNNI)) { // // Intel Alderlake (2021-) // sgemm = llamafile_sgemm_amd_avxvnni; // mixmul = llamafile_mixmul_amd_avxvnni; // iqk_mixmul = iqk_mul_mat_moe; // } else { // // Intel Haswell/Broadwell/Skylake (2013-2020) // // AMD Excavator (2015-2022) // sgemm = llamafile_sgemm_amd_avx2; // mixmul = llamafile_mixmul_amd_avx2; // if (X86_HAVE(F16C)) // iqk_mixmul = iqk_mul_mat_moe; // } // } else { // // AMD Piledriver (2011-2014) // sgemm = llamafile_sgemm_amd_fma; // mixmul = llamafile_mixmul_amd_fma; // if (X86_HAVE(F16C)) // iqk_mixmul = iqk_mul_mat_moe; // } // } else { // // Intel Sandybridge/Ivybridge (2010-2012) // // AMD Bulldozer (2011) // sgemm = llamafile_sgemm_amd_avx; // mixmul = llamafile_mixmul_amd_avx; // } // } else { // // AMD K8/Barcelona (2003-2010) // // Intel Core/Nehalem (2006-2009) // sgemm = llamafile_sgemm_unsupported; // mixmul = llamafile_mixmul_unsupported; // } #if defined(__AVX__) #if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))) #if defined(__AVX2__) #if defined(__AVX512F__) #if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__) // AMD Zen4+ (2023-) sgemm = llamafile_sgemm_amd_zen4; mixmul = llamafile_mixmul_amd_zen4; iqk_mixmul = iqk_mul_mat_moe_zen4; #else // Intel Xeon Skylake+ (2015-) sgemm = llamafile_sgemm_amd_avx512f; mixmul = llamafile_mixmul_amd_avx512f; iqk_mixmul = iqk_mul_mat_moe; #endif #elif defined(__AVXVNNI__) // Intel Alderlake (2021-) sgemm = llamafile_sgemm_amd_avxvnni; mixmul = llamafile_mixmul_amd_avxvnni; iqk_mixmul = iqk_mul_mat_moe; #else // Intel Haswell/Broadwell/Skylake (2013-2020) // AMD Excavator (2015-2022) sgemm = llamafile_sgemm_amd_avx2; mixmul = llamafile_mixmul_amd_avx2; #if defined(__F16C__) iqk_mixmul = iqk_mul_mat_moe; #endif #endif #else // AMD Piledriver (2011-2014) sgemm = llamafile_sgemm_amd_fma; mixmul = llamafile_mixmul_amd_fma; #if defined(__F16C__) iqk_mixmul = iqk_mul_mat_moe; #endif #endif #else // Intel Sandybridge/Ivybridge (2010-2012) // AMD Bulldozer (2011) sgemm = llamafile_sgemm_amd_avx; mixmul = llamafile_mixmul_amd_avx; #endif #else // AMD K8/Barcelona (2003-2010) // Intel Core/Nehalem (2006-2009) sgemm = llamafile_sgemm_unsupported; mixmul = llamafile_mixmul_unsupported; #endif #elif defined(__aarch64__) // long hwcap = getauxval(AT_HWCAP); // if ((hwcap & HWCAP_FPHP) && // fp16 scalar isa (ID_AA64PFR0_EL1.FP == 1) // (hwcap & HWCAP_ASIMDHP) && // fp16 vector isa (ID_AA64PFR0_EL1.AdvSIMD == 1) // (hwcap & HWCAP_ASIMDDP)) { // dotprod isa (ID_AA64ISAR0_EL1.DP == 1) // // e.g. Apple M1, Raspberry Pi 5 // sgemm = llamafile_sgemm_arm82; // mixmul = llamafile_mixmul_arm82; // iqk_mixmul = iqk_mul_mat_moe_arm82; // } else { // ARM64 baseline ISA sgemm = llamafile_sgemm_arm80; mixmul = llamafile_mixmul_arm80; // } #else sgemm = llamafile_sgemm_unsupported; mixmul = llamafile_mixmul_unsupported; #endif } } funcs; /** * Performs optimized matrix multiplication on CPU. * * This subroutine may compute C = Aᵀ * B with column major ordering. * Despite its name, this isn't a generalized implementation. Work is * only performed when a handwritten kernel is written and available. * Otherwise the caller should fall back to a general matmul routine. * * @param m is rows in `A` and `C` * @param n is cols in `B` and `C` * @param k is cols in `A` and rows in `B` * @param A is first input matrix (always transposed) * @param lda is row stride of `A` * @param B is second input matrix (never transposed) * @param ldb is row stride of `B` * @param C is input/output array of output matrices * @param ldc is row stride of `C` * @param ith is thread id (must be less than `nth`) * @param nth is number of threads (must be greater than zero) * @param task is GGML task type * @param Atype is GGML data type of `A` * @param Btype is GGML data type of `B` * @param Ctype is GGML data type of `C` * @param precision may be used to control the internal compute type * @return true if this function was able to service the matmul request */ bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) { return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype, Ctype, precision); } /** * Performs "mixture of experts" tensor multiplication on CPU. */ bool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) { return funcs.mixmul(params, weights, thought, plan, result); } bool llamafile_mixmul_iqk(long Nx, long Ny, long ne00, int ne11, int typeA, const void* A, const void* B, float* C, long nb1, long nb2, const void* vrow_mapping, int ith, int nth) { return funcs.iqk_mixmul(Nx, Ny, ne00, ne11, typeA, A, B, C, nb1, nb2, vrow_mapping, ith, nth); } ================================================ FILE: archive/third_party/llamafile/sgemm_x86.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi // // Copyright 2024 Mozilla Foundation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "sgemm.h" // #include // #include // #include #include // #include #include // #include "llamafile.h" static const struct GemmFuncs { bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int); bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*); bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int); // typeof(llamafile_sgemm)* sgemm; // typeof(llamafile_mixmul)* mixmul; // typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported; GemmFuncs() { #if defined(__x86_64__) || defined(_M_X64) // if (X86_HAVE(AVX)) { // if (X86_HAVE(FMA)) { // if (X86_HAVE(AVX2)) { // if (X86_HAVE(AVX512F)) { // if (X86_HAVE(AVX512VL) && // // X86_HAVE(AVX512BW) && // // X86_HAVE(AVX512DQ) && // // X86_HAVE(AVX512_VNNI) && // // X86_HAVE(AVX512_BF16)) { // // AMD Zen4+ (2023-) // sgemm = llamafile_sgemm_amd_zen4; // mixmul = llamafile_mixmul_amd_zen4; // iqk_mixmul = iqk_mul_mat_moe_zen4; // } else { // // Intel Xeon Skylake+ (2015-) // sgemm = llamafile_sgemm_amd_avx512f; // mixmul = llamafile_mixmul_amd_avx512f; // iqk_mixmul = iqk_mul_mat_moe; // } // } else if (X86_HAVE(AVXVNNI)) { // // Intel Alderlake (2021-) // sgemm = llamafile_sgemm_amd_avxvnni; // mixmul = llamafile_mixmul_amd_avxvnni; // iqk_mixmul = iqk_mul_mat_moe; // } else { // // Intel Haswell/Broadwell/Skylake (2013-2020) // // AMD Excavator (2015-2022) // sgemm = llamafile_sgemm_amd_avx2; // mixmul = llamafile_mixmul_amd_avx2; // if (X86_HAVE(F16C)) // iqk_mixmul = iqk_mul_mat_moe; // } // } else { // // AMD Piledriver (2011-2014) // sgemm = llamafile_sgemm_amd_fma; // mixmul = llamafile_mixmul_amd_fma; // if (X86_HAVE(F16C)) // iqk_mixmul = iqk_mul_mat_moe; // } // } else { // // Intel Sandybridge/Ivybridge (2010-2012) // // AMD Bulldozer (2011) // sgemm = llamafile_sgemm_amd_avx; // mixmul = llamafile_mixmul_amd_avx; // } // } else { // // AMD K8/Barcelona (2003-2010) // // Intel Core/Nehalem (2006-2009) // sgemm = llamafile_sgemm_unsupported; // mixmul = llamafile_mixmul_unsupported; // } #if defined(__AVX__) #if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))) #if defined(__AVX2__) #if defined(__AVX512F__) #if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__) // AMD Zen4+ (2023-) sgemm = llamafile_sgemm_amd_zen4; mixmul = llamafile_mixmul_amd_zen4; iqk_mixmul = iqk_mul_mat_moe_zen4; #else // Intel Xeon Skylake+ (2015-) sgemm = llamafile_sgemm_amd_avx512f; mixmul = llamafile_mixmul_amd_avx512f; iqk_mixmul = iqk_mul_mat_moe; #endif #elif defined(__AVXVNNI__) // Intel Alderlake (2021-) sgemm = llamafile_sgemm_amd_avxvnni; mixmul = llamafile_mixmul_amd_avxvnni; iqk_mixmul = iqk_mul_mat_moe; #else // Intel Haswell/Broadwell/Skylake (2013-2020) // AMD Excavator (2015-2022) sgemm = llamafile_sgemm_amd_avx2; mixmul = llamafile_mixmul_amd_avx2; #if defined(__F16C__) iqk_mixmul = iqk_mul_mat_moe; #endif #endif #else // AMD Piledriver (2011-2014) sgemm = llamafile_sgemm_amd_fma; mixmul = llamafile_mixmul_amd_fma; #if defined(__F16C__) iqk_mixmul = iqk_mul_mat_moe; #endif #endif #else // Intel Sandybridge/Ivybridge (2010-2012) // AMD Bulldozer (2011) sgemm = llamafile_sgemm_amd_avx; mixmul = llamafile_mixmul_amd_avx; #endif #else // AMD K8/Barcelona (2003-2010) // Intel Core/Nehalem (2006-2009) sgemm = llamafile_sgemm_unsupported; mixmul = llamafile_mixmul_unsupported; #endif #elif defined(__aarch64__) long hwcap = getauxval(AT_HWCAP); if ((hwcap & HWCAP_FPHP) && // fp16 scalar isa (ID_AA64PFR0_EL1.FP == 1) (hwcap & HWCAP_ASIMDHP) && // fp16 vector isa (ID_AA64PFR0_EL1.AdvSIMD == 1) (hwcap & HWCAP_ASIMDDP)) { // dotprod isa (ID_AA64ISAR0_EL1.DP == 1) // e.g. Apple M1, Raspberry Pi 5 sgemm = llamafile_sgemm_arm82; mixmul = llamafile_mixmul_arm82; iqk_mixmul = iqk_mul_mat_moe_arm82; } else { // ARM64 baseline ISA sgemm = llamafile_sgemm_arm80; mixmul = llamafile_mixmul_arm80; } #else sgemm = llamafile_sgemm_unsupported; mixmul = llamafile_mixmul_unsupported; #endif } } funcs; /** * Performs optimized matrix multiplication on CPU. * * This subroutine may compute C = Aᵀ * B with column major ordering. * Despite its name, this isn't a generalized implementation. Work is * only performed when a handwritten kernel is written and available. * Otherwise the caller should fall back to a general matmul routine. * * @param m is rows in `A` and `C` * @param n is cols in `B` and `C` * @param k is cols in `A` and rows in `B` * @param A is first input matrix (always transposed) * @param lda is row stride of `A` * @param B is second input matrix (never transposed) * @param ldb is row stride of `B` * @param C is input/output array of output matrices * @param ldc is row stride of `C` * @param ith is thread id (must be less than `nth`) * @param nth is number of threads (must be greater than zero) * @param task is GGML task type * @param Atype is GGML data type of `A` * @param Btype is GGML data type of `B` * @param Ctype is GGML data type of `C` * @param precision may be used to control the internal compute type * @return true if this function was able to service the matmul request */ bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) { return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype, Ctype, precision); } /** * Performs "mixture of experts" tensor multiplication on CPU. */ bool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) { return funcs.mixmul(params, weights, thought, plan, result); } bool llamafile_mixmul_iqk(long Nx, long Ny, long ne00, int ne11, int typeA, const void* A, const void* B, float* C, long nb1, long nb2, const void* vrow_mapping, int ith, int nth) { return funcs.iqk_mixmul(Nx, Ny, ne00, ne11, typeA, A, B, C, nb1, nb2, vrow_mapping, ith, nth); } ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu.h ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu.h // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi // // Copyright 2024 Mozilla Foundation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // // ██████╗ ██╗ █████╗ ██████╗ // ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝ // ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗ // ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║ // ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║ // ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝ // // BASIC LINEAR ALGEBRA SUBPROGRAMS // // // This file implements multithreaded CPU matrix multiplication for the // common contiguous use case C = Aᵀ * B. These kernels are designed to // have excellent performance[1] for matrices that fit in the CPU cache // without imposing any overhead such as cache filling or malloc calls. // // This implementation does not guarantee any upper bound with rounding // errors, which grow along with k. Our goal's to maximally exploit the // hardware for performance, and then use whatever resources remain for // improving numerical accuracy. // // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online]. // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024]. #pragma once #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-quants.h" // #include "log.h" #include "flags.h" #include "sgemm.h" // #include #pragma GCC diagnostic ignored "-Wpedantic" #pragma GCC diagnostic ignored "-Wignored-attributes" #define ROW_ALIGN 64 #define MATRIX_ALIGN 4096 #define MAX_ALIGN 4096 #ifdef _MSC_VER #define NOINLINE __declspec(noinline) #else #define NOINLINE __attribute__((__noinline__)) #endif #if defined(__ARM_NEON) || defined(__AVX512F__) #define VECTOR_REGISTERS 32 #else #define VECTOR_REGISTERS 16 #endif #if 0 #define NOT_SUPPORTED tinyBLAS_not_supported(__FILE__, __LINE__) #else #define NOT_SUPPORTED false #endif #define WANT_QUANTIZATION false namespace { bool tinyBLAS_not_supported(const char* file, int line) { // tinylogf("%s:%d: tinyBLAS not supported\n", file, line); return false; } inline float unhalf(ggml_fp16_t d) { return GGML_FP16_TO_FP32(d); } inline float unhalf(ggml_bf16_t d) { return GGML_BF16_TO_FP32(d); } //////////////////////////////////////////////////////////////////////////////////////////////////// // MATRIX MEMORY INDEXING #define NCA 1 #define NCB 2 #define NCC 4 #define INDEX(A, lda, j, i) (CONFIG & NC##A ? ((T##A**)A)[j] + i : A + lda * (j) + i) //////////////////////////////////////////////////////////////////////////////////////////////////// // GGML TYPE TRAITS template struct ggml_type_trait; template <> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_F32; }; template <> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_BF16; }; template <> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_F16; }; template <> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_Q8_0; }; //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED ARITHMETIC OPERATIONS #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); } inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); } inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); } #endif // __SSE__ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); } inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); } inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); } #endif // __AVX__ #if defined(__AVX512F__) inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); } inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); } inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); } #endif // __AVX512F__ #if defined(__ARM_NEON) inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); } inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); } inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); } #endif // __ARM_NEON #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); } inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); } inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED FUSED MULTIPLY ADD /** * Computes a * b + c. */ template inline U madd(T a, T b, U c) { return add(mul(a, b), c); } /** * Computes a * b + c with error correction. * * @see W. Kahan, "Further remarks on reducing truncation errors," * Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965, * doi: 10.1145/363707.363723. */ template inline U madder(T a, T b, U c, U* e) { U y = sub(mul(a, b), *e); U t = add(c, y); *e = sub(sub(t, c), y); return t; } #ifdef __ARM_NEON inline float32x4_t badder(float32x4_t a, float b, float32x4_t c, float32x4_t* e) { float32x4_t y = sub(vmulq_n_f32(a, b), *e); float32x4_t t = add(c, y); *e = sub(sub(t, c), y); return t; } #endif #if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))) #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) template <> inline __m256 madd(__m256 a, __m256 b, __m256 c) { return _mm256_fmadd_ps(a, b, c); } #endif #if defined(__AVX512F__) template <> inline __m512 madd(__m512 a, __m512 b, __m512 c) { return _mm512_fmadd_ps(a, b, c); } #endif #endif #if defined(__ARM_FEATURE_FMA) template <> inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { return vfmaq_f32(c, a, b); } #if 0 // todo: this specialization chops gcc 12.3 performance in half #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) && 0 template <> inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) { return vfmaq_f16(c, b, a); } #endif #endif #endif #if defined(__AVX512BF16__) template <> inline __m512 madd(__m512bh x, __m512bh y, __m512 z) { return _mm512_dpbf16_ps(z, x, y); } template <> inline __m512 madder(__m512bh x, __m512bh y, __m512 z, __m512* _) { return _mm512_dpbf16_ps(z, x, y); } #endif //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED HORIZONTAL SUM #if defined(__ARM_NEON) inline float hsum(float32x4_t x) { return vaddvq_f32(x); } #endif // __ARM_NEON #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) inline float hsum(float16x8_t x) { // todo: this works great on clang but it produces terrible code on gcc 12.3 return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_high_f16(x)))); } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) inline float hsum(__m128 x) { #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) x = _mm_add_ps(x, _mm_movehl_ps(x, x)); x = _mm_add_ss(x, _mm_movehdup_ps(x)); #else __m128 t; t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1)); x = _mm_add_ps(x, t); t = _mm_movehl_ps(t, x); x = _mm_add_ss(x, t); #endif return _mm_cvtss_f32(x); } #endif #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) inline float hsum(__m256 x) { return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x))); } #endif // __AVX__ #if defined(__AVX512F__) inline float hsum(__m512 x) { return _mm512_reduce_add_ps(x); } #endif // __AVX512F__ //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED MEMORY LOADING template T load(const U*); template <> inline float load(const float* p) { return *p; } template <> inline float load(const ggml_fp16_t* p) { return unhalf(*p); } template <> inline float load(const ggml_bf16_t* p) { return unhalf(*p); } #if defined(__ARM_NEON) template <> inline float32x4_t load(const float* p) { return vld1q_f32(p); } template <> inline float32x4_t load(const ggml_bf16_t* p) { return vreinterpretq_f32_u32(vshll_n_u16(vld1_u16((const unsigned short*)p), 16)); } #if !defined(_MSC_VER) template <> inline float16x8_t load(const ggml_fp16_t* p) { return vld1q_f16((const float16_t*)p); } template <> inline float32x4_t load(const ggml_fp16_t* p) { return vcvt_f32_f16(vld1_f16((const float16_t*)p)); } #endif // _MSC_VER #endif // __ARM_NEON #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) template <> inline __m128 load(const float* p) { return _mm_loadu_ps(p); } #endif // __SSE__ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) template <> inline __m256 load(const float* p) { return _mm256_loadu_ps(p); } #endif // __AVX__ #if defined(__AVX2__) || defined(__AVX512F__) template <> inline __m256 load(const ggml_bf16_t* p) { return _mm256_castsi256_ps( _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)p)), 16)); } #endif // __AVX2__ #if defined(__F16C__) template <> inline __m256 load(const ggml_fp16_t* p) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)p)); } #endif // __F16C__ #if defined(__AVX512F__) template <> inline __m512 load(const float* p) { return _mm512_loadu_ps(p); } template <> inline __m512 load(const ggml_fp16_t* p) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)p)); } template <> inline __m512 load(const ggml_bf16_t* p) { return _mm512_castsi512_ps( _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)p)), 16)); } #endif // __AVX512F__ #if defined(__AVX512BF16__) template <> inline __m512bh load(const ggml_bf16_t* p) { return (__m512bh)_mm512_loadu_ps((const float*)p); } template <> inline __m512bh load(const float* p) { return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p)); } #endif // __AVX512BF16__ //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT OUTPUT STREAMING inline void store(float* p, float f) { *p = f; } inline void store(ggml_fp16_t* p, float f) { *p = GGML_FP32_TO_FP16(f); } inline void store(ggml_bf16_t* p, float f) { *p = GGML_FP32_TO_BF16(f); } //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION template class tinyBLAS { public: tinyBLAS(long k, const TA* A, long lda, const TB* B, long ldb, TC* C, long ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } void matmul(long m, long n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: NOINLINE void mnpack(long m0, long m, long n0, long n) { long mc, nc, mp, np; #if VECTOR_REGISTERS == 32 if (!FLAG_precise) { switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) { case 0x55: mc = 5; nc = 5; gemm<5, 5, false>(m0, m, n0, n); break; case 0x54: case 0x53: case 0x52: case 0x45: case 0x44: case 0x43: case 0x42: case 0x35: case 0x34: case 0x33: case 0x32: case 0x25: case 0x24: case 0x23: case 0x22: mc = 2; nc = 2; gemm<2, 2, false>(m0, m, n0, n); break; case 0x51: case 0x41: case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, false>(m0, m, n0, n); break; case 0x15: case 0x14: case 0x13: case 0x12: mc = 1; nc = 2; gemm<1, 2, false>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, false>(m0, m, n0, n); break; default: return; } } else { switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 3)) { case 0x43: mc = 4; nc = 3; gemm<4, 3, true>(m0, m, n0, n); break; case 0x42: case 0x33: case 0x32: case 0x23: case 0x22: mc = 2; nc = 2; gemm<2, 2, true>(m0, m, n0, n); break; case 0x41: case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, true>(m0, m, n0, n); break; case 0x13: case 0x12: mc = 1; nc = 2; gemm<1, 2, true>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, true>(m0, m, n0, n); break; default: return; } } #endif #if VECTOR_REGISTERS == 16 if (!FLAG_precise) { switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 3)) { case 0x43: mc = 4; nc = 3; gemm<4, 3, false>(m0, m, n0, n); break; case 0x42: case 0x33: case 0x32: case 0x23: case 0x22: mc = 2; nc = 2; gemm<2, 2, false>(m0, m, n0, n); break; case 0x41: case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, false>(m0, m, n0, n); break; case 0x13: case 0x12: mc = 1; nc = 2; gemm<1, 2, false>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, false>(m0, m, n0, n); break; default: return; } } else { switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 2)) { case 0x32: mc = 3; nc = 2; gemm<3, 2, true>(m0, m, n0, n); break; case 0x23: mc = 2; nc = 3; gemm<2, 3, true>(m0, m, n0, n); break; case 0x22: mc = 2; nc = 2; gemm<2, 2, true>(m0, m, n0, n); break; case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, true>(m0, m, n0, n); break; case 0x12: mc = 1; nc = 2; gemm<1, 2, true>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, true>(m0, m, n0, n); break; default: return; } } #endif mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); } template NOINLINE void gemm(long m0, long m, long n0, long n) { long ytiles = RM > 1 ? (m - m0) / RM : 1; long xtiles = RN > 1 ? (n - n0) / RN : 1; long tiles = xtiles * ytiles; long duty = (tiles + nth - 1) / nth; long start = duty * ith; long end = start + duty; if (end > tiles) end = tiles; for (long job = start; job < end; ++job) { long ii = m0 + job / xtiles * RM; long jj = n0 + job % xtiles * RN; D Cv[RN][RM] = {}; D Ce[RN][RM] = {}; for (long l = 0; l < k; l += KN) #pragma GCC unroll 100 for (int j = 0; j < RN; ++j) #pragma GCC unroll 100 for (int i = 0; i < RM; ++i) if (PRECISE) Cv[j][i] = madder(load(INDEX(A, lda, ii + i, l)), // load(INDEX(B, ldb, jj + j, l)), // Cv[j][i], &Ce[j][i]); else Cv[j][i] = madd(load(INDEX(A, lda, ii + i, l)), // load(INDEX(B, ldb, jj + j, l)), // Cv[j][i]); #pragma GCC unroll 100 for (int j = 0; j < RN; ++j) #pragma GCC unroll 100 for (int i = 0; i < RM; ++i) store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i])); } } const TA* const A; const TB* const B; TC* const C; const long k; const long lda; const long ldb; const long ldc; const int ith; const int nth; }; ////////////////////////////////////////////////////////////////////////////////////////// // QUANT ZERO MATRIX MULTIPLICATION #if defined(__ARM_FEATURE_DOTPROD) template class tinyBLAS_Q0_ARM { public: tinyBLAS_Q0_ARM(long k, const TA* A, long lda, const TB* B, long ldb, TC* C, long ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } void matmul(long m, long n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: NOINLINE void mnpack(long m0, long m, long n0, long n) { long mc, nc, mp, np; if (!FLAG_precise) { switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) { case 0x33: mc = 3; nc = 3; gemm<3, 3, false>(m0, m, n0, n); break; case 0x32: case 0x23: case 0x22: mc = 2; nc = 2; gemm<2, 2, false>(m0, m, n0, n); break; case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, false>(m0, m, n0, n); break; case 0x13: case 0x12: mc = 1; nc = 2; gemm<1, 2, false>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, false>(m0, m, n0, n); break; default: return; } } else { switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) { case 0x33: mc = 3; nc = 3; gemm<3, 3, true>(m0, m, n0, n); break; case 0x32: case 0x23: case 0x22: mc = 2; nc = 2; gemm<2, 2, true>(m0, m, n0, n); break; case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, true>(m0, m, n0, n); break; case 0x13: case 0x12: mc = 1; nc = 2; gemm<1, 2, true>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, true>(m0, m, n0, n); break; default: return; } } mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); } template NOINLINE void gemm(long m0, long m, long n0, long n) { long ytiles = RM > 1 ? (m - m0) / RM : 1; long xtiles = RN > 1 ? (n - n0) / RN : 1; long tiles = xtiles * ytiles; long duty = (tiles + nth - 1) / nth; long start = duty * ith; long end = start + duty; if (end > tiles) end = tiles; for (long job = start; job < end; ++job) { long ii = m0 + job / xtiles * RM; long jj = n0 + job % xtiles * RN; float32x4_t Cv[RN][RM] = {}; float32x4_t Ce[RN][RM] = {}; for (int l = 0; l < k; ++l) #pragma GCC unroll 100 for (int j = 0; j < RN; ++j) #pragma GCC unroll 100 for (int i = 0; i < RM; ++i) { float32x4_t a = vcvtq_f32_s32(vdotq_s32( vdotq_s32(vdupq_n_s32(0), load_lo(INDEX(A, lda, ii + i, l)), load_lo(INDEX(B, ldb, jj + j, l))), load_hi(INDEX(A, lda, ii + i, l)), load_hi(INDEX(B, ldb, jj + j, l)))); float b = unhalf(INDEX(A, lda, ii + i, l)->d) * unhalf(INDEX(B, ldb, jj + j, l)->d); if (PRECISE) Cv[j][i] = badder(a, b, Cv[j][i], &Ce[j][i]); else Cv[j][i] = vmlaq_n_f32(Cv[j][i], a, b); } #pragma GCC unroll 100 for (int j = 0; j < RN; ++j) #pragma GCC unroll 100 for (int i = 0; i < RM; ++i) store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i])); } } inline int8x16_t load_lo(const block_q8_0* b) { return vld1q_s8(b->qs); } inline int8x16_t load_hi(const block_q8_0* b) { return vld1q_s8(b->qs + 16); } inline int8x16_t load_lo(const block_q4_0* b) { return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), vdupq_n_u8(0x0f))), vdupq_n_s8(0x8)); } inline int8x16_t load_hi(const block_q4_0* b) { return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), vdupq_n_s8(0x8)); } const TA* const A; const TB* const B; TC* const C; const long k; const long lda; const long ldb; const long ldc; const int ith; const int nth; }; #endif // __ARM_FEATURE_DOTPROD #if defined(__AVX2__) || defined(__AVX512F__) template class tinyBLAS_Q0_AVX2 { public: tinyBLAS_Q0_AVX2(long k, const TA* A, long lda, const TB* B, long ldb, TC* C, long ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } void matmul(long m, long n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: void mnpack(long m0, long m, long n0, long n) { long mc, nc, mp, np; #if VECTOR_REGISTERS == 32 if (!FLAG_precise) { switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) { case 0x33: mc = 3; nc = 3; gemm<3, 3, false>(m0, m, n0, n); break; case 0x32: case 0x23: case 0x22: mc = 2; nc = 2; gemm<2, 2, false>(m0, m, n0, n); break; case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, true>(m0, m, n0, n); break; case 0x13: case 0x12: mc = 1; nc = 2; gemm<1, 2, true>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, true>(m0, m, n0, n); break; default: return; } } else { switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) { case 0x33: mc = 3; nc = 3; gemm<3, 3, true>(m0, m, n0, n); break; case 0x32: case 0x23: case 0x22: mc = 2; nc = 2; gemm<2, 2, true>(m0, m, n0, n); break; case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, true>(m0, m, n0, n); break; case 0x13: case 0x12: mc = 1; nc = 2; gemm<1, 2, true>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, true>(m0, m, n0, n); break; default: return; } } #endif #if VECTOR_REGISTERS == 16 if (!FLAG_precise) { switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 2)) { case 0x32: mc = 3; nc = 2; gemm<3, 2, false>(m0, m, n0, n); break; case 0x23: mc = 2; nc = 3; gemm<2, 3, false>(m0, m, n0, n); break; case 0x22: mc = 2; nc = 2; gemm<2, 2, false>(m0, m, n0, n); break; case 0x31: case 0x21: mc = 2; nc = 1; gemm<2, 1, false>(m0, m, n0, n); break; case 0x12: mc = 1; nc = 2; gemm<1, 2, false>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, false>(m0, m, n0, n); break; default: return; } } else { switch ((MIN(m - m0, 2) << 4) | MIN(n - n0, 1)) { case 0x21: mc = 2; nc = 1; gemm<2, 1, true>(m0, m, n0, n); break; case 0x12: mc = 1; nc = 2; gemm<1, 2, true>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1, true>(m0, m, n0, n); break; default: return; } } #endif mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); } template NOINLINE void gemm(long m0, long m, long n0, long n) { long ytiles = RM > 1 ? (m - m0) / RM : 1; long xtiles = RN > 1 ? (n - n0) / RN : 1; long tiles = xtiles * ytiles; long duty = (tiles + nth - 1) / nth; long start = duty * ith; long end = start + duty; if (end > tiles) end = tiles; for (long job = start; job < end; ++job) { long ii = m0 + job / xtiles * RM; long jj = n0 + job % xtiles * RN; __m256 Cv[RN][RM] = {}; __m256 Ce[RN][RM] = {}; for (long l = 0; l < k; ++l) #pragma GCC unroll 100 for (int j = 0; j < RN; ++j) #pragma GCC unroll 100 for (int i = 0; i < RM; ++i) { __m256 a = _mm256_set1_ps(unhalf(INDEX(A, lda, ii + i, l)->d) * unhalf(INDEX(B, ldb, jj + j, l)->d)); __m256 b = updot(_mm256_sign_epi8(load(INDEX(A, lda, ii + i, l)), load(INDEX(A, lda, ii + i, l))), _mm256_sign_epi8(load(INDEX(B, ldb, jj + j, l)), load(INDEX(A, lda, ii + i, l)))); if (PRECISE) Cv[j][i] = madder(a, b, Cv[j][i], &Ce[j][i]); else Cv[j][i] = madd(a, b, Cv[j][i]); } #pragma GCC unroll 100 for (int j = 0; j < RN; ++j) #pragma GCC unroll 100 for (int i = 0; i < RM; ++i) store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i])); } } inline __m256i load(const block_q8_0* b) { return _mm256_loadu_si256((const __m256i*)b->qs); } inline __m256i load(const block_q4_0* b) { __m128i x = _mm_loadu_si128((const __m128i*)b->qs); return _mm256_sub_epi8(_mm256_and_si256(_mm256_set1_epi8(15), _mm256_insertf128_si256(_mm256_castsi128_si256(x), _mm_srli_epi16(x, 4), 1)), _mm256_set1_epi8(8)); } inline __m256 updot(__m256i u, __m256i s) { __m256i res; #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); #else res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); #endif return _mm256_cvtepi32_ps(res); } const TA* const A; const TB* const B; TC* const C; const long k; const long lda; const long ldb; const long ldc; const int ith; const int nth; }; #endif // __AVX2__ } // namespace ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_mixmul.inc ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul.inc // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi // // Copyright 2024 Mozilla Foundation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tinyblas_cpu.h" // // // ██████╗ ██╗ █████╗ ██████╗ // ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝ // ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗ // ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║ // ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║ // ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝ // // MIXTURE OF EXPERTS TENSOR MULTIPLICATION // // // SHAPES // // - weights [cols, rows, experts] // - thought [cols, tasks, tokens] w/ tasks ≤ thinkers // - result [rows, thinkers, tokens] w/ thinkers ≤ experts // - plan [thinkers, tokens] w/ i32 < experts // // DEFINITION // // for thinker in range(thinkers): // for token in range(tokens): // for row in range(rows): // c = 0 // for col in range(cols): // expert = plan[token][thinker] // a = weights[expert][row][col] // b = thought[token][thinker % tasks][col] // c += a * b // result[token][thinker][row] = c // // REGULARITIES // // - tokens can be odd // - thinkers is usually 2 // - tasks is usually 1 or 2 // - cols should be a multiple of 64 // - rows should be a multiple of 64 // - experts is usually 8 but could be 60 // - tokens is always 1 for token generation // - tokens can be huge for prompt processing // // EXAMPLE // // mixtral 8x7b w/ 217 token prompt // // | ne*0 ne*1 ne*2 ne*3 | nb*0 nb*1 nb*2 nb*3 | type // ========================================================================= // weights | 16384 6144 8 1 | 18 0x2400 0x3600000 0x1b000000 | q4_0 // thought | 16384 2 217 1 | 4 0x10000 0x20000 0x1b20000 | f32 // result | 6144 2 217 1 | 4 0x6000 0xc000 0xa2c000 | f32 // plan | 2 217 1 1 | 4 0x20 0x1b20 0x1b20 | i32 // namespace { class MixMul { public: MixMul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) : params(params), weights(weights), thought(thought), plan(plan), result(result), rows(weights->ne[1]), cols(weights->ne[0]), experts(weights->ne[2]), thinkers(plan->ne[0]), tasks(thought->ne[1]), tokens(thought->ne[2]), ldq((cols * 2 + ROW_ALIGN - 1) & -ROW_ALIGN), wdata_((char*)(((uintptr_t)params->wdata + MAX_ALIGN - 1) & -MAX_ALIGN)), allocated_(0) { } bool allocate_shared_memory() { if (!(quantized_thought_ = allocate(MATRIX_ALIGN, tokens * tasks * ldq))) return false; if (!(rowptr_result_ = allocate(ROW_ALIGN, experts * tokens * thinkers))) return false; if (!(rowptr_thought_ = allocate(ROW_ALIGN, experts * tokens * thinkers))) return false; if (!(rowptr_count_ = allocate(sizeof(long), experts))) return false; return true; } size_t get_allocated_bytes() { return (wdata_ - (char*)params->wdata) + allocated_; } bool mixmul() { // invariants assert(tasks <= thinkers); assert(thinkers <= experts); assert(tokens == plan->ne[1]); assert(rows == result->ne[0]); assert(cols == thought->ne[0]); assert(tokens == result->ne[2]); assert(thinkers == result->ne[1]); // dimensionality assert(plan->ne[2] == 1); assert(plan->ne[3] == 1); assert(result->ne[3] == 1); assert(weights->ne[3] == 1); assert(thought->ne[3] == 1); // miscellaneous assert(params->nth > 0); assert(params->ith < params->nth); assert(plan->type == GGML_TYPE_I32); // check nb01 is convertible to lda if (weights->nb[1] % ggml_type_size(weights->type)) return false; // no support for column strides if (result->nb[0] != ggml_type_size(result->type)) return false; if (thought->nb[0] != ggml_type_size(thought->type)) return false; if (weights->nb[0] != ggml_type_size(weights->type)) return false; // supported output types switch (result->type) { case GGML_TYPE_F32: return mixmuler(); default: return false; } } private: template bool mixmuler() { switch (weights->type) { case GGML_TYPE_F32: if (thought->type != GGML_TYPE_F32) return false; #if defined(__AVX512F__) return mixmat<16, 1, tinyBLAS, float, float, TC>(); #elif defined(__AVX__) || defined(__AVX2__) return mixmat<8, 1, tinyBLAS, float, float, TC>(); #elif defined(__SSE__) return mixmat<4, 1, tinyBLAS, float, float, TC>(); #elif defined(__ARM_NEON) return mixmat<4, 1, tinyBLAS, float, float, TC>(); #else return false; #endif case GGML_TYPE_BF16: if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_BF16) return false; #if defined(__AVX512BF16__) if (!FLAG_precise) { return mixmat< 32, 1, tinyBLAS, ggml_bf16_t, ggml_bf16_t, TC>(); } else { return mixmat<16, 1, tinyBLAS, ggml_bf16_t, ggml_bf16_t, TC>(); } #elif defined(__AVX512F__) return mixmat<16, 1, tinyBLAS, ggml_bf16_t, ggml_bf16_t, TC>(); #elif defined(__AVX2__) return mixmat<8, 1, tinyBLAS, ggml_bf16_t, ggml_bf16_t, TC>(); #elif defined(__ARM_NEON) && !defined(_MSC_VER) return mixmat< 4, 1, tinyBLAS, ggml_bf16_t, ggml_bf16_t, TC>(); #else return false; #endif case GGML_TYPE_F16: if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_F16) return false; #if defined(__AVX512F__) return mixmat<16, 1, tinyBLAS, ggml_fp16_t, ggml_fp16_t, TC>(); #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) // if (X86_CHECK(F16C)) { return mixmat<8, 1, tinyBLAS, ggml_fp16_t, ggml_fp16_t, TC>(); // } else { // return false; // } #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) if (result->op_params[0] == GGML_PREC_F32) { return mixmat< 4, 1, tinyBLAS, ggml_fp16_t, ggml_fp16_t, TC>(); } else { return mixmat< 8, 1, tinyBLAS, ggml_fp16_t, ggml_fp16_t, TC>(); } #elif defined(__ARM_NEON) && !defined(_MSC_VER) return mixmat< 4, 1, tinyBLAS, ggml_fp16_t, ggml_fp16_t, TC>(); #else return false; #endif case GGML_TYPE_Q4_0: if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0) return false; #if defined(__AVX2__) || defined(__AVX512F__) return mixmat<32, 32, tinyBLAS_Q0_AVX2, block_q4_0, block_q8_0, TC>(); #elif defined(__ARM_FEATURE_DOTPROD) return mixmat<32, 32, tinyBLAS_Q0_ARM, block_q4_0, block_q8_0, TC>(); #else return false; #endif case GGML_TYPE_Q8_0: if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0) return false; #if defined(__AVX2__) || defined(__AVX512F__) return mixmat<32, 32, tinyBLAS_Q0_AVX2, block_q8_0, block_q8_0, TC>(); #elif defined(__ARM_FEATURE_DOTPROD) return mixmat<32, 32, tinyBLAS_Q0_ARM, block_q8_0, block_q8_0, TC>(); #else return false; #endif default: return false; } } template bool mixmat() { if (cols % KN) return false; switch (params->type) { case GGML_TASK_TYPE_INIT: if (thought->type != ggml_type_trait::id) quantize_thought(ggml_type_trait::id); build_row_pointers(ggml_type_trait::id); return true; case GGML_TASK_TYPE_COMPUTE: assert(!(cols % BS)); assert(!(weights->nb[1] % sizeof(TA))); for (int expert = 0; expert < experts; ++expert) { BLAS tb{cols / BS, (const TA*)((const char*)weights->data + expert * weights->nb[2]), (long)(weights->nb[1] / sizeof(TA)), (const TB*)(rowptr_thought_ + expert * tokens * thinkers), 0, (TC*)(rowptr_result_ + expert * tokens * thinkers), 0, params->ith, params->nth}; tb.matmul(rows, rowptr_count_[expert], GGML_TASK_TYPE_COMPUTE); } return true; default: return true; } } void build_row_pointers(ggml_type vec_dot_type) { for (int expert = params->ith; expert < experts; expert += params->nth) { long count = 0; for (long token = 0; token < tokens; ++token) for (int thinker = 0; thinker < thinkers; ++thinker) if (expert == *(const int32_t*)((const char*)plan->data + token * plan->nb[1] + thinker * plan->nb[0])) { long row = count++; long idx = expert * thinkers * tokens + row; rowptr_result_[idx] = (uintptr_t)((char*)result->data + token * result->nb[2] + thinker * result->nb[1]); if (thought->type == vec_dot_type) rowptr_thought_[idx] = (uintptr_t)((char*)thought->data + token * thought->nb[2] + thinker % tasks * thought->nb[1]); else rowptr_thought_[idx] = (uintptr_t)((char*)quantized_thought_ + token * tasks * ldq + thinker % tasks * ldq); } rowptr_count_[expert] = count; } } void quantize_thought(ggml_type vec_dot_type) { long chore = 0; for (long token = 0; token < tokens; ++token) for (int task = 0; task < tasks; ++task) if (chore++ % params->nth == params->ith) quantize_row(quantized_thought_ + token * tasks * ldq + task * ldq, (const float*)((const char*)thought->data + token * thought->nb[2] + task * thought->nb[1]), vec_dot_type); } void quantize_row(void* dst, const float* src, ggml_type type) { assert((long)ggml_row_size(type, cols) <= ldq); switch (type) { case GGML_TYPE_F16: ggml_fp32_to_fp16_row(src, (ggml_fp16_t*)dst, cols); break; case GGML_TYPE_BF16: ggml_fp32_to_bf16_row(src, (ggml_bf16_t*)dst, cols); break; case GGML_TYPE_Q8_0: quantize_row_q8_0((const float*)src, (block_q8_0*)dst, cols); break; default: GGML_UNREACHABLE(); } } template T* allocate(size_t align, size_t elems) { T* res = nullptr; size_t need = sizeof(T) * elems; size_t base = allocated_; base += align - 1; base &= -align; size_t toto = base + need; if (toto >= allocated_ && toto <= params->wsize) { res = (T*)(wdata_ + base); allocated_ = toto; } return res; } const ggml_compute_params* const params; const ggml_tensor* const weights; const ggml_tensor* const thought; const ggml_tensor* const plan; ggml_tensor* const result; const long rows; const long cols; const int experts; const int thinkers; const int tasks; const long tokens; const long ldq; // variables char* const wdata_; size_t allocated_; // shared memory long* rowptr_count_ /*[experts]*/; char* quantized_thought_ /*[tokens][tasks][cols][2]*/; uintptr_t* rowptr_result_ /*[experts][tokens*thinkers]*/; uintptr_t* rowptr_thought_ /*[experts][tokens*thinkers]*/; }; } // namespace /** * Performs "mixture of experts" tensor multiplication on CPU. */ bool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) { MixMul mm{params, weights, thought, plan, result}; return mm.allocate_shared_memory() && mm.mixmul(); } ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avx.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #if defined(__x86_64__) || defined(_M_X64) #define llamafile_mixmul llamafile_mixmul_amd_avx #include "tinyblas_cpu_mixmul.inc" /** * Returns number of shared memory bytes llamafile_mixmul() needs. */ size_t llamafile_mixmul_needs(const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan) { ggml_compute_params params{}; params.wsize = 0x7ffff000; params.wdata = (void*)0x1000; MixMul mm{¶ms, weights, thought, plan, 0}; if (mm.allocate_shared_memory()) return mm.get_allocated_bytes(); else return 0; } #endif // __x86_64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx2.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avx2.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #if defined(__x86_64__) || defined(_M_X64) #define llamafile_mixmul llamafile_mixmul_amd_avx2 #include "tinyblas_cpu_mixmul.inc" #endif // __x86_64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx512f.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avx512f.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #if defined(__x86_64__) || defined(_M_X64) #define llamafile_mixmul llamafile_mixmul_amd_avx512f #include "tinyblas_cpu_mixmul.inc" #endif // __x86_64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_mixmul_amd_avxvnni.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avxvnni.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #if defined(__x86_64__) || defined(_M_X64) #define llamafile_mixmul llamafile_mixmul_amd_avxvnni #include "tinyblas_cpu_mixmul.inc" #endif // __x86_64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_mixmul_amd_fma.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_fma.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #if defined(__x86_64__) || defined(_M_X64) #define llamafile_mixmul llamafile_mixmul_amd_fma #include "tinyblas_cpu_mixmul.inc" #endif // __x86_64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_mixmul_amd_zen4.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_zen4.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #if defined(__x86_64__) || defined(_M_X64) #define llamafile_mixmul llamafile_mixmul_amd_zen4 #include "tinyblas_cpu_mixmul.inc" #endif // __x86_64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_mixmul_arm80.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_arm80.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #ifdef __aarch64__ #define llamafile_mixmul llamafile_mixmul_arm80 #include "tinyblas_cpu_mixmul.inc" /** * Returns number of shared memory bytes llamafile_mixmul() needs. */ size_t llamafile_mixmul_needs(const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan) { ggml_compute_params params{}; params.wsize = 0x7ffff000; params.wdata = (void*)0x1000; MixMul mm{¶ms, weights, thought, plan, 0}; if (mm.allocate_shared_memory()) return mm.get_allocated_bytes(); else return 0; } #endif // __aarch64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_mixmul_arm82.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_arm82.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #ifdef __aarch64__ #define llamafile_mixmul llamafile_mixmul_arm82 #include "tinyblas_cpu_mixmul.inc" #endif // __aarch64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_sgemm.inc ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi // // Copyright 2024 Mozilla Foundation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // // ██████╗ ██╗ █████╗ ██████╗ // ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝ // ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗ // ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║ // ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║ // ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝ // // BASIC LINEAR ALGEBRA SUBPROGRAMS // // // This file implements multithreaded CPU matrix multiplication for the // common contiguous use case C = Aᵀ * B. These kernels are designed to // have excellent performance[1] for matrices that fit in the CPU cache // without imposing any overhead such as cache filling or malloc calls. // // This implementation does not guarantee any upper bound with rounding // errors, which grow along with k. Our goal's to maximally exploit the // hardware for performance, and then use whatever resources remain for // improving numerical accuracy. // // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online]. // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024]. #if defined(KTRANSFORMERS_USE_NPU) && KTRANSFORMERS_USE_NPU // use ARM version #include "tinyblas_cpu_sgemm_arm.inc" #else // use x86 version #include "tinyblas_cpu_sgemm_x86.inc" #endif ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_avx.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #if defined(__x86_64__) || defined(_M_X64) #define llamafile_sgemm llamafile_sgemm_amd_avx #include "tinyblas_cpu_sgemm.inc" #endif // __x86_64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx2.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_avx2.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #if defined(__x86_64__) || defined(_M_X64) #define llamafile_sgemm llamafile_sgemm_amd_avx2 #include "tinyblas_cpu_sgemm.inc" #endif // __x86_64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx512f.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_avx512f.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #if defined(__x86_64__) || defined(_M_X64) #define llamafile_sgemm llamafile_sgemm_amd_avx512f #include "tinyblas_cpu_sgemm.inc" #endif // __x86_64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_sgemm_amd_avxvnni.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_avxvnni.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #if defined(__x86_64__) || defined(_M_X64) #define llamafile_sgemm llamafile_sgemm_amd_avxvnni #include "tinyblas_cpu_sgemm.inc" #endif // __x86_64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_sgemm_amd_fma.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_fma.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #if defined(__x86_64__) || defined(_M_X64) #define llamafile_sgemm llamafile_sgemm_amd_fma #include "tinyblas_cpu_sgemm.inc" #endif // __x86_64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_sgemm_amd_zen4.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_zen4.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #if defined(__x86_64__) || defined(_M_X64) #define llamafile_sgemm llamafile_sgemm_amd_zen4 #define iqk_mul_mat iqk_mul_mat_zen4 #include "tinyblas_cpu_sgemm.inc" #endif // __x86_64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_sgemm_arm.inc ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi // // Copyright 2024 Mozilla Foundation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tinyblas_cpu.h" #include #include #include // // // ██████╗ ██╗ █████╗ ██████╗ // ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝ // ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗ // ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║ // ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║ // ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝ // // BASIC LINEAR ALGEBRA SUBPROGRAMS // // // This file implements multithreaded CPU matrix multiplication for the // common contiguous use case C = Aᵀ * B. These kernels are designed to // have excellent performance[1] for matrices that fit in the CPU cache // without imposing any overhead such as cache filling or malloc calls. // // This implementation does not guarantee any upper bound with rounding // errors, which grow along with k. Our goal's to maximally exploit the // hardware for performance, and then use whatever resources remain for // improving numerical accuracy. // // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online]. // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024]. namespace { template void SgemmHelperN1Neon2(long m, long n, long k, const float16_t* A, long lda, const float16_t* B, long ldb, TC* C, long ldc, int ith, int nth) { // A m * k B n * k c n * m const long NVL = 8; long kk = k / (NVL * 4); kk = kk * (NVL * 4); long length = (m / nth) + (ith < (m % nth) ? 1 : 0); long startRow = ith * (m / nth) + (ith < (m % nth) ? ith : (m % nth)); long endRow = startRow + length; for (long i = startRow; i < endRow; i ++) { const float16_t* tA = A + i * lda; float32x4_t c0 = vdupq_n_f32(0); float32x4_t c1 = vdupq_n_f32(0); float32x4_t c2 = vdupq_n_f32(0); float32x4_t c3 = vdupq_n_f32(0); float32x4_t c4 = vdupq_n_f32(0); float32x4_t c5 = vdupq_n_f32(0); float32x4_t c6 = vdupq_n_f32(0); float32x4_t c7 = vdupq_n_f32(0); for (long j = 0; j < kk; j += NVL * 4) { __builtin_prefetch(tA + 192, 0, 0); float16x8_t a0 = vld1q_f16(tA + j); float16x8_t b0 = vld1q_f16(B + j); c0 = vfmlalq_low_f16(c0, a0, b0); c1 = vfmlalq_high_f16(c1, a0, b0); float16x8_t a1 = vld1q_f16(tA + j + NVL); float16x8_t b1 = vld1q_f16(B + j + NVL); c2 = vfmlalq_low_f16(c2, a1, b1); c3 = vfmlalq_high_f16(c3, a1, b1); float16x8_t a2 = vld1q_f16(tA + j + NVL * 2); float16x8_t b2 = vld1q_f16(B + j + NVL * 2); c4 = vfmlalq_low_f16(c4, a2, b2); c5 = vfmlalq_high_f16(c5, a2, b2); float16x8_t a3 = vld1q_f16(tA + j + NVL * 3); float16x8_t b3 = vld1q_f16(B + j + NVL * 3); c6 = vfmlalq_low_f16(c6, a3, b3); c7 = vfmlalq_high_f16(c7, a3, b3); } if (k - kk >= NVL * 2) { float16x8_t a0 = vld1q_f16(tA + kk); float16x8_t b0 = vld1q_f16(B + kk); c0 = vfmlalq_low_f16(c0, a0, b0); c1 = vfmlalq_high_f16(c1, a0, b0); float16x8_t a1 = vld1q_f16(tA + kk + NVL); float16x8_t b1 = vld1q_f16(B + kk + NVL); c2 = vfmlalq_low_f16(c2, a1, b1); c3 = vfmlalq_high_f16(c3, a1, b1); kk += NVL * 2; } if (k - kk >= NVL) { float16x8_t a = vld1q_f16(tA + kk); float16x8_t b = vld1q_f16(B + kk); c0 = vfmlalq_low_f16(c0, a, b); c1 = vfmlalq_high_f16(c1, a, b); kk += NVL; } TC sum = 0.0f; for (long j = kk; j < k; j ++) { sum += (float32_t)tA[j] * (float32_t)B[j]; } c0 = vaddq_f32(c0, c1); c2 = vaddq_f32(c2, c3); c4 = vaddq_f32(c4, c5); c6 = vaddq_f32(c6, c7); c0 = vaddq_f32(c0, c2); c4 = vaddq_f32(c4, c6); sum += vaddvq_f32(c0) + vaddvq_f32(c4); C[i] = sum; } return; } template void SgemmHelperN1(long m, long n, long k, const ggml_fp16_t* A_, long lda, const ggml_fp16_t* B_, long ldb, TC* C, long ldc, int ith, int nth) { // A m * k B n * k c n * m float16_t *A = (float16_t*)A_; float16_t *B = (float16_t*)B_; long rowsPerThread = m / nth; long startRow = ith * rowsPerThread; long endRow = (ith == nth - 1) ? m : startRow + rowsPerThread; for (long i = startRow; i < endRow; i ++) { TC sum = 0.0f; for (long j = 0; j < k; j ++) { sum += (float32_t)A[i * lda + j] * (float32_t)B[j]; } C[i] = sum; } return; } template bool llamafile_sgemm_impl(long m, long n, long k, const void* A, long lda, const void* B, long ldb, TC* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) { // std::cout << "tinyBLAS tinyBLAS NOT_SUPPORTED FP16 55, n: " << n << ", m: " << m << ", k: " << k << ", FLAG_precise: " << FLAG_precise << "\n"< tb{ k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__AVX__) || defined(__AVX2__) if (k % 8) return NOT_SUPPORTED; tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{ k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_NEON) if (k % 4) return NOT_SUPPORTED; tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{ k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else return NOT_SUPPORTED; #endif } case GGML_TYPE_BF16: { #if defined(__AVX512BF16__) if (k % 32) return NOT_SUPPORTED; if (Btype == GGML_TYPE_F32 && n < 2) { tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{ k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; } if (Btype == GGML_TYPE_F32) return WANT_QUANTIZATION; if (Btype != GGML_TYPE_BF16) return NOT_SUPPORTED; if (!FLAG_precise) { tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{ k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; } else { tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{ k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; } #elif defined(__AVX512F__) if (k % 16) return NOT_SUPPORTED; tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{ k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__AVX2__) if (k % 8) return NOT_SUPPORTED; if (Btype != GGML_TYPE_F32) return NOT_SUPPORTED; tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{ k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_NEON) && !defined(_MSC_VER) if (k % 4) return NOT_SUPPORTED; if (Btype != GGML_TYPE_F32) return NOT_SUPPORTED; tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{ k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else return NOT_SUPPORTED; #endif } case GGML_TYPE_F16: { #if defined(__AVX512F__) if (k % 16) return NOT_SUPPORTED; if (Btype == GGML_TYPE_F32 && n < 2) { tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{ k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; } if (Btype == GGML_TYPE_F32) return WANT_QUANTIZATION; if (Btype != GGML_TYPE_F16) return NOT_SUPPORTED; tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{ k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) // if (X86_CHECK(F16C)) { if (k % 8) return NOT_SUPPORTED; if (Btype == GGML_TYPE_F32 && n < 2) { tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{ k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; } if (Btype == GGML_TYPE_F32) return WANT_QUANTIZATION; if (Btype != GGML_TYPE_F16) return NOT_SUPPORTED; tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{ k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; // } else { // return NOT_SUPPORTED; // } #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) if (n < 2 && !FLAG_precise) { // TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec? if (Btype == GGML_TYPE_F16 && task == GGML_TASK_TYPE_COMPUTE) { SgemmHelperN1Neon2(m, n, k, (const float16_t*)A, lda, (const float16_t*)B, ldb, C, ldc, ith, nth); // SgemmHelperN1(m, n, k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth); return true; } return NOT_SUPPORTED; } if (precision == GGML_PREC_F32) { if (k % 4) return NOT_SUPPORTED; if (Btype != GGML_TYPE_F32) return NOT_SUPPORTED; tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{ k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; } else { if (k % 8) return NOT_SUPPORTED; if (Btype == GGML_TYPE_F32) return WANT_QUANTIZATION; if (Btype != GGML_TYPE_F16) return NOT_SUPPORTED; tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{ k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; } #elif defined(__ARM_NEON) && !defined(_MSC_VER) if (n < 2 && !FLAG_precise) { // TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec? // printf("tinyBLAS tinyBLAS NOT_SUPPORTED FP16 225, m: %ld, n: %ld, k: %ld\n", m, n, k); if (Btype == GGML_TYPE_F16 && task == GGML_TASK_TYPE_COMPUTE) { SgemmHelperN1Neon2(m, n, k, (const float16_t*)A, lda, (const float16_t*)B, ldb, C, ldc, ith, nth); // SgemmHelperN1(m, n, k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth); return true; } std::cout << "tinyBLAS tinyBLAS NOT_SUPPORTED FP16 231, n: " << n << ", m: " << m << ", k: " << m << ", FLAG_precise: " << FLAG_precise << "\n"< tb{ k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else // std::cout << "tinyBLAS tinyBLAS NOT_SUPPORTED FP16" < tb{ k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_FEATURE_DOTPROD) tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, TC> tb{ k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else return NOT_SUPPORTED; #endif } case GGML_TYPE_Q4_0: { if (Btype == GGML_TYPE_F32) return WANT_QUANTIZATION; if (Btype != GGML_TYPE_Q8_0) return NOT_SUPPORTED; #if defined(__AVX2__) || defined(__AVX512F__) tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, TC> tb{ k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_FEATURE_DOTPROD) tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, TC> tb{ k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else return NOT_SUPPORTED; #endif } default: return NOT_SUPPORTED; } (void)m; (void)n; (void)k; (void)A; (void)lda; (void)B; (void)ldb; (void)C; (void)ldc; (void)ith; (void)nth; (void)Atype; (void)Btype; (void)precision; } } // namespace /** * Performs optimized matrix multiplication on CPU. * * This subroutine may compute C = Aᵀ * B with column major ordering. * Despite its name, this isn't a generalized implementation. Work is * only performed when a handwritten kernel is written and available. * Otherwise the caller should fall back to a general matmul routine. * * For example, for single-threaded single-precision GEMM you can say * * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1, * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, * GGML_PREC_DEFAULT); * * @param m is rows in `A` and `C` * @param n is cols in `B` and `C` * @param k is cols in `A` and rows in `B` * @param A is first input matrix (always transposed) * @param lda is row stride of `A` * @param B is second input matrix (never transposed) * @param ldb is row stride of `B` * @param C is input/output array of output matrices * @param ldc is row stride of `C` * @param ith is thread id (must be less than `nth`) * @param nth is number of threads (must be greater than zero) * @param Atype is GGML data type of `A` * @param Btype is GGML data type of `B` * @param Ctype is GGML data type of `C` * @param precision may be used to control the internal compute type * @return true if this function was able to service the matmul request */ bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) { assert(m >= 0); assert(n >= 0); assert(k >= 0); assert(lda >= k); assert(ldb >= k); assert(ldc >= m); assert(nth > 0); assert(ith < nth); #if QK_K == 256 #if defined(__x86_64__) || defined(_M_X64) #if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))) // if (X86_CHECK(AVX2) && X86_CHECK(FMA)) { if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32){ if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) { return true; } } if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) { // assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32); assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32)); if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) { return true; } } // } #endif #elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) { if (iqk_mul_mat(m, n, k * QK_K, Atype, A, k, Btype, B, k, (float*)C, ldc, ith, nth)) { return true; } } if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) { // assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32); assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32)); if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, k, Btype, B, k, (float*)C, ldc, ith, nth)) { return true; } } #endif #endif switch (Ctype) { case GGML_TYPE_F32: return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float*)C, ldc, ith, nth, task, Atype, Btype, Ctype, precision); default: return NOT_SUPPORTED; } } ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_sgemm_arm80.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_arm80.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #ifdef __aarch64__ #define llamafile_sgemm llamafile_sgemm_arm80 #include "tinyblas_cpu_sgemm.inc" #endif // __aarch64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_sgemm_arm82.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_arm82.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. #ifdef __aarch64__ #define llamafile_sgemm llamafile_sgemm_arm82 #define iqk_mul_mat iqk_mul_mat_arm82 #include "tinyblas_cpu_sgemm.inc" #endif // __aarch64__ ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_sgemm_x86.inc ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi // // Copyright 2024 Mozilla Foundation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "tinyblas_cpu.h" // // // ██████╗ ██╗ █████╗ ██████╗ // ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝ // ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗ // ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║ // ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║ // ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝ // // BASIC LINEAR ALGEBRA SUBPROGRAMS // // // This file implements multithreaded CPU matrix multiplication for the // common contiguous use case C = Aᵀ * B. These kernels are designed to // have excellent performance[1] for matrices that fit in the CPU cache // without imposing any overhead such as cache filling or malloc calls. // // This implementation does not guarantee any upper bound with rounding // errors, which grow along with k. Our goal's to maximally exploit the // hardware for performance, and then use whatever resources remain for // improving numerical accuracy. // // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online]. // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024]. namespace { template bool llamafile_sgemm_impl(long m, long n, long k, const void* A, long lda, const void* B, long ldb, TC* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) { switch (Atype) { case GGML_TYPE_F32: { if (Btype != GGML_TYPE_F32) return NOT_SUPPORTED; #if defined(__AVX512F__) if (k % 16) return NOT_SUPPORTED; tinyBLAS<0, 16, __m512, __m512, float, float, TC> tb{ k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__AVX__) || defined(__AVX2__) if (k % 8) return NOT_SUPPORTED; tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{ k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_NEON) if (k % 4) return NOT_SUPPORTED; tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{ k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else return NOT_SUPPORTED; #endif } case GGML_TYPE_BF16: { #if defined(__AVX512BF16__) if (k % 32) return NOT_SUPPORTED; if (Btype == GGML_TYPE_F32 && n < 2) { tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{ k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; } if (Btype == GGML_TYPE_F32) return WANT_QUANTIZATION; if (Btype != GGML_TYPE_BF16) return NOT_SUPPORTED; if (!FLAG_precise) { tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{ k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; } else { tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{ k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; } #elif defined(__AVX512F__) if (k % 16) return NOT_SUPPORTED; tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{ k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__AVX2__) if (k % 8) return NOT_SUPPORTED; if (Btype != GGML_TYPE_F32) return NOT_SUPPORTED; tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{ k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_NEON) && !defined(_MSC_VER) if (k % 4) return NOT_SUPPORTED; if (Btype != GGML_TYPE_F32) return NOT_SUPPORTED; tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{ k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else return NOT_SUPPORTED; #endif } case GGML_TYPE_F16: { #if defined(__AVX512F__) if (k % 16) return NOT_SUPPORTED; if (Btype == GGML_TYPE_F32 && n < 2) { tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{ k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; } if (Btype == GGML_TYPE_F32) return WANT_QUANTIZATION; if (Btype != GGML_TYPE_F16) return NOT_SUPPORTED; tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{ k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) // if (X86_CHECK(F16C)) { if (k % 8) return NOT_SUPPORTED; if (Btype == GGML_TYPE_F32 && n < 2) { tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{ k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; } if (Btype == GGML_TYPE_F32) return WANT_QUANTIZATION; if (Btype != GGML_TYPE_F16) return NOT_SUPPORTED; tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{ k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; // } else { // return NOT_SUPPORTED; // } #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) if (n < 2 && !FLAG_precise) // TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec? return NOT_SUPPORTED; if (precision == GGML_PREC_F32) { if (k % 4) return NOT_SUPPORTED; if (Btype != GGML_TYPE_F32) return NOT_SUPPORTED; tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{ k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; } else { if (k % 8) return NOT_SUPPORTED; if (Btype == GGML_TYPE_F32) return WANT_QUANTIZATION; if (Btype != GGML_TYPE_F16) return NOT_SUPPORTED; tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{ k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; } #elif defined(__ARM_NEON) && !defined(_MSC_VER) if (n < 2 && !FLAG_precise) // TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec? return NOT_SUPPORTED; if (k % 4) return NOT_SUPPORTED; if (Btype != GGML_TYPE_F32) return NOT_SUPPORTED; tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{ k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else return NOT_SUPPORTED; #endif } case GGML_TYPE_Q8_0: { if (Btype == GGML_TYPE_F32) return WANT_QUANTIZATION; if (Btype != GGML_TYPE_Q8_0) return NOT_SUPPORTED; #if defined(__AVX2__) || defined(__AVX512F__) tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, TC> tb{ k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_FEATURE_DOTPROD) tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, TC> tb{ k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else return NOT_SUPPORTED; #endif } case GGML_TYPE_Q4_0: { if (Btype == GGML_TYPE_F32) return WANT_QUANTIZATION; if (Btype != GGML_TYPE_Q8_0) return NOT_SUPPORTED; #if defined(__AVX2__) || defined(__AVX512F__) tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, TC> tb{ k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_FEATURE_DOTPROD) tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, TC> tb{ k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else return NOT_SUPPORTED; #endif } default: return NOT_SUPPORTED; } (void)m; (void)n; (void)k; (void)A; (void)lda; (void)B; (void)ldb; (void)C; (void)ldc; (void)ith; (void)nth; (void)Atype; (void)Btype; (void)precision; } } // namespace /** * Performs optimized matrix multiplication on CPU. * * This subroutine may compute C = Aᵀ * B with column major ordering. * Despite its name, this isn't a generalized implementation. Work is * only performed when a handwritten kernel is written and available. * Otherwise the caller should fall back to a general matmul routine. * * For example, for single-threaded single-precision GEMM you can say * * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1, * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, * GGML_PREC_DEFAULT); * * @param m is rows in `A` and `C` * @param n is cols in `B` and `C` * @param k is cols in `A` and rows in `B` * @param A is first input matrix (always transposed) * @param lda is row stride of `A` * @param B is second input matrix (never transposed) * @param ldb is row stride of `B` * @param C is input/output array of output matrices * @param ldc is row stride of `C` * @param ith is thread id (must be less than `nth`) * @param nth is number of threads (must be greater than zero) * @param Atype is GGML data type of `A` * @param Btype is GGML data type of `B` * @param Ctype is GGML data type of `C` * @param precision may be used to control the internal compute type * @return true if this function was able to service the matmul request */ bool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) { assert(m >= 0); assert(n >= 0); assert(k >= 0); assert(lda >= k); assert(ldb >= k); assert(ldc >= m); assert(nth > 0); assert(ith < nth); #if QK_K == 256 #if defined(__x86_64__) || defined(_M_X64) #if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))) /* moonll more Btype accept }*/ if (Ctype == GGML_TYPE_F32){ if (iqk_mul_mat(m, n, k * ggml_blck_size(ggml_type(Atype)), Atype, A,lda,Btype, B,ldb, (float*)C, ldc, ith, nth)) { return true; } } #endif #elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) { if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) { return true; } } if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) { // assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32); assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32)); if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) { return true; } } #endif #endif switch (Ctype) { case GGML_TYPE_F32: return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float*)C, ldc, ith, nth, task, Atype, Btype, Ctype, precision); default: return NOT_SUPPORTED; } } ================================================ FILE: archive/third_party/llamafile/tinyblas_cpu_unsupported.cpp ================================================ // Adapted from // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_unsupported.cpp // Copyrigth 2024 Mozilla Foundation. // Copyright(c) 2024 by KVCache.AI, All Rights Reserved. // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- // vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi // // Copyright 2024 Mozilla Foundation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "sgemm.h" bool llamafile_sgemm_unsupported(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) { return false; } bool llamafile_mixmul_unsupported(const struct ggml_compute_params* params, const struct ggml_tensor* weights, const struct ggml_tensor* thought, const struct ggml_tensor* plan, struct ggml_tensor* result) { return false; } bool iqk_mul_mat_moe_unsupported(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int) { return false; } ================================================ FILE: archive/third_party/nlohmann/json.hpp ================================================ // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT /****************************************************************************\ * Note on documentation: The source files contain links to the online * * documentation of the public API at https://json.nlohmann.me. This URL * * contains the most recent documentation and should also be applicable to * * previous versions; documentation for deprecated functions is not * * removed, but marked deprecated. See "Generate documentation" section in * * file docs/README.md. * \****************************************************************************/ #ifndef INCLUDE_NLOHMANN_JSON_HPP_ #define INCLUDE_NLOHMANN_JSON_HPP_ #include // all_of, find, for_each #include // nullptr_t, ptrdiff_t, size_t #include // hash, less #include // initializer_list #ifndef JSON_NO_IO #include // istream, ostream #endif // JSON_NO_IO #include // random_access_iterator_tag #include // unique_ptr #include // string, stoi, to_string #include // declval, forward, move, pair, swap #include // vector // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT #include // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT // This file contains all macro definitions affecting or depending on the ABI #ifndef JSON_SKIP_LIBRARY_VERSION_CHECK #if defined(NLOHMANN_JSON_VERSION_MAJOR) && defined(NLOHMANN_JSON_VERSION_MINOR) && defined(NLOHMANN_JSON_VERSION_PATCH) #if NLOHMANN_JSON_VERSION_MAJOR != 3 || NLOHMANN_JSON_VERSION_MINOR != 11 || NLOHMANN_JSON_VERSION_PATCH != 3 #warning "Already included a different version of the library!" #endif #endif #endif #define NLOHMANN_JSON_VERSION_MAJOR 3 // NOLINT(modernize-macro-to-enum) #define NLOHMANN_JSON_VERSION_MINOR 11 // NOLINT(modernize-macro-to-enum) #define NLOHMANN_JSON_VERSION_PATCH 3 // NOLINT(modernize-macro-to-enum) #ifndef JSON_DIAGNOSTICS #define JSON_DIAGNOSTICS 0 #endif #ifndef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON #define JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 0 #endif #if JSON_DIAGNOSTICS #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS _diag #else #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS #endif #if JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON _ldvcmp #else #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON #endif #ifndef NLOHMANN_JSON_NAMESPACE_NO_VERSION #define NLOHMANN_JSON_NAMESPACE_NO_VERSION 0 #endif // Construct the namespace ABI tags component #define NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b) json_abi ## a ## b #define NLOHMANN_JSON_ABI_TAGS_CONCAT(a, b) \ NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b) #define NLOHMANN_JSON_ABI_TAGS \ NLOHMANN_JSON_ABI_TAGS_CONCAT( \ NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS, \ NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON) // Construct the namespace version component #define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) \ _v ## major ## _ ## minor ## _ ## patch #define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(major, minor, patch) \ NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) #if NLOHMANN_JSON_NAMESPACE_NO_VERSION #define NLOHMANN_JSON_NAMESPACE_VERSION #else #define NLOHMANN_JSON_NAMESPACE_VERSION \ NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(NLOHMANN_JSON_VERSION_MAJOR, \ NLOHMANN_JSON_VERSION_MINOR, \ NLOHMANN_JSON_VERSION_PATCH) #endif // Combine namespace components #define NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) a ## b #define NLOHMANN_JSON_NAMESPACE_CONCAT(a, b) \ NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) #ifndef NLOHMANN_JSON_NAMESPACE #define NLOHMANN_JSON_NAMESPACE \ nlohmann::NLOHMANN_JSON_NAMESPACE_CONCAT( \ NLOHMANN_JSON_ABI_TAGS, \ NLOHMANN_JSON_NAMESPACE_VERSION) #endif #ifndef NLOHMANN_JSON_NAMESPACE_BEGIN #define NLOHMANN_JSON_NAMESPACE_BEGIN \ namespace nlohmann \ { \ inline namespace NLOHMANN_JSON_NAMESPACE_CONCAT( \ NLOHMANN_JSON_ABI_TAGS, \ NLOHMANN_JSON_NAMESPACE_VERSION) \ { #endif #ifndef NLOHMANN_JSON_NAMESPACE_END #define NLOHMANN_JSON_NAMESPACE_END \ } /* namespace (inline namespace) NOLINT(readability/namespace) */ \ } // namespace nlohmann #endif // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT #include // transform #include // array #include // forward_list #include // inserter, front_inserter, end #include // map #include // string #include // tuple, make_tuple #include // is_arithmetic, is_same, is_enum, underlying_type, is_convertible #include // unordered_map #include // pair, declval #include // valarray // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT #include // nullptr_t #include // exception #if JSON_DIAGNOSTICS #include // accumulate #endif #include // runtime_error #include // to_string #include // vector // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT #include // array #include // size_t #include // uint8_t #include // string // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT #include // declval, pair // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT #include // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT // #include NLOHMANN_JSON_NAMESPACE_BEGIN namespace detail { template struct make_void { using type = void; }; template using void_t = typename make_void::type; } // namespace detail NLOHMANN_JSON_NAMESPACE_END NLOHMANN_JSON_NAMESPACE_BEGIN namespace detail { // https://en.cppreference.com/w/cpp/experimental/is_detected struct nonesuch { nonesuch() = delete; ~nonesuch() = delete; nonesuch(nonesuch const&) = delete; nonesuch(nonesuch const&&) = delete; void operator=(nonesuch const&) = delete; void operator=(nonesuch&&) = delete; }; template class Op, class... Args> struct detector { using value_t = std::false_type; using type = Default; }; template class Op, class... Args> struct detector>, Op, Args...> { using value_t = std::true_type; using type = Op; }; template class Op, class... Args> using is_detected = typename detector::value_t; template class Op, class... Args> struct is_detected_lazy : is_detected { }; template class Op, class... Args> using detected_t = typename detector::type; template class Op, class... Args> using detected_or = detector; template class Op, class... Args> using detected_or_t = typename detected_or::type; template class Op, class... Args> using is_detected_exact = std::is_same>; template class Op, class... Args> using is_detected_convertible = std::is_convertible, To>; } // namespace detail NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-FileCopyrightText: 2016-2021 Evan Nemerson // SPDX-License-Identifier: MIT /* Hedley - https://nemequ.github.io/hedley * Created by Evan Nemerson */ #if !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < 15) #if defined(JSON_HEDLEY_VERSION) #undef JSON_HEDLEY_VERSION #endif #define JSON_HEDLEY_VERSION 15 #if defined(JSON_HEDLEY_STRINGIFY_EX) #undef JSON_HEDLEY_STRINGIFY_EX #endif #define JSON_HEDLEY_STRINGIFY_EX(x) #x #if defined(JSON_HEDLEY_STRINGIFY) #undef JSON_HEDLEY_STRINGIFY #endif #define JSON_HEDLEY_STRINGIFY(x) JSON_HEDLEY_STRINGIFY_EX(x) #if defined(JSON_HEDLEY_CONCAT_EX) #undef JSON_HEDLEY_CONCAT_EX #endif #define JSON_HEDLEY_CONCAT_EX(a,b) a##b #if defined(JSON_HEDLEY_CONCAT) #undef JSON_HEDLEY_CONCAT #endif #define JSON_HEDLEY_CONCAT(a,b) JSON_HEDLEY_CONCAT_EX(a,b) #if defined(JSON_HEDLEY_CONCAT3_EX) #undef JSON_HEDLEY_CONCAT3_EX #endif #define JSON_HEDLEY_CONCAT3_EX(a,b,c) a##b##c #if defined(JSON_HEDLEY_CONCAT3) #undef JSON_HEDLEY_CONCAT3 #endif #define JSON_HEDLEY_CONCAT3(a,b,c) JSON_HEDLEY_CONCAT3_EX(a,b,c) #if defined(JSON_HEDLEY_VERSION_ENCODE) #undef JSON_HEDLEY_VERSION_ENCODE #endif #define JSON_HEDLEY_VERSION_ENCODE(major,minor,revision) (((major) * 1000000) + ((minor) * 1000) + (revision)) #if defined(JSON_HEDLEY_VERSION_DECODE_MAJOR) #undef JSON_HEDLEY_VERSION_DECODE_MAJOR #endif #define JSON_HEDLEY_VERSION_DECODE_MAJOR(version) ((version) / 1000000) #if defined(JSON_HEDLEY_VERSION_DECODE_MINOR) #undef JSON_HEDLEY_VERSION_DECODE_MINOR #endif #define JSON_HEDLEY_VERSION_DECODE_MINOR(version) (((version) % 1000000) / 1000) #if defined(JSON_HEDLEY_VERSION_DECODE_REVISION) #undef JSON_HEDLEY_VERSION_DECODE_REVISION #endif #define JSON_HEDLEY_VERSION_DECODE_REVISION(version) ((version) % 1000) #if defined(JSON_HEDLEY_GNUC_VERSION) #undef JSON_HEDLEY_GNUC_VERSION #endif #if defined(__GNUC__) && defined(__GNUC_PATCHLEVEL__) #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) #elif defined(__GNUC__) #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, 0) #endif #if defined(JSON_HEDLEY_GNUC_VERSION_CHECK) #undef JSON_HEDLEY_GNUC_VERSION_CHECK #endif #if defined(JSON_HEDLEY_GNUC_VERSION) #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GNUC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_MSVC_VERSION) #undef JSON_HEDLEY_MSVC_VERSION #endif #if defined(_MSC_FULL_VER) && (_MSC_FULL_VER >= 140000000) && !defined(__ICL) #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 10000000, (_MSC_FULL_VER % 10000000) / 100000, (_MSC_FULL_VER % 100000) / 100) #elif defined(_MSC_FULL_VER) && !defined(__ICL) #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 1000000, (_MSC_FULL_VER % 1000000) / 10000, (_MSC_FULL_VER % 10000) / 10) #elif defined(_MSC_VER) && !defined(__ICL) #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_VER / 100, _MSC_VER % 100, 0) #endif #if defined(JSON_HEDLEY_MSVC_VERSION_CHECK) #undef JSON_HEDLEY_MSVC_VERSION_CHECK #endif #if !defined(JSON_HEDLEY_MSVC_VERSION) #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (0) #elif defined(_MSC_VER) && (_MSC_VER >= 1400) #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 10000000) + (minor * 100000) + (patch))) #elif defined(_MSC_VER) && (_MSC_VER >= 1200) #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 1000000) + (minor * 10000) + (patch))) #else #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_VER >= ((major * 100) + (minor))) #endif #if defined(JSON_HEDLEY_INTEL_VERSION) #undef JSON_HEDLEY_INTEL_VERSION #endif #if defined(__INTEL_COMPILER) && defined(__INTEL_COMPILER_UPDATE) && !defined(__ICL) #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, __INTEL_COMPILER_UPDATE) #elif defined(__INTEL_COMPILER) && !defined(__ICL) #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, 0) #endif #if defined(JSON_HEDLEY_INTEL_VERSION_CHECK) #undef JSON_HEDLEY_INTEL_VERSION_CHECK #endif #if defined(JSON_HEDLEY_INTEL_VERSION) #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_INTEL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_INTEL_CL_VERSION) #undef JSON_HEDLEY_INTEL_CL_VERSION #endif #if defined(__INTEL_COMPILER) && defined(__INTEL_COMPILER_UPDATE) && defined(__ICL) #define JSON_HEDLEY_INTEL_CL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER, __INTEL_COMPILER_UPDATE, 0) #endif #if defined(JSON_HEDLEY_INTEL_CL_VERSION_CHECK) #undef JSON_HEDLEY_INTEL_CL_VERSION_CHECK #endif #if defined(JSON_HEDLEY_INTEL_CL_VERSION) #define JSON_HEDLEY_INTEL_CL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_INTEL_CL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_INTEL_CL_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_PGI_VERSION) #undef JSON_HEDLEY_PGI_VERSION #endif #if defined(__PGI) && defined(__PGIC__) && defined(__PGIC_MINOR__) && defined(__PGIC_PATCHLEVEL__) #define JSON_HEDLEY_PGI_VERSION JSON_HEDLEY_VERSION_ENCODE(__PGIC__, __PGIC_MINOR__, __PGIC_PATCHLEVEL__) #endif #if defined(JSON_HEDLEY_PGI_VERSION_CHECK) #undef JSON_HEDLEY_PGI_VERSION_CHECK #endif #if defined(JSON_HEDLEY_PGI_VERSION) #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PGI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_SUNPRO_VERSION) #undef JSON_HEDLEY_SUNPRO_VERSION #endif #if defined(__SUNPRO_C) && (__SUNPRO_C > 0x1000) #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_C >> 16) & 0xf) * 10) + ((__SUNPRO_C >> 12) & 0xf), (((__SUNPRO_C >> 8) & 0xf) * 10) + ((__SUNPRO_C >> 4) & 0xf), (__SUNPRO_C & 0xf) * 10) #elif defined(__SUNPRO_C) #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_C >> 8) & 0xf, (__SUNPRO_C >> 4) & 0xf, (__SUNPRO_C) & 0xf) #elif defined(__SUNPRO_CC) && (__SUNPRO_CC > 0x1000) #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_CC >> 16) & 0xf) * 10) + ((__SUNPRO_CC >> 12) & 0xf), (((__SUNPRO_CC >> 8) & 0xf) * 10) + ((__SUNPRO_CC >> 4) & 0xf), (__SUNPRO_CC & 0xf) * 10) #elif defined(__SUNPRO_CC) #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_CC >> 8) & 0xf, (__SUNPRO_CC >> 4) & 0xf, (__SUNPRO_CC) & 0xf) #endif #if defined(JSON_HEDLEY_SUNPRO_VERSION_CHECK) #undef JSON_HEDLEY_SUNPRO_VERSION_CHECK #endif #if defined(JSON_HEDLEY_SUNPRO_VERSION) #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_SUNPRO_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION) #undef JSON_HEDLEY_EMSCRIPTEN_VERSION #endif #if defined(__EMSCRIPTEN__) #define JSON_HEDLEY_EMSCRIPTEN_VERSION JSON_HEDLEY_VERSION_ENCODE(__EMSCRIPTEN_major__, __EMSCRIPTEN_minor__, __EMSCRIPTEN_tiny__) #endif #if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK) #undef JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK #endif #if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION) #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_EMSCRIPTEN_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_ARM_VERSION) #undef JSON_HEDLEY_ARM_VERSION #endif #if defined(__CC_ARM) && defined(__ARMCOMPILER_VERSION) #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCOMPILER_VERSION / 1000000, (__ARMCOMPILER_VERSION % 1000000) / 10000, (__ARMCOMPILER_VERSION % 10000) / 100) #elif defined(__CC_ARM) && defined(__ARMCC_VERSION) #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCC_VERSION / 1000000, (__ARMCC_VERSION % 1000000) / 10000, (__ARMCC_VERSION % 10000) / 100) #endif #if defined(JSON_HEDLEY_ARM_VERSION_CHECK) #undef JSON_HEDLEY_ARM_VERSION_CHECK #endif #if defined(JSON_HEDLEY_ARM_VERSION) #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_ARM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_IBM_VERSION) #undef JSON_HEDLEY_IBM_VERSION #endif #if defined(__ibmxl__) #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ibmxl_version__, __ibmxl_release__, __ibmxl_modification__) #elif defined(__xlC__) && defined(__xlC_ver__) #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, (__xlC_ver__ >> 8) & 0xff) #elif defined(__xlC__) #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, 0) #endif #if defined(JSON_HEDLEY_IBM_VERSION_CHECK) #undef JSON_HEDLEY_IBM_VERSION_CHECK #endif #if defined(JSON_HEDLEY_IBM_VERSION) #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IBM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_TI_VERSION) #undef JSON_HEDLEY_TI_VERSION #endif #if \ defined(__TI_COMPILER_VERSION__) && \ ( \ defined(__TMS470__) || defined(__TI_ARM__) || \ defined(__MSP430__) || \ defined(__TMS320C2000__) \ ) #if (__TI_COMPILER_VERSION__ >= 16000000) #define JSON_HEDLEY_TI_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) #endif #endif #if defined(JSON_HEDLEY_TI_VERSION_CHECK) #undef JSON_HEDLEY_TI_VERSION_CHECK #endif #if defined(JSON_HEDLEY_TI_VERSION) #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_TI_CL2000_VERSION) #undef JSON_HEDLEY_TI_CL2000_VERSION #endif #if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C2000__) #define JSON_HEDLEY_TI_CL2000_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) #endif #if defined(JSON_HEDLEY_TI_CL2000_VERSION_CHECK) #undef JSON_HEDLEY_TI_CL2000_VERSION_CHECK #endif #if defined(JSON_HEDLEY_TI_CL2000_VERSION) #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL2000_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_TI_CL430_VERSION) #undef JSON_HEDLEY_TI_CL430_VERSION #endif #if defined(__TI_COMPILER_VERSION__) && defined(__MSP430__) #define JSON_HEDLEY_TI_CL430_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) #endif #if defined(JSON_HEDLEY_TI_CL430_VERSION_CHECK) #undef JSON_HEDLEY_TI_CL430_VERSION_CHECK #endif #if defined(JSON_HEDLEY_TI_CL430_VERSION) #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL430_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_TI_ARMCL_VERSION) #undef JSON_HEDLEY_TI_ARMCL_VERSION #endif #if defined(__TI_COMPILER_VERSION__) && (defined(__TMS470__) || defined(__TI_ARM__)) #define JSON_HEDLEY_TI_ARMCL_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) #endif #if defined(JSON_HEDLEY_TI_ARMCL_VERSION_CHECK) #undef JSON_HEDLEY_TI_ARMCL_VERSION_CHECK #endif #if defined(JSON_HEDLEY_TI_ARMCL_VERSION) #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_ARMCL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_TI_CL6X_VERSION) #undef JSON_HEDLEY_TI_CL6X_VERSION #endif #if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C6X__) #define JSON_HEDLEY_TI_CL6X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) #endif #if defined(JSON_HEDLEY_TI_CL6X_VERSION_CHECK) #undef JSON_HEDLEY_TI_CL6X_VERSION_CHECK #endif #if defined(JSON_HEDLEY_TI_CL6X_VERSION) #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL6X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_TI_CL7X_VERSION) #undef JSON_HEDLEY_TI_CL7X_VERSION #endif #if defined(__TI_COMPILER_VERSION__) && defined(__C7000__) #define JSON_HEDLEY_TI_CL7X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) #endif #if defined(JSON_HEDLEY_TI_CL7X_VERSION_CHECK) #undef JSON_HEDLEY_TI_CL7X_VERSION_CHECK #endif #if defined(JSON_HEDLEY_TI_CL7X_VERSION) #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL7X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_TI_CLPRU_VERSION) #undef JSON_HEDLEY_TI_CLPRU_VERSION #endif #if defined(__TI_COMPILER_VERSION__) && defined(__PRU__) #define JSON_HEDLEY_TI_CLPRU_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) #endif #if defined(JSON_HEDLEY_TI_CLPRU_VERSION_CHECK) #undef JSON_HEDLEY_TI_CLPRU_VERSION_CHECK #endif #if defined(JSON_HEDLEY_TI_CLPRU_VERSION) #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CLPRU_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_CRAY_VERSION) #undef JSON_HEDLEY_CRAY_VERSION #endif #if defined(_CRAYC) #if defined(_RELEASE_PATCHLEVEL) #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, _RELEASE_PATCHLEVEL) #else #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, 0) #endif #endif #if defined(JSON_HEDLEY_CRAY_VERSION_CHECK) #undef JSON_HEDLEY_CRAY_VERSION_CHECK #endif #if defined(JSON_HEDLEY_CRAY_VERSION) #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_CRAY_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_IAR_VERSION) #undef JSON_HEDLEY_IAR_VERSION #endif #if defined(__IAR_SYSTEMS_ICC__) #if __VER__ > 1000 #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE((__VER__ / 1000000), ((__VER__ / 1000) % 1000), (__VER__ % 1000)) #else #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE(__VER__ / 100, __VER__ % 100, 0) #endif #endif #if defined(JSON_HEDLEY_IAR_VERSION_CHECK) #undef JSON_HEDLEY_IAR_VERSION_CHECK #endif #if defined(JSON_HEDLEY_IAR_VERSION) #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IAR_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_TINYC_VERSION) #undef JSON_HEDLEY_TINYC_VERSION #endif #if defined(__TINYC__) #define JSON_HEDLEY_TINYC_VERSION JSON_HEDLEY_VERSION_ENCODE(__TINYC__ / 1000, (__TINYC__ / 100) % 10, __TINYC__ % 100) #endif #if defined(JSON_HEDLEY_TINYC_VERSION_CHECK) #undef JSON_HEDLEY_TINYC_VERSION_CHECK #endif #if defined(JSON_HEDLEY_TINYC_VERSION) #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TINYC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_DMC_VERSION) #undef JSON_HEDLEY_DMC_VERSION #endif #if defined(__DMC__) #define JSON_HEDLEY_DMC_VERSION JSON_HEDLEY_VERSION_ENCODE(__DMC__ >> 8, (__DMC__ >> 4) & 0xf, __DMC__ & 0xf) #endif #if defined(JSON_HEDLEY_DMC_VERSION_CHECK) #undef JSON_HEDLEY_DMC_VERSION_CHECK #endif #if defined(JSON_HEDLEY_DMC_VERSION) #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_DMC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_COMPCERT_VERSION) #undef JSON_HEDLEY_COMPCERT_VERSION #endif #if defined(__COMPCERT_VERSION__) #define JSON_HEDLEY_COMPCERT_VERSION JSON_HEDLEY_VERSION_ENCODE(__COMPCERT_VERSION__ / 10000, (__COMPCERT_VERSION__ / 100) % 100, __COMPCERT_VERSION__ % 100) #endif #if defined(JSON_HEDLEY_COMPCERT_VERSION_CHECK) #undef JSON_HEDLEY_COMPCERT_VERSION_CHECK #endif #if defined(JSON_HEDLEY_COMPCERT_VERSION) #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_COMPCERT_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_PELLES_VERSION) #undef JSON_HEDLEY_PELLES_VERSION #endif #if defined(__POCC__) #define JSON_HEDLEY_PELLES_VERSION JSON_HEDLEY_VERSION_ENCODE(__POCC__ / 100, __POCC__ % 100, 0) #endif #if defined(JSON_HEDLEY_PELLES_VERSION_CHECK) #undef JSON_HEDLEY_PELLES_VERSION_CHECK #endif #if defined(JSON_HEDLEY_PELLES_VERSION) #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PELLES_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_MCST_LCC_VERSION) #undef JSON_HEDLEY_MCST_LCC_VERSION #endif #if defined(__LCC__) && defined(__LCC_MINOR__) #define JSON_HEDLEY_MCST_LCC_VERSION JSON_HEDLEY_VERSION_ENCODE(__LCC__ / 100, __LCC__ % 100, __LCC_MINOR__) #endif #if defined(JSON_HEDLEY_MCST_LCC_VERSION_CHECK) #undef JSON_HEDLEY_MCST_LCC_VERSION_CHECK #endif #if defined(JSON_HEDLEY_MCST_LCC_VERSION) #define JSON_HEDLEY_MCST_LCC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_MCST_LCC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_MCST_LCC_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_GCC_VERSION) #undef JSON_HEDLEY_GCC_VERSION #endif #if \ defined(JSON_HEDLEY_GNUC_VERSION) && \ !defined(__clang__) && \ !defined(JSON_HEDLEY_INTEL_VERSION) && \ !defined(JSON_HEDLEY_PGI_VERSION) && \ !defined(JSON_HEDLEY_ARM_VERSION) && \ !defined(JSON_HEDLEY_CRAY_VERSION) && \ !defined(JSON_HEDLEY_TI_VERSION) && \ !defined(JSON_HEDLEY_TI_ARMCL_VERSION) && \ !defined(JSON_HEDLEY_TI_CL430_VERSION) && \ !defined(JSON_HEDLEY_TI_CL2000_VERSION) && \ !defined(JSON_HEDLEY_TI_CL6X_VERSION) && \ !defined(JSON_HEDLEY_TI_CL7X_VERSION) && \ !defined(JSON_HEDLEY_TI_CLPRU_VERSION) && \ !defined(__COMPCERT__) && \ !defined(JSON_HEDLEY_MCST_LCC_VERSION) #define JSON_HEDLEY_GCC_VERSION JSON_HEDLEY_GNUC_VERSION #endif #if defined(JSON_HEDLEY_GCC_VERSION_CHECK) #undef JSON_HEDLEY_GCC_VERSION_CHECK #endif #if defined(JSON_HEDLEY_GCC_VERSION) #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GCC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) #else #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (0) #endif #if defined(JSON_HEDLEY_HAS_ATTRIBUTE) #undef JSON_HEDLEY_HAS_ATTRIBUTE #endif #if \ defined(__has_attribute) && \ ( \ (!defined(JSON_HEDLEY_IAR_VERSION) || JSON_HEDLEY_IAR_VERSION_CHECK(8,5,9)) \ ) # define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) __has_attribute(attribute) #else # define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) (0) #endif #if defined(JSON_HEDLEY_GNUC_HAS_ATTRIBUTE) #undef JSON_HEDLEY_GNUC_HAS_ATTRIBUTE #endif #if defined(__has_attribute) #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_HAS_ATTRIBUTE(attribute) #else #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) #endif #if defined(JSON_HEDLEY_GCC_HAS_ATTRIBUTE) #undef JSON_HEDLEY_GCC_HAS_ATTRIBUTE #endif #if defined(__has_attribute) #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_HAS_ATTRIBUTE(attribute) #else #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) #endif #if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE) #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE #endif #if \ defined(__has_cpp_attribute) && \ defined(__cplusplus) && \ (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0)) #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) __has_cpp_attribute(attribute) #else #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) (0) #endif #if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS) #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS #endif #if !defined(__cplusplus) || !defined(__has_cpp_attribute) #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0) #elif \ !defined(JSON_HEDLEY_PGI_VERSION) && \ !defined(JSON_HEDLEY_IAR_VERSION) && \ (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0)) && \ (!defined(JSON_HEDLEY_MSVC_VERSION) || JSON_HEDLEY_MSVC_VERSION_CHECK(19,20,0)) #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(ns::attribute) #else #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0) #endif #if defined(JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE) #undef JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE #endif #if defined(__has_cpp_attribute) && defined(__cplusplus) #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute) #else #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) #endif #if defined(JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE) #undef JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE #endif #if defined(__has_cpp_attribute) && defined(__cplusplus) #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute) #else #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) #endif #if defined(JSON_HEDLEY_HAS_BUILTIN) #undef JSON_HEDLEY_HAS_BUILTIN #endif #if defined(__has_builtin) #define JSON_HEDLEY_HAS_BUILTIN(builtin) __has_builtin(builtin) #else #define JSON_HEDLEY_HAS_BUILTIN(builtin) (0) #endif #if defined(JSON_HEDLEY_GNUC_HAS_BUILTIN) #undef JSON_HEDLEY_GNUC_HAS_BUILTIN #endif #if defined(__has_builtin) #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin) #else #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) #endif #if defined(JSON_HEDLEY_GCC_HAS_BUILTIN) #undef JSON_HEDLEY_GCC_HAS_BUILTIN #endif #if defined(__has_builtin) #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin) #else #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) #endif #if defined(JSON_HEDLEY_HAS_FEATURE) #undef JSON_HEDLEY_HAS_FEATURE #endif #if defined(__has_feature) #define JSON_HEDLEY_HAS_FEATURE(feature) __has_feature(feature) #else #define JSON_HEDLEY_HAS_FEATURE(feature) (0) #endif #if defined(JSON_HEDLEY_GNUC_HAS_FEATURE) #undef JSON_HEDLEY_GNUC_HAS_FEATURE #endif #if defined(__has_feature) #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature) #else #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) #endif #if defined(JSON_HEDLEY_GCC_HAS_FEATURE) #undef JSON_HEDLEY_GCC_HAS_FEATURE #endif #if defined(__has_feature) #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature) #else #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) #endif #if defined(JSON_HEDLEY_HAS_EXTENSION) #undef JSON_HEDLEY_HAS_EXTENSION #endif #if defined(__has_extension) #define JSON_HEDLEY_HAS_EXTENSION(extension) __has_extension(extension) #else #define JSON_HEDLEY_HAS_EXTENSION(extension) (0) #endif #if defined(JSON_HEDLEY_GNUC_HAS_EXTENSION) #undef JSON_HEDLEY_GNUC_HAS_EXTENSION #endif #if defined(__has_extension) #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension) #else #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) #endif #if defined(JSON_HEDLEY_GCC_HAS_EXTENSION) #undef JSON_HEDLEY_GCC_HAS_EXTENSION #endif #if defined(__has_extension) #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension) #else #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) #endif #if defined(JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE) #undef JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE #endif #if defined(__has_declspec_attribute) #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) __has_declspec_attribute(attribute) #else #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) (0) #endif #if defined(JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE) #undef JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE #endif #if defined(__has_declspec_attribute) #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute) #else #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) #endif #if defined(JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE) #undef JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE #endif #if defined(__has_declspec_attribute) #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute) #else #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) #endif #if defined(JSON_HEDLEY_HAS_WARNING) #undef JSON_HEDLEY_HAS_WARNING #endif #if defined(__has_warning) #define JSON_HEDLEY_HAS_WARNING(warning) __has_warning(warning) #else #define JSON_HEDLEY_HAS_WARNING(warning) (0) #endif #if defined(JSON_HEDLEY_GNUC_HAS_WARNING) #undef JSON_HEDLEY_GNUC_HAS_WARNING #endif #if defined(__has_warning) #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning) #else #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) #endif #if defined(JSON_HEDLEY_GCC_HAS_WARNING) #undef JSON_HEDLEY_GCC_HAS_WARNING #endif #if defined(__has_warning) #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning) #else #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) #endif #if \ (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \ defined(__clang__) || \ JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \ JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,0,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0) || \ JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,17) || \ JSON_HEDLEY_SUNPRO_VERSION_CHECK(8,0,0) || \ (JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) && defined(__C99_PRAGMA_OPERATOR)) #define JSON_HEDLEY_PRAGMA(value) _Pragma(#value) #elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) #define JSON_HEDLEY_PRAGMA(value) __pragma(value) #else #define JSON_HEDLEY_PRAGMA(value) #endif #if defined(JSON_HEDLEY_DIAGNOSTIC_PUSH) #undef JSON_HEDLEY_DIAGNOSTIC_PUSH #endif #if defined(JSON_HEDLEY_DIAGNOSTIC_POP) #undef JSON_HEDLEY_DIAGNOSTIC_POP #endif #if defined(__clang__) #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("clang diagnostic push") #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("clang diagnostic pop") #elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") #elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("GCC diagnostic push") #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("GCC diagnostic pop") #elif \ JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) || \ JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) #define JSON_HEDLEY_DIAGNOSTIC_PUSH __pragma(warning(push)) #define JSON_HEDLEY_DIAGNOSTIC_POP __pragma(warning(pop)) #elif JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("push") #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("pop") #elif \ JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,4,0) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("diag_push") #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("diag_pop") #elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0) #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") #else #define JSON_HEDLEY_DIAGNOSTIC_PUSH #define JSON_HEDLEY_DIAGNOSTIC_POP #endif /* JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ is for HEDLEY INTERNAL USE ONLY. API subject to change without notice. */ #if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_) #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ #endif #if defined(__cplusplus) # if JSON_HEDLEY_HAS_WARNING("-Wc++98-compat") # if JSON_HEDLEY_HAS_WARNING("-Wc++17-extensions") # if JSON_HEDLEY_HAS_WARNING("-Wc++1z-extensions") # define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ JSON_HEDLEY_DIAGNOSTIC_PUSH \ _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ _Pragma("clang diagnostic ignored \"-Wc++17-extensions\"") \ _Pragma("clang diagnostic ignored \"-Wc++1z-extensions\"") \ xpr \ JSON_HEDLEY_DIAGNOSTIC_POP # else # define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ JSON_HEDLEY_DIAGNOSTIC_PUSH \ _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ _Pragma("clang diagnostic ignored \"-Wc++17-extensions\"") \ xpr \ JSON_HEDLEY_DIAGNOSTIC_POP # endif # else # define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ JSON_HEDLEY_DIAGNOSTIC_PUSH \ _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ xpr \ JSON_HEDLEY_DIAGNOSTIC_POP # endif # endif #endif #if !defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(x) x #endif #if defined(JSON_HEDLEY_CONST_CAST) #undef JSON_HEDLEY_CONST_CAST #endif #if defined(__cplusplus) # define JSON_HEDLEY_CONST_CAST(T, expr) (const_cast(expr)) #elif \ JSON_HEDLEY_HAS_WARNING("-Wcast-qual") || \ JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) # define JSON_HEDLEY_CONST_CAST(T, expr) (__extension__ ({ \ JSON_HEDLEY_DIAGNOSTIC_PUSH \ JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL \ ((T) (expr)); \ JSON_HEDLEY_DIAGNOSTIC_POP \ })) #else # define JSON_HEDLEY_CONST_CAST(T, expr) ((T) (expr)) #endif #if defined(JSON_HEDLEY_REINTERPRET_CAST) #undef JSON_HEDLEY_REINTERPRET_CAST #endif #if defined(__cplusplus) #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) (reinterpret_cast(expr)) #else #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) ((T) (expr)) #endif #if defined(JSON_HEDLEY_STATIC_CAST) #undef JSON_HEDLEY_STATIC_CAST #endif #if defined(__cplusplus) #define JSON_HEDLEY_STATIC_CAST(T, expr) (static_cast(expr)) #else #define JSON_HEDLEY_STATIC_CAST(T, expr) ((T) (expr)) #endif #if defined(JSON_HEDLEY_CPP_CAST) #undef JSON_HEDLEY_CPP_CAST #endif #if defined(__cplusplus) # if JSON_HEDLEY_HAS_WARNING("-Wold-style-cast") # define JSON_HEDLEY_CPP_CAST(T, expr) \ JSON_HEDLEY_DIAGNOSTIC_PUSH \ _Pragma("clang diagnostic ignored \"-Wold-style-cast\"") \ ((T) (expr)) \ JSON_HEDLEY_DIAGNOSTIC_POP # elif JSON_HEDLEY_IAR_VERSION_CHECK(8,3,0) # define JSON_HEDLEY_CPP_CAST(T, expr) \ JSON_HEDLEY_DIAGNOSTIC_PUSH \ _Pragma("diag_suppress=Pe137") \ JSON_HEDLEY_DIAGNOSTIC_POP # else # define JSON_HEDLEY_CPP_CAST(T, expr) ((T) (expr)) # endif #else # define JSON_HEDLEY_CPP_CAST(T, expr) (expr) #endif #if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED) #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #endif #if JSON_HEDLEY_HAS_WARNING("-Wdeprecated-declarations") #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("clang diagnostic ignored \"-Wdeprecated-declarations\"") #elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warning(disable:1478 1786)") #elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:1478 1786)) #elif JSON_HEDLEY_PGI_VERSION_CHECK(20,7,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1216,1444,1445") #elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444") #elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") #elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:4996)) #elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444") #elif \ JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1291,1718") #elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && !defined(__cplusplus) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,E_DEPRECATED_ATT,E_DEPRECATED_ATT_MESS)") #elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && defined(__cplusplus) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,symdeprecated,symdeprecated2)") #elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress=Pe1444,Pe1215") #elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warn(disable:2241)") #else #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED #endif #if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS) #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS #endif #if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("clang diagnostic ignored \"-Wunknown-pragmas\"") #elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("warning(disable:161)") #elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:161)) #elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 1675") #elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("GCC diagnostic ignored \"-Wunknown-pragmas\"") #elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:4068)) #elif \ JSON_HEDLEY_TI_VERSION_CHECK(16,9,0) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") #elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") #elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress=Pe161") #elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 161") #else #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS #endif #if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES) #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES #endif #if JSON_HEDLEY_HAS_WARNING("-Wunknown-attributes") #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("clang diagnostic ignored \"-Wunknown-attributes\"") #elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") #elif JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("warning(disable:1292)") #elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:1292)) #elif JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:5030)) #elif JSON_HEDLEY_PGI_VERSION_CHECK(20,7,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097,1098") #elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097") #elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("error_messages(off,attrskipunsup)") #elif \ JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1173") #elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress=Pe1097") #elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097") #else #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES #endif #if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL) #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL #endif #if JSON_HEDLEY_HAS_WARNING("-Wcast-qual") #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("clang diagnostic ignored \"-Wcast-qual\"") #elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("warning(disable:2203 2331)") #elif JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("GCC diagnostic ignored \"-Wcast-qual\"") #else #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL #endif #if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION) #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION #endif #if JSON_HEDLEY_HAS_WARNING("-Wunused-function") #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("clang diagnostic ignored \"-Wunused-function\"") #elif JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("GCC diagnostic ignored \"-Wunused-function\"") #elif JSON_HEDLEY_MSVC_VERSION_CHECK(1,0,0) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION __pragma(warning(disable:4505)) #elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("diag_suppress 3142") #else #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION #endif #if defined(JSON_HEDLEY_DEPRECATED) #undef JSON_HEDLEY_DEPRECATED #endif #if defined(JSON_HEDLEY_DEPRECATED_FOR) #undef JSON_HEDLEY_DEPRECATED_FOR #endif #if \ JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \ JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated("Since " # since)) #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated("Since " #since "; use " #replacement)) #elif \ (JSON_HEDLEY_HAS_EXTENSION(attribute_deprecated_with_message) && !defined(JSON_HEDLEY_IAR_VERSION)) || \ JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \ JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) || \ JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(18,1,0) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__("Since " #since))) #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__("Since " #since "; use " #replacement))) #elif defined(__cplusplus) && (__cplusplus >= 201402L) #define JSON_HEDLEY_DEPRECATED(since) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated("Since " #since)]]) #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated("Since " #since "; use " #replacement)]]) #elif \ JSON_HEDLEY_HAS_ATTRIBUTE(deprecated) || \ JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \ JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__)) #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__)) #elif \ JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ JSON_HEDLEY_PELLES_VERSION_CHECK(6,50,0) || \ JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated) #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated) #elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) #define JSON_HEDLEY_DEPRECATED(since) _Pragma("deprecated") #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) _Pragma("deprecated") #else #define JSON_HEDLEY_DEPRECATED(since) #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) #endif #if defined(JSON_HEDLEY_UNAVAILABLE) #undef JSON_HEDLEY_UNAVAILABLE #endif #if \ JSON_HEDLEY_HAS_ATTRIBUTE(warning) || \ JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_UNAVAILABLE(available_since) __attribute__((__warning__("Not available until " #available_since))) #else #define JSON_HEDLEY_UNAVAILABLE(available_since) #endif #if defined(JSON_HEDLEY_WARN_UNUSED_RESULT) #undef JSON_HEDLEY_WARN_UNUSED_RESULT #endif #if defined(JSON_HEDLEY_WARN_UNUSED_RESULT_MSG) #undef JSON_HEDLEY_WARN_UNUSED_RESULT_MSG #endif #if \ JSON_HEDLEY_HAS_ATTRIBUTE(warn_unused_result) || \ JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \ JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_WARN_UNUSED_RESULT __attribute__((__warn_unused_result__)) #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) __attribute__((__warn_unused_result__)) #elif (JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard) >= 201907L) #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard(msg)]]) #elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard) #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) #elif defined(_Check_return_) /* SAL */ #define JSON_HEDLEY_WARN_UNUSED_RESULT _Check_return_ #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) _Check_return_ #else #define JSON_HEDLEY_WARN_UNUSED_RESULT #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) #endif #if defined(JSON_HEDLEY_SENTINEL) #undef JSON_HEDLEY_SENTINEL #endif #if \ JSON_HEDLEY_HAS_ATTRIBUTE(sentinel) || \ JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_SENTINEL(position) __attribute__((__sentinel__(position))) #else #define JSON_HEDLEY_SENTINEL(position) #endif #if defined(JSON_HEDLEY_NO_RETURN) #undef JSON_HEDLEY_NO_RETURN #endif #if JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) #define JSON_HEDLEY_NO_RETURN __noreturn #elif \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__)) #elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L #define JSON_HEDLEY_NO_RETURN _Noreturn #elif defined(__cplusplus) && (__cplusplus >= 201103L) #define JSON_HEDLEY_NO_RETURN JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[noreturn]]) #elif \ JSON_HEDLEY_HAS_ATTRIBUTE(noreturn) || \ JSON_HEDLEY_GCC_VERSION_CHECK(3,2,0) || \ JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__)) #elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) #define JSON_HEDLEY_NO_RETURN _Pragma("does_not_return") #elif \ JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) #define JSON_HEDLEY_NO_RETURN __declspec(noreturn) #elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus) #define JSON_HEDLEY_NO_RETURN _Pragma("FUNC_NEVER_RETURNS;") #elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0) #define JSON_HEDLEY_NO_RETURN __attribute((noreturn)) #elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0) #define JSON_HEDLEY_NO_RETURN __declspec(noreturn) #else #define JSON_HEDLEY_NO_RETURN #endif #if defined(JSON_HEDLEY_NO_ESCAPE) #undef JSON_HEDLEY_NO_ESCAPE #endif #if JSON_HEDLEY_HAS_ATTRIBUTE(noescape) #define JSON_HEDLEY_NO_ESCAPE __attribute__((__noescape__)) #else #define JSON_HEDLEY_NO_ESCAPE #endif #if defined(JSON_HEDLEY_UNREACHABLE) #undef JSON_HEDLEY_UNREACHABLE #endif #if defined(JSON_HEDLEY_UNREACHABLE_RETURN) #undef JSON_HEDLEY_UNREACHABLE_RETURN #endif #if defined(JSON_HEDLEY_ASSUME) #undef JSON_HEDLEY_ASSUME #endif #if \ JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) #define JSON_HEDLEY_ASSUME(expr) __assume(expr) #elif JSON_HEDLEY_HAS_BUILTIN(__builtin_assume) #define JSON_HEDLEY_ASSUME(expr) __builtin_assume(expr) #elif \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) #if defined(__cplusplus) #define JSON_HEDLEY_ASSUME(expr) std::_nassert(expr) #else #define JSON_HEDLEY_ASSUME(expr) _nassert(expr) #endif #endif #if \ (JSON_HEDLEY_HAS_BUILTIN(__builtin_unreachable) && (!defined(JSON_HEDLEY_ARM_VERSION))) || \ JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \ JSON_HEDLEY_PGI_VERSION_CHECK(18,10,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_IBM_VERSION_CHECK(13,1,5) || \ JSON_HEDLEY_CRAY_VERSION_CHECK(10,0,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_UNREACHABLE() __builtin_unreachable() #elif defined(JSON_HEDLEY_ASSUME) #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0) #endif #if !defined(JSON_HEDLEY_ASSUME) #if defined(JSON_HEDLEY_UNREACHABLE) #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, ((expr) ? 1 : (JSON_HEDLEY_UNREACHABLE(), 1))) #else #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, expr) #endif #endif #if defined(JSON_HEDLEY_UNREACHABLE) #if \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (JSON_HEDLEY_STATIC_CAST(void, JSON_HEDLEY_ASSUME(0)), (value)) #else #define JSON_HEDLEY_UNREACHABLE_RETURN(value) JSON_HEDLEY_UNREACHABLE() #endif #else #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (value) #endif #if !defined(JSON_HEDLEY_UNREACHABLE) #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0) #endif JSON_HEDLEY_DIAGNOSTIC_PUSH #if JSON_HEDLEY_HAS_WARNING("-Wpedantic") #pragma clang diagnostic ignored "-Wpedantic" #endif #if JSON_HEDLEY_HAS_WARNING("-Wc++98-compat-pedantic") && defined(__cplusplus) #pragma clang diagnostic ignored "-Wc++98-compat-pedantic" #endif #if JSON_HEDLEY_GCC_HAS_WARNING("-Wvariadic-macros",4,0,0) #if defined(__clang__) #pragma clang diagnostic ignored "-Wvariadic-macros" #elif defined(JSON_HEDLEY_GCC_VERSION) #pragma GCC diagnostic ignored "-Wvariadic-macros" #endif #endif #if defined(JSON_HEDLEY_NON_NULL) #undef JSON_HEDLEY_NON_NULL #endif #if \ JSON_HEDLEY_HAS_ATTRIBUTE(nonnull) || \ JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) #define JSON_HEDLEY_NON_NULL(...) __attribute__((__nonnull__(__VA_ARGS__))) #else #define JSON_HEDLEY_NON_NULL(...) #endif JSON_HEDLEY_DIAGNOSTIC_POP #if defined(JSON_HEDLEY_PRINTF_FORMAT) #undef JSON_HEDLEY_PRINTF_FORMAT #endif #if defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && !defined(__USE_MINGW_ANSI_STDIO) #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(ms_printf, string_idx, first_to_check))) #elif defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && defined(__USE_MINGW_ANSI_STDIO) #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(gnu_printf, string_idx, first_to_check))) #elif \ JSON_HEDLEY_HAS_ATTRIBUTE(format) || \ JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \ JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(__printf__, string_idx, first_to_check))) #elif JSON_HEDLEY_PELLES_VERSION_CHECK(6,0,0) #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __declspec(vaformat(printf,string_idx,first_to_check)) #else #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) #endif #if defined(JSON_HEDLEY_CONSTEXPR) #undef JSON_HEDLEY_CONSTEXPR #endif #if defined(__cplusplus) #if __cplusplus >= 201103L #define JSON_HEDLEY_CONSTEXPR JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(constexpr) #endif #endif #if !defined(JSON_HEDLEY_CONSTEXPR) #define JSON_HEDLEY_CONSTEXPR #endif #if defined(JSON_HEDLEY_PREDICT) #undef JSON_HEDLEY_PREDICT #endif #if defined(JSON_HEDLEY_LIKELY) #undef JSON_HEDLEY_LIKELY #endif #if defined(JSON_HEDLEY_UNLIKELY) #undef JSON_HEDLEY_UNLIKELY #endif #if defined(JSON_HEDLEY_UNPREDICTABLE) #undef JSON_HEDLEY_UNPREDICTABLE #endif #if JSON_HEDLEY_HAS_BUILTIN(__builtin_unpredictable) #define JSON_HEDLEY_UNPREDICTABLE(expr) __builtin_unpredictable((expr)) #endif #if \ (JSON_HEDLEY_HAS_BUILTIN(__builtin_expect_with_probability) && !defined(JSON_HEDLEY_PGI_VERSION)) || \ JSON_HEDLEY_GCC_VERSION_CHECK(9,0,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) # define JSON_HEDLEY_PREDICT(expr, value, probability) __builtin_expect_with_probability( (expr), (value), (probability)) # define JSON_HEDLEY_PREDICT_TRUE(expr, probability) __builtin_expect_with_probability(!!(expr), 1 , (probability)) # define JSON_HEDLEY_PREDICT_FALSE(expr, probability) __builtin_expect_with_probability(!!(expr), 0 , (probability)) # define JSON_HEDLEY_LIKELY(expr) __builtin_expect (!!(expr), 1 ) # define JSON_HEDLEY_UNLIKELY(expr) __builtin_expect (!!(expr), 0 ) #elif \ (JSON_HEDLEY_HAS_BUILTIN(__builtin_expect) && !defined(JSON_HEDLEY_INTEL_CL_VERSION)) || \ JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \ JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,27) || \ JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) # define JSON_HEDLEY_PREDICT(expr, expected, probability) \ (((probability) >= 0.9) ? __builtin_expect((expr), (expected)) : (JSON_HEDLEY_STATIC_CAST(void, expected), (expr))) # define JSON_HEDLEY_PREDICT_TRUE(expr, probability) \ (__extension__ ({ \ double hedley_probability_ = (probability); \ ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 1) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 0) : !!(expr))); \ })) # define JSON_HEDLEY_PREDICT_FALSE(expr, probability) \ (__extension__ ({ \ double hedley_probability_ = (probability); \ ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 0) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 1) : !!(expr))); \ })) # define JSON_HEDLEY_LIKELY(expr) __builtin_expect(!!(expr), 1) # define JSON_HEDLEY_UNLIKELY(expr) __builtin_expect(!!(expr), 0) #else # define JSON_HEDLEY_PREDICT(expr, expected, probability) (JSON_HEDLEY_STATIC_CAST(void, expected), (expr)) # define JSON_HEDLEY_PREDICT_TRUE(expr, probability) (!!(expr)) # define JSON_HEDLEY_PREDICT_FALSE(expr, probability) (!!(expr)) # define JSON_HEDLEY_LIKELY(expr) (!!(expr)) # define JSON_HEDLEY_UNLIKELY(expr) (!!(expr)) #endif #if !defined(JSON_HEDLEY_UNPREDICTABLE) #define JSON_HEDLEY_UNPREDICTABLE(expr) JSON_HEDLEY_PREDICT(expr, 1, 0.5) #endif #if defined(JSON_HEDLEY_MALLOC) #undef JSON_HEDLEY_MALLOC #endif #if \ JSON_HEDLEY_HAS_ATTRIBUTE(malloc) || \ JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \ JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_MALLOC __attribute__((__malloc__)) #elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) #define JSON_HEDLEY_MALLOC _Pragma("returns_new_memory") #elif \ JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \ JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) #define JSON_HEDLEY_MALLOC __declspec(restrict) #else #define JSON_HEDLEY_MALLOC #endif #if defined(JSON_HEDLEY_PURE) #undef JSON_HEDLEY_PURE #endif #if \ JSON_HEDLEY_HAS_ATTRIBUTE(pure) || \ JSON_HEDLEY_GCC_VERSION_CHECK(2,96,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) # define JSON_HEDLEY_PURE __attribute__((__pure__)) #elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) # define JSON_HEDLEY_PURE _Pragma("does_not_write_global_data") #elif defined(__cplusplus) && \ ( \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) \ ) # define JSON_HEDLEY_PURE _Pragma("FUNC_IS_PURE;") #else # define JSON_HEDLEY_PURE #endif #if defined(JSON_HEDLEY_CONST) #undef JSON_HEDLEY_CONST #endif #if \ JSON_HEDLEY_HAS_ATTRIBUTE(const) || \ JSON_HEDLEY_GCC_VERSION_CHECK(2,5,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_CONST __attribute__((__const__)) #elif \ JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) #define JSON_HEDLEY_CONST _Pragma("no_side_effect") #else #define JSON_HEDLEY_CONST JSON_HEDLEY_PURE #endif #if defined(JSON_HEDLEY_RESTRICT) #undef JSON_HEDLEY_RESTRICT #endif #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && !defined(__cplusplus) #define JSON_HEDLEY_RESTRICT restrict #elif \ JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,4) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus)) || \ JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \ defined(__clang__) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_RESTRICT __restrict #elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,3,0) && !defined(__cplusplus) #define JSON_HEDLEY_RESTRICT _Restrict #else #define JSON_HEDLEY_RESTRICT #endif #if defined(JSON_HEDLEY_INLINE) #undef JSON_HEDLEY_INLINE #endif #if \ (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \ (defined(__cplusplus) && (__cplusplus >= 199711L)) #define JSON_HEDLEY_INLINE inline #elif \ defined(JSON_HEDLEY_GCC_VERSION) || \ JSON_HEDLEY_ARM_VERSION_CHECK(6,2,0) #define JSON_HEDLEY_INLINE __inline__ #elif \ JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) || \ JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,1,0) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_INLINE __inline #else #define JSON_HEDLEY_INLINE #endif #if defined(JSON_HEDLEY_ALWAYS_INLINE) #undef JSON_HEDLEY_ALWAYS_INLINE #endif #if \ JSON_HEDLEY_HAS_ATTRIBUTE(always_inline) || \ JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \ JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) # define JSON_HEDLEY_ALWAYS_INLINE __attribute__((__always_inline__)) JSON_HEDLEY_INLINE #elif \ JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) || \ JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) # define JSON_HEDLEY_ALWAYS_INLINE __forceinline #elif defined(__cplusplus) && \ ( \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) \ ) # define JSON_HEDLEY_ALWAYS_INLINE _Pragma("FUNC_ALWAYS_INLINE;") #elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) # define JSON_HEDLEY_ALWAYS_INLINE _Pragma("inline=forced") #else # define JSON_HEDLEY_ALWAYS_INLINE JSON_HEDLEY_INLINE #endif #if defined(JSON_HEDLEY_NEVER_INLINE) #undef JSON_HEDLEY_NEVER_INLINE #endif #if \ JSON_HEDLEY_HAS_ATTRIBUTE(noinline) || \ JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \ JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) #define JSON_HEDLEY_NEVER_INLINE __attribute__((__noinline__)) #elif \ JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline) #elif JSON_HEDLEY_PGI_VERSION_CHECK(10,2,0) #define JSON_HEDLEY_NEVER_INLINE _Pragma("noinline") #elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus) #define JSON_HEDLEY_NEVER_INLINE _Pragma("FUNC_CANNOT_INLINE;") #elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) #define JSON_HEDLEY_NEVER_INLINE _Pragma("inline=never") #elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0) #define JSON_HEDLEY_NEVER_INLINE __attribute((noinline)) #elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0) #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline) #else #define JSON_HEDLEY_NEVER_INLINE #endif #if defined(JSON_HEDLEY_PRIVATE) #undef JSON_HEDLEY_PRIVATE #endif #if defined(JSON_HEDLEY_PUBLIC) #undef JSON_HEDLEY_PUBLIC #endif #if defined(JSON_HEDLEY_IMPORT) #undef JSON_HEDLEY_IMPORT #endif #if defined(_WIN32) || defined(__CYGWIN__) # define JSON_HEDLEY_PRIVATE # define JSON_HEDLEY_PUBLIC __declspec(dllexport) # define JSON_HEDLEY_IMPORT __declspec(dllimport) #else # if \ JSON_HEDLEY_HAS_ATTRIBUTE(visibility) || \ JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ ( \ defined(__TI_EABI__) && \ ( \ (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) \ ) \ ) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) # define JSON_HEDLEY_PRIVATE __attribute__((__visibility__("hidden"))) # define JSON_HEDLEY_PUBLIC __attribute__((__visibility__("default"))) # else # define JSON_HEDLEY_PRIVATE # define JSON_HEDLEY_PUBLIC # endif # define JSON_HEDLEY_IMPORT extern #endif #if defined(JSON_HEDLEY_NO_THROW) #undef JSON_HEDLEY_NO_THROW #endif #if \ JSON_HEDLEY_HAS_ATTRIBUTE(nothrow) || \ JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_NO_THROW __attribute__((__nothrow__)) #elif \ JSON_HEDLEY_MSVC_VERSION_CHECK(13,1,0) || \ JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) #define JSON_HEDLEY_NO_THROW __declspec(nothrow) #else #define JSON_HEDLEY_NO_THROW #endif #if defined(JSON_HEDLEY_FALL_THROUGH) #undef JSON_HEDLEY_FALL_THROUGH #endif #if \ JSON_HEDLEY_HAS_ATTRIBUTE(fallthrough) || \ JSON_HEDLEY_GCC_VERSION_CHECK(7,0,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_FALL_THROUGH __attribute__((__fallthrough__)) #elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(clang,fallthrough) #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[clang::fallthrough]]) #elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(fallthrough) #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[fallthrough]]) #elif defined(__fallthrough) /* SAL */ #define JSON_HEDLEY_FALL_THROUGH __fallthrough #else #define JSON_HEDLEY_FALL_THROUGH #endif #if defined(JSON_HEDLEY_RETURNS_NON_NULL) #undef JSON_HEDLEY_RETURNS_NON_NULL #endif #if \ JSON_HEDLEY_HAS_ATTRIBUTE(returns_nonnull) || \ JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_RETURNS_NON_NULL __attribute__((__returns_nonnull__)) #elif defined(_Ret_notnull_) /* SAL */ #define JSON_HEDLEY_RETURNS_NON_NULL _Ret_notnull_ #else #define JSON_HEDLEY_RETURNS_NON_NULL #endif #if defined(JSON_HEDLEY_ARRAY_PARAM) #undef JSON_HEDLEY_ARRAY_PARAM #endif #if \ defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && \ !defined(__STDC_NO_VLA__) && \ !defined(__cplusplus) && \ !defined(JSON_HEDLEY_PGI_VERSION) && \ !defined(JSON_HEDLEY_TINYC_VERSION) #define JSON_HEDLEY_ARRAY_PARAM(name) (name) #else #define JSON_HEDLEY_ARRAY_PARAM(name) #endif #if defined(JSON_HEDLEY_IS_CONSTANT) #undef JSON_HEDLEY_IS_CONSTANT #endif #if defined(JSON_HEDLEY_REQUIRE_CONSTEXPR) #undef JSON_HEDLEY_REQUIRE_CONSTEXPR #endif /* JSON_HEDLEY_IS_CONSTEXPR_ is for HEDLEY INTERNAL USE ONLY. API subject to change without notice. */ #if defined(JSON_HEDLEY_IS_CONSTEXPR_) #undef JSON_HEDLEY_IS_CONSTEXPR_ #endif #if \ JSON_HEDLEY_HAS_BUILTIN(__builtin_constant_p) || \ JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,19) || \ JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) && !defined(__cplusplus)) || \ JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \ JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) #define JSON_HEDLEY_IS_CONSTANT(expr) __builtin_constant_p(expr) #endif #if !defined(__cplusplus) # if \ JSON_HEDLEY_HAS_BUILTIN(__builtin_types_compatible_p) || \ JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) || \ JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,24) #if defined(__INTPTR_TYPE__) #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0)), int*) #else #include #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((intptr_t) ((expr) * 0)) : (int*) 0)), int*) #endif # elif \ ( \ defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) && \ !defined(JSON_HEDLEY_SUNPRO_VERSION) && \ !defined(JSON_HEDLEY_PGI_VERSION) && \ !defined(JSON_HEDLEY_IAR_VERSION)) || \ (JSON_HEDLEY_HAS_EXTENSION(c_generic_selections) && !defined(JSON_HEDLEY_IAR_VERSION)) || \ JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0) || \ JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \ JSON_HEDLEY_ARM_VERSION_CHECK(5,3,0) #if defined(__INTPTR_TYPE__) #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0), int*: 1, void*: 0) #else #include #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((intptr_t) * 0) : (int*) 0), int*: 1, void*: 0) #endif # elif \ defined(JSON_HEDLEY_GCC_VERSION) || \ defined(JSON_HEDLEY_INTEL_VERSION) || \ defined(JSON_HEDLEY_TINYC_VERSION) || \ defined(JSON_HEDLEY_TI_ARMCL_VERSION) || \ JSON_HEDLEY_TI_CL430_VERSION_CHECK(18,12,0) || \ defined(JSON_HEDLEY_TI_CL2000_VERSION) || \ defined(JSON_HEDLEY_TI_CL6X_VERSION) || \ defined(JSON_HEDLEY_TI_CL7X_VERSION) || \ defined(JSON_HEDLEY_TI_CLPRU_VERSION) || \ defined(__clang__) # define JSON_HEDLEY_IS_CONSTEXPR_(expr) ( \ sizeof(void) != \ sizeof(*( \ 1 ? \ ((void*) ((expr) * 0L) ) : \ ((struct { char v[sizeof(void) * 2]; } *) 1) \ ) \ ) \ ) # endif #endif #if defined(JSON_HEDLEY_IS_CONSTEXPR_) #if !defined(JSON_HEDLEY_IS_CONSTANT) #define JSON_HEDLEY_IS_CONSTANT(expr) JSON_HEDLEY_IS_CONSTEXPR_(expr) #endif #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (JSON_HEDLEY_IS_CONSTEXPR_(expr) ? (expr) : (-1)) #else #if !defined(JSON_HEDLEY_IS_CONSTANT) #define JSON_HEDLEY_IS_CONSTANT(expr) (0) #endif #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (expr) #endif #if defined(JSON_HEDLEY_BEGIN_C_DECLS) #undef JSON_HEDLEY_BEGIN_C_DECLS #endif #if defined(JSON_HEDLEY_END_C_DECLS) #undef JSON_HEDLEY_END_C_DECLS #endif #if defined(JSON_HEDLEY_C_DECL) #undef JSON_HEDLEY_C_DECL #endif #if defined(__cplusplus) #define JSON_HEDLEY_BEGIN_C_DECLS extern "C" { #define JSON_HEDLEY_END_C_DECLS } #define JSON_HEDLEY_C_DECL extern "C" #else #define JSON_HEDLEY_BEGIN_C_DECLS #define JSON_HEDLEY_END_C_DECLS #define JSON_HEDLEY_C_DECL #endif #if defined(JSON_HEDLEY_STATIC_ASSERT) #undef JSON_HEDLEY_STATIC_ASSERT #endif #if \ !defined(__cplusplus) && ( \ (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)) || \ (JSON_HEDLEY_HAS_FEATURE(c_static_assert) && !defined(JSON_HEDLEY_INTEL_CL_VERSION)) || \ JSON_HEDLEY_GCC_VERSION_CHECK(6,0,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ defined(_Static_assert) \ ) # define JSON_HEDLEY_STATIC_ASSERT(expr, message) _Static_assert(expr, message) #elif \ (defined(__cplusplus) && (__cplusplus >= 201103L)) || \ JSON_HEDLEY_MSVC_VERSION_CHECK(16,0,0) || \ JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) # define JSON_HEDLEY_STATIC_ASSERT(expr, message) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(static_assert(expr, message)) #else # define JSON_HEDLEY_STATIC_ASSERT(expr, message) #endif #if defined(JSON_HEDLEY_NULL) #undef JSON_HEDLEY_NULL #endif #if defined(__cplusplus) #if __cplusplus >= 201103L #define JSON_HEDLEY_NULL JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(nullptr) #elif defined(NULL) #define JSON_HEDLEY_NULL NULL #else #define JSON_HEDLEY_NULL JSON_HEDLEY_STATIC_CAST(void*, 0) #endif #elif defined(NULL) #define JSON_HEDLEY_NULL NULL #else #define JSON_HEDLEY_NULL ((void*) 0) #endif #if defined(JSON_HEDLEY_MESSAGE) #undef JSON_HEDLEY_MESSAGE #endif #if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") # define JSON_HEDLEY_MESSAGE(msg) \ JSON_HEDLEY_DIAGNOSTIC_PUSH \ JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \ JSON_HEDLEY_PRAGMA(message msg) \ JSON_HEDLEY_DIAGNOSTIC_POP #elif \ JSON_HEDLEY_GCC_VERSION_CHECK(4,4,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) # define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message msg) #elif JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0) # define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(_CRI message msg) #elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) # define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg)) #elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,0,0) # define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg)) #else # define JSON_HEDLEY_MESSAGE(msg) #endif #if defined(JSON_HEDLEY_WARNING) #undef JSON_HEDLEY_WARNING #endif #if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") # define JSON_HEDLEY_WARNING(msg) \ JSON_HEDLEY_DIAGNOSTIC_PUSH \ JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \ JSON_HEDLEY_PRAGMA(clang warning msg) \ JSON_HEDLEY_DIAGNOSTIC_POP #elif \ JSON_HEDLEY_GCC_VERSION_CHECK(4,8,0) || \ JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \ JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) # define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(GCC warning msg) #elif \ JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) || \ JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) # define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(message(msg)) #else # define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_MESSAGE(msg) #endif #if defined(JSON_HEDLEY_REQUIRE) #undef JSON_HEDLEY_REQUIRE #endif #if defined(JSON_HEDLEY_REQUIRE_MSG) #undef JSON_HEDLEY_REQUIRE_MSG #endif #if JSON_HEDLEY_HAS_ATTRIBUTE(diagnose_if) # if JSON_HEDLEY_HAS_WARNING("-Wgcc-compat") # define JSON_HEDLEY_REQUIRE(expr) \ JSON_HEDLEY_DIAGNOSTIC_PUSH \ _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \ __attribute__((diagnose_if(!(expr), #expr, "error"))) \ JSON_HEDLEY_DIAGNOSTIC_POP # define JSON_HEDLEY_REQUIRE_MSG(expr,msg) \ JSON_HEDLEY_DIAGNOSTIC_PUSH \ _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \ __attribute__((diagnose_if(!(expr), msg, "error"))) \ JSON_HEDLEY_DIAGNOSTIC_POP # else # define JSON_HEDLEY_REQUIRE(expr) __attribute__((diagnose_if(!(expr), #expr, "error"))) # define JSON_HEDLEY_REQUIRE_MSG(expr,msg) __attribute__((diagnose_if(!(expr), msg, "error"))) # endif #else # define JSON_HEDLEY_REQUIRE(expr) # define JSON_HEDLEY_REQUIRE_MSG(expr,msg) #endif #if defined(JSON_HEDLEY_FLAGS) #undef JSON_HEDLEY_FLAGS #endif #if JSON_HEDLEY_HAS_ATTRIBUTE(flag_enum) && (!defined(__cplusplus) || JSON_HEDLEY_HAS_WARNING("-Wbitfield-enum-conversion")) #define JSON_HEDLEY_FLAGS __attribute__((__flag_enum__)) #else #define JSON_HEDLEY_FLAGS #endif #if defined(JSON_HEDLEY_FLAGS_CAST) #undef JSON_HEDLEY_FLAGS_CAST #endif #if JSON_HEDLEY_INTEL_VERSION_CHECK(19,0,0) # define JSON_HEDLEY_FLAGS_CAST(T, expr) (__extension__ ({ \ JSON_HEDLEY_DIAGNOSTIC_PUSH \ _Pragma("warning(disable:188)") \ ((T) (expr)); \ JSON_HEDLEY_DIAGNOSTIC_POP \ })) #else # define JSON_HEDLEY_FLAGS_CAST(T, expr) JSON_HEDLEY_STATIC_CAST(T, expr) #endif #if defined(JSON_HEDLEY_EMPTY_BASES) #undef JSON_HEDLEY_EMPTY_BASES #endif #if \ (JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,23918) && !JSON_HEDLEY_MSVC_VERSION_CHECK(20,0,0)) || \ JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) #define JSON_HEDLEY_EMPTY_BASES __declspec(empty_bases) #else #define JSON_HEDLEY_EMPTY_BASES #endif /* Remaining macros are deprecated. */ #if defined(JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK) #undef JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK #endif #if defined(__clang__) #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) (0) #else #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) #endif #if defined(JSON_HEDLEY_CLANG_HAS_ATTRIBUTE) #undef JSON_HEDLEY_CLANG_HAS_ATTRIBUTE #endif #define JSON_HEDLEY_CLANG_HAS_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_ATTRIBUTE(attribute) #if defined(JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE) #undef JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE #endif #define JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) #if defined(JSON_HEDLEY_CLANG_HAS_BUILTIN) #undef JSON_HEDLEY_CLANG_HAS_BUILTIN #endif #define JSON_HEDLEY_CLANG_HAS_BUILTIN(builtin) JSON_HEDLEY_HAS_BUILTIN(builtin) #if defined(JSON_HEDLEY_CLANG_HAS_FEATURE) #undef JSON_HEDLEY_CLANG_HAS_FEATURE #endif #define JSON_HEDLEY_CLANG_HAS_FEATURE(feature) JSON_HEDLEY_HAS_FEATURE(feature) #if defined(JSON_HEDLEY_CLANG_HAS_EXTENSION) #undef JSON_HEDLEY_CLANG_HAS_EXTENSION #endif #define JSON_HEDLEY_CLANG_HAS_EXTENSION(extension) JSON_HEDLEY_HAS_EXTENSION(extension) #if defined(JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE) #undef JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE #endif #define JSON_HEDLEY_CLANG_HAS_DECLSPEC_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) #if defined(JSON_HEDLEY_CLANG_HAS_WARNING) #undef JSON_HEDLEY_CLANG_HAS_WARNING #endif #define JSON_HEDLEY_CLANG_HAS_WARNING(warning) JSON_HEDLEY_HAS_WARNING(warning) #endif /* !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < X) */ // This file contains all internal macro definitions (except those affecting ABI) // You MUST include macro_unscope.hpp at the end of json.hpp to undef all of them // #include // exclude unsupported compilers #if !defined(JSON_SKIP_UNSUPPORTED_COMPILER_CHECK) #if defined(__clang__) #if (__clang_major__ * 10000 + __clang_minor__ * 100 + __clang_patchlevel__) < 30400 #error "unsupported Clang version - see https://github.com/nlohmann/json#supported-compilers" #endif #elif defined(__GNUC__) && !(defined(__ICC) || defined(__INTEL_COMPILER)) #if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) < 40800 #error "unsupported GCC version - see https://github.com/nlohmann/json#supported-compilers" #endif #endif #endif // C++ language standard detection // if the user manually specified the used c++ version this is skipped #if !defined(JSON_HAS_CPP_20) && !defined(JSON_HAS_CPP_17) && !defined(JSON_HAS_CPP_14) && !defined(JSON_HAS_CPP_11) #if (defined(__cplusplus) && __cplusplus >= 202002L) || (defined(_MSVC_LANG) && _MSVC_LANG >= 202002L) #define JSON_HAS_CPP_20 #define JSON_HAS_CPP_17 #define JSON_HAS_CPP_14 #elif (defined(__cplusplus) && __cplusplus >= 201703L) || (defined(_HAS_CXX17) && _HAS_CXX17 == 1) // fix for issue #464 #define JSON_HAS_CPP_17 #define JSON_HAS_CPP_14 #elif (defined(__cplusplus) && __cplusplus >= 201402L) || (defined(_HAS_CXX14) && _HAS_CXX14 == 1) #define JSON_HAS_CPP_14 #endif // the cpp 11 flag is always specified because it is the minimal required version #define JSON_HAS_CPP_11 #endif #ifdef __has_include #if __has_include() #include #endif #endif #if !defined(JSON_HAS_FILESYSTEM) && !defined(JSON_HAS_EXPERIMENTAL_FILESYSTEM) #ifdef JSON_HAS_CPP_17 #if defined(__cpp_lib_filesystem) #define JSON_HAS_FILESYSTEM 1 #elif defined(__cpp_lib_experimental_filesystem) #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 1 #elif !defined(__has_include) #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 1 #elif __has_include() #define JSON_HAS_FILESYSTEM 1 #elif __has_include() #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 1 #endif // std::filesystem does not work on MinGW GCC 8: https://sourceforge.net/p/mingw-w64/bugs/737/ #if defined(__MINGW32__) && defined(__GNUC__) && __GNUC__ == 8 #undef JSON_HAS_FILESYSTEM #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM #endif // no filesystem support before GCC 8: https://en.cppreference.com/w/cpp/compiler_support #if defined(__GNUC__) && !defined(__clang__) && __GNUC__ < 8 #undef JSON_HAS_FILESYSTEM #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM #endif // no filesystem support before Clang 7: https://en.cppreference.com/w/cpp/compiler_support #if defined(__clang_major__) && __clang_major__ < 7 #undef JSON_HAS_FILESYSTEM #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM #endif // no filesystem support before MSVC 19.14: https://en.cppreference.com/w/cpp/compiler_support #if defined(_MSC_VER) && _MSC_VER < 1914 #undef JSON_HAS_FILESYSTEM #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM #endif // no filesystem support before iOS 13 #if defined(__IPHONE_OS_VERSION_MIN_REQUIRED) && __IPHONE_OS_VERSION_MIN_REQUIRED < 130000 #undef JSON_HAS_FILESYSTEM #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM #endif // no filesystem support before macOS Catalina #if defined(__MAC_OS_X_VERSION_MIN_REQUIRED) && __MAC_OS_X_VERSION_MIN_REQUIRED < 101500 #undef JSON_HAS_FILESYSTEM #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM #endif #endif #endif #ifndef JSON_HAS_EXPERIMENTAL_FILESYSTEM #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 0 #endif #ifndef JSON_HAS_FILESYSTEM #define JSON_HAS_FILESYSTEM 0 #endif #ifndef JSON_HAS_THREE_WAY_COMPARISON #if defined(__cpp_impl_three_way_comparison) && __cpp_impl_three_way_comparison >= 201907L \ && defined(__cpp_lib_three_way_comparison) && __cpp_lib_three_way_comparison >= 201907L #define JSON_HAS_THREE_WAY_COMPARISON 1 #else #define JSON_HAS_THREE_WAY_COMPARISON 0 #endif #endif #ifndef JSON_HAS_RANGES // ranges header shipping in GCC 11.1.0 (released 2021-04-27) has syntax error #if defined(__GLIBCXX__) && __GLIBCXX__ == 20210427 #define JSON_HAS_RANGES 0 #elif defined(__cpp_lib_ranges) #define JSON_HAS_RANGES 1 #else #define JSON_HAS_RANGES 0 #endif #endif #ifndef JSON_HAS_STATIC_RTTI #if !defined(_HAS_STATIC_RTTI) || _HAS_STATIC_RTTI != 0 #define JSON_HAS_STATIC_RTTI 1 #else #define JSON_HAS_STATIC_RTTI 0 #endif #endif #ifdef JSON_HAS_CPP_17 #define JSON_INLINE_VARIABLE inline #else #define JSON_INLINE_VARIABLE #endif #if JSON_HEDLEY_HAS_ATTRIBUTE(no_unique_address) #define JSON_NO_UNIQUE_ADDRESS [[no_unique_address]] #else #define JSON_NO_UNIQUE_ADDRESS #endif // disable documentation warnings on clang #if defined(__clang__) #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wdocumentation" #pragma clang diagnostic ignored "-Wdocumentation-unknown-command" #endif // allow disabling exceptions #if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)) && !defined(JSON_NOEXCEPTION) #define JSON_THROW(exception) throw exception #define JSON_TRY try #define JSON_CATCH(exception) catch(exception) #define JSON_INTERNAL_CATCH(exception) catch(exception) #else #include #define JSON_THROW(exception) std::abort() #define JSON_TRY if(true) #define JSON_CATCH(exception) if(false) #define JSON_INTERNAL_CATCH(exception) if(false) #endif // override exception macros #if defined(JSON_THROW_USER) #undef JSON_THROW #define JSON_THROW JSON_THROW_USER #endif #if defined(JSON_TRY_USER) #undef JSON_TRY #define JSON_TRY JSON_TRY_USER #endif #if defined(JSON_CATCH_USER) #undef JSON_CATCH #define JSON_CATCH JSON_CATCH_USER #undef JSON_INTERNAL_CATCH #define JSON_INTERNAL_CATCH JSON_CATCH_USER #endif #if defined(JSON_INTERNAL_CATCH_USER) #undef JSON_INTERNAL_CATCH #define JSON_INTERNAL_CATCH JSON_INTERNAL_CATCH_USER #endif // allow overriding assert #if !defined(JSON_ASSERT) #include // assert #define JSON_ASSERT(x) assert(x) #endif // allow to access some private functions (needed by the test suite) #if defined(JSON_TESTS_PRIVATE) #define JSON_PRIVATE_UNLESS_TESTED public #else #define JSON_PRIVATE_UNLESS_TESTED private #endif /*! @brief macro to briefly define a mapping between an enum and JSON @def NLOHMANN_JSON_SERIALIZE_ENUM @since version 3.4.0 */ #define NLOHMANN_JSON_SERIALIZE_ENUM(ENUM_TYPE, ...) \ template \ inline void to_json(BasicJsonType& j, const ENUM_TYPE& e) \ { \ static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!"); \ static const std::pair m[] = __VA_ARGS__; \ auto it = std::find_if(std::begin(m), std::end(m), \ [e](const std::pair& ej_pair) -> bool \ { \ return ej_pair.first == e; \ }); \ j = ((it != std::end(m)) ? it : std::begin(m))->second; \ } \ template \ inline void from_json(const BasicJsonType& j, ENUM_TYPE& e) \ { \ static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!"); \ static const std::pair m[] = __VA_ARGS__; \ auto it = std::find_if(std::begin(m), std::end(m), \ [&j](const std::pair& ej_pair) -> bool \ { \ return ej_pair.second == j; \ }); \ e = ((it != std::end(m)) ? it : std::begin(m))->first; \ } // Ugly macros to avoid uglier copy-paste when specializing basic_json. They // may be removed in the future once the class is split. #define NLOHMANN_BASIC_JSON_TPL_DECLARATION \ template class ObjectType, \ template class ArrayType, \ class StringType, class BooleanType, class NumberIntegerType, \ class NumberUnsignedType, class NumberFloatType, \ template class AllocatorType, \ template class JSONSerializer, \ class BinaryType, \ class CustomBaseClass> #define NLOHMANN_BASIC_JSON_TPL \ basic_json // Macros to simplify conversion from/to types #define NLOHMANN_JSON_EXPAND( x ) x #define NLOHMANN_JSON_GET_MACRO(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, NAME,...) NAME #define NLOHMANN_JSON_PASTE(...) NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_GET_MACRO(__VA_ARGS__, \ NLOHMANN_JSON_PASTE64, \ NLOHMANN_JSON_PASTE63, \ NLOHMANN_JSON_PASTE62, \ NLOHMANN_JSON_PASTE61, \ NLOHMANN_JSON_PASTE60, \ NLOHMANN_JSON_PASTE59, \ NLOHMANN_JSON_PASTE58, \ NLOHMANN_JSON_PASTE57, \ NLOHMANN_JSON_PASTE56, \ NLOHMANN_JSON_PASTE55, \ NLOHMANN_JSON_PASTE54, \ NLOHMANN_JSON_PASTE53, \ NLOHMANN_JSON_PASTE52, \ NLOHMANN_JSON_PASTE51, \ NLOHMANN_JSON_PASTE50, \ NLOHMANN_JSON_PASTE49, \ NLOHMANN_JSON_PASTE48, \ NLOHMANN_JSON_PASTE47, \ NLOHMANN_JSON_PASTE46, \ NLOHMANN_JSON_PASTE45, \ NLOHMANN_JSON_PASTE44, \ NLOHMANN_JSON_PASTE43, \ NLOHMANN_JSON_PASTE42, \ NLOHMANN_JSON_PASTE41, \ NLOHMANN_JSON_PASTE40, \ NLOHMANN_JSON_PASTE39, \ NLOHMANN_JSON_PASTE38, \ NLOHMANN_JSON_PASTE37, \ NLOHMANN_JSON_PASTE36, \ NLOHMANN_JSON_PASTE35, \ NLOHMANN_JSON_PASTE34, \ NLOHMANN_JSON_PASTE33, \ NLOHMANN_JSON_PASTE32, \ NLOHMANN_JSON_PASTE31, \ NLOHMANN_JSON_PASTE30, \ NLOHMANN_JSON_PASTE29, \ NLOHMANN_JSON_PASTE28, \ NLOHMANN_JSON_PASTE27, \ NLOHMANN_JSON_PASTE26, \ NLOHMANN_JSON_PASTE25, \ NLOHMANN_JSON_PASTE24, \ NLOHMANN_JSON_PASTE23, \ NLOHMANN_JSON_PASTE22, \ NLOHMANN_JSON_PASTE21, \ NLOHMANN_JSON_PASTE20, \ NLOHMANN_JSON_PASTE19, \ NLOHMANN_JSON_PASTE18, \ NLOHMANN_JSON_PASTE17, \ NLOHMANN_JSON_PASTE16, \ NLOHMANN_JSON_PASTE15, \ NLOHMANN_JSON_PASTE14, \ NLOHMANN_JSON_PASTE13, \ NLOHMANN_JSON_PASTE12, \ NLOHMANN_JSON_PASTE11, \ NLOHMANN_JSON_PASTE10, \ NLOHMANN_JSON_PASTE9, \ NLOHMANN_JSON_PASTE8, \ NLOHMANN_JSON_PASTE7, \ NLOHMANN_JSON_PASTE6, \ NLOHMANN_JSON_PASTE5, \ NLOHMANN_JSON_PASTE4, \ NLOHMANN_JSON_PASTE3, \ NLOHMANN_JSON_PASTE2, \ NLOHMANN_JSON_PASTE1)(__VA_ARGS__)) #define NLOHMANN_JSON_PASTE2(func, v1) func(v1) #define NLOHMANN_JSON_PASTE3(func, v1, v2) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE2(func, v2) #define NLOHMANN_JSON_PASTE4(func, v1, v2, v3) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE3(func, v2, v3) #define NLOHMANN_JSON_PASTE5(func, v1, v2, v3, v4) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE4(func, v2, v3, v4) #define NLOHMANN_JSON_PASTE6(func, v1, v2, v3, v4, v5) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE5(func, v2, v3, v4, v5) #define NLOHMANN_JSON_PASTE7(func, v1, v2, v3, v4, v5, v6) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE6(func, v2, v3, v4, v5, v6) #define NLOHMANN_JSON_PASTE8(func, v1, v2, v3, v4, v5, v6, v7) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE7(func, v2, v3, v4, v5, v6, v7) #define NLOHMANN_JSON_PASTE9(func, v1, v2, v3, v4, v5, v6, v7, v8) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE8(func, v2, v3, v4, v5, v6, v7, v8) #define NLOHMANN_JSON_PASTE10(func, v1, v2, v3, v4, v5, v6, v7, v8, v9) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE9(func, v2, v3, v4, v5, v6, v7, v8, v9) #define NLOHMANN_JSON_PASTE11(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE10(func, v2, v3, v4, v5, v6, v7, v8, v9, v10) #define NLOHMANN_JSON_PASTE12(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE11(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) #define NLOHMANN_JSON_PASTE13(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE12(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) #define NLOHMANN_JSON_PASTE14(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE13(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13) #define NLOHMANN_JSON_PASTE15(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE14(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) #define NLOHMANN_JSON_PASTE16(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE15(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15) #define NLOHMANN_JSON_PASTE17(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE16(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16) #define NLOHMANN_JSON_PASTE18(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE17(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17) #define NLOHMANN_JSON_PASTE19(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE18(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18) #define NLOHMANN_JSON_PASTE20(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE19(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19) #define NLOHMANN_JSON_PASTE21(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE20(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20) #define NLOHMANN_JSON_PASTE22(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE21(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21) #define NLOHMANN_JSON_PASTE23(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE22(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22) #define NLOHMANN_JSON_PASTE24(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE23(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23) #define NLOHMANN_JSON_PASTE25(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE24(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24) #define NLOHMANN_JSON_PASTE26(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE25(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25) #define NLOHMANN_JSON_PASTE27(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE26(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26) #define NLOHMANN_JSON_PASTE28(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE27(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27) #define NLOHMANN_JSON_PASTE29(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE28(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28) #define NLOHMANN_JSON_PASTE30(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE29(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) #define NLOHMANN_JSON_PASTE31(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE30(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30) #define NLOHMANN_JSON_PASTE32(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE31(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31) #define NLOHMANN_JSON_PASTE33(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE32(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32) #define NLOHMANN_JSON_PASTE34(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE33(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33) #define NLOHMANN_JSON_PASTE35(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE34(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34) #define NLOHMANN_JSON_PASTE36(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE35(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35) #define NLOHMANN_JSON_PASTE37(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE36(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36) #define NLOHMANN_JSON_PASTE38(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE37(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37) #define NLOHMANN_JSON_PASTE39(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE38(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38) #define NLOHMANN_JSON_PASTE40(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE39(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39) #define NLOHMANN_JSON_PASTE41(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE40(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40) #define NLOHMANN_JSON_PASTE42(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE41(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41) #define NLOHMANN_JSON_PASTE43(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE42(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42) #define NLOHMANN_JSON_PASTE44(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE43(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43) #define NLOHMANN_JSON_PASTE45(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE44(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44) #define NLOHMANN_JSON_PASTE46(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE45(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45) #define NLOHMANN_JSON_PASTE47(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE46(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46) #define NLOHMANN_JSON_PASTE48(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE47(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47) #define NLOHMANN_JSON_PASTE49(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE48(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48) #define NLOHMANN_JSON_PASTE50(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE49(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49) #define NLOHMANN_JSON_PASTE51(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE50(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50) #define NLOHMANN_JSON_PASTE52(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE51(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51) #define NLOHMANN_JSON_PASTE53(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE52(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52) #define NLOHMANN_JSON_PASTE54(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE53(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53) #define NLOHMANN_JSON_PASTE55(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE54(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54) #define NLOHMANN_JSON_PASTE56(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE55(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55) #define NLOHMANN_JSON_PASTE57(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE56(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56) #define NLOHMANN_JSON_PASTE58(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE57(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57) #define NLOHMANN_JSON_PASTE59(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE58(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58) #define NLOHMANN_JSON_PASTE60(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE59(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59) #define NLOHMANN_JSON_PASTE61(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE60(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60) #define NLOHMANN_JSON_PASTE62(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE61(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61) #define NLOHMANN_JSON_PASTE63(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE62(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62) #define NLOHMANN_JSON_PASTE64(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE63(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63) #define NLOHMANN_JSON_TO(v1) nlohmann_json_j[#v1] = nlohmann_json_t.v1; #define NLOHMANN_JSON_FROM(v1) nlohmann_json_j.at(#v1).get_to(nlohmann_json_t.v1); #define NLOHMANN_JSON_FROM_WITH_DEFAULT(v1) nlohmann_json_t.v1 = nlohmann_json_j.value(#v1, nlohmann_json_default_obj.v1); /*! @brief macro @def NLOHMANN_DEFINE_TYPE_INTRUSIVE @since version 3.9.0 */ #define NLOHMANN_DEFINE_TYPE_INTRUSIVE(Type, ...) \ friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } #define NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(Type, ...) \ friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) } #define NLOHMANN_DEFINE_TYPE_INTRUSIVE_ONLY_SERIALIZE(Type, ...) \ friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } /*! @brief macro @def NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE @since version 3.9.0 */ #define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Type, ...) \ inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } #define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_ONLY_SERIALIZE(Type, ...) \ inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } #define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(Type, ...) \ inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) } /*! @brief macro @def NLOHMANN_DEFINE_DERIVED_TYPE_INTRUSIVE @since version 3.11.x */ #define NLOHMANN_DEFINE_DERIVED_TYPE_INTRUSIVE(Type, BaseType, ...) \ friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { nlohmann::to_json(nlohmann_json_j, static_cast(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { nlohmann::from_json(nlohmann_json_j, static_cast(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } #define NLOHMANN_DEFINE_DERIVED_TYPE_INTRUSIVE_WITH_DEFAULT(Type, BaseType, ...) \ friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { nlohmann::to_json(nlohmann_json_j, static_cast(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { nlohmann::from_json(nlohmann_json_j, static_cast(nlohmann_json_t)); const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) } /*! @brief macro @def NLOHMANN_DEFINE_DERIVED_TYPE_NON_INTRUSIVE @since version 3.11.x */ #define NLOHMANN_DEFINE_DERIVED_TYPE_NON_INTRUSIVE(Type, BaseType, ...) \ inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { nlohmann::to_json(nlohmann_json_j, static_cast(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { nlohmann::from_json(nlohmann_json_j, static_cast(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } #define NLOHMANN_DEFINE_DERIVED_TYPE_NON_INTRUSIVE_WITH_DEFAULT(Type, BaseType, ...) \ inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { nlohmann::to_json(nlohmann_json_j, static_cast(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { nlohmann::from_json(nlohmann_json_j, static_cast(nlohmann_json_t)); const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) } // inspired from https://stackoverflow.com/a/26745591 // allows to call any std function as if (e.g. with begin): // using std::begin; begin(x); // // it allows using the detected idiom to retrieve the return type // of such an expression #define NLOHMANN_CAN_CALL_STD_FUNC_IMPL(std_name) \ namespace detail { \ using std::std_name; \ \ template \ using result_of_##std_name = decltype(std_name(std::declval()...)); \ } \ \ namespace detail2 { \ struct std_name##_tag \ { \ }; \ \ template \ std_name##_tag std_name(T&&...); \ \ template \ using result_of_##std_name = decltype(std_name(std::declval()...)); \ \ template \ struct would_call_std_##std_name \ { \ static constexpr auto const value = ::nlohmann::detail:: \ is_detected_exact::value; \ }; \ } /* namespace detail2 */ \ \ template \ struct would_call_std_##std_name : detail2::would_call_std_##std_name \ { \ } #ifndef JSON_USE_IMPLICIT_CONVERSIONS #define JSON_USE_IMPLICIT_CONVERSIONS 1 #endif #if JSON_USE_IMPLICIT_CONVERSIONS #define JSON_EXPLICIT #else #define JSON_EXPLICIT explicit #endif #ifndef JSON_DISABLE_ENUM_SERIALIZATION #define JSON_DISABLE_ENUM_SERIALIZATION 0 #endif #ifndef JSON_USE_GLOBAL_UDLS #define JSON_USE_GLOBAL_UDLS 1 #endif #if JSON_HAS_THREE_WAY_COMPARISON #include // partial_ordering #endif NLOHMANN_JSON_NAMESPACE_BEGIN namespace detail { /////////////////////////// // JSON type enumeration // /////////////////////////// /*! @brief the JSON type enumeration This enumeration collects the different JSON types. It is internally used to distinguish the stored values, and the functions @ref basic_json::is_null(), @ref basic_json::is_object(), @ref basic_json::is_array(), @ref basic_json::is_string(), @ref basic_json::is_boolean(), @ref basic_json::is_number() (with @ref basic_json::is_number_integer(), @ref basic_json::is_number_unsigned(), and @ref basic_json::is_number_float()), @ref basic_json::is_discarded(), @ref basic_json::is_primitive(), and @ref basic_json::is_structured() rely on it. @note There are three enumeration entries (number_integer, number_unsigned, and number_float), because the library distinguishes these three types for numbers: @ref basic_json::number_unsigned_t is used for unsigned integers, @ref basic_json::number_integer_t is used for signed integers, and @ref basic_json::number_float_t is used for floating-point numbers or to approximate integers which do not fit in the limits of their respective type. @sa see @ref basic_json::basic_json(const value_t value_type) -- create a JSON value with the default value for a given type @since version 1.0.0 */ enum class value_t : std::uint8_t { null, ///< null value object, ///< object (unordered set of name/value pairs) array, ///< array (ordered collection of values) string, ///< string value boolean, ///< boolean value number_integer, ///< number value (signed integer) number_unsigned, ///< number value (unsigned integer) number_float, ///< number value (floating-point) binary, ///< binary array (ordered collection of bytes) discarded ///< discarded by the parser callback function }; /*! @brief comparison operator for JSON types Returns an ordering that is similar to Python: - order: null < boolean < number < object < array < string < binary - furthermore, each type is not smaller than itself - discarded values are not comparable - binary is represented as a b"" string in python and directly comparable to a string; however, making a binary array directly comparable with a string would be surprising behavior in a JSON file. @since version 1.0.0 */ #if JSON_HAS_THREE_WAY_COMPARISON inline std::partial_ordering operator<=>(const value_t lhs, const value_t rhs) noexcept // *NOPAD* #else inline bool operator<(const value_t lhs, const value_t rhs) noexcept #endif { static constexpr std::array order = {{ 0 /* null */, 3 /* object */, 4 /* array */, 5 /* string */, 1 /* boolean */, 2 /* integer */, 2 /* unsigned */, 2 /* float */, 6 /* binary */ } }; const auto l_index = static_cast(lhs); const auto r_index = static_cast(rhs); #if JSON_HAS_THREE_WAY_COMPARISON if (l_index < order.size() && r_index < order.size()) { return order[l_index] <=> order[r_index]; // *NOPAD* } return std::partial_ordering::unordered; #else return l_index < order.size() && r_index < order.size() && order[l_index] < order[r_index]; #endif } // GCC selects the built-in operator< over an operator rewritten from // a user-defined spaceship operator // Clang, MSVC, and ICC select the rewritten candidate // (see GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105200) #if JSON_HAS_THREE_WAY_COMPARISON && defined(__GNUC__) inline bool operator<(const value_t lhs, const value_t rhs) noexcept { return std::is_lt(lhs <=> rhs); // *NOPAD* } #endif } // namespace detail NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT // #include NLOHMANN_JSON_NAMESPACE_BEGIN namespace detail { /*! @brief replace all occurrences of a substring by another string @param[in,out] s the string to manipulate; changed so that all occurrences of @a f are replaced with @a t @param[in] f the substring to replace with @a t @param[in] t the string to replace @a f @pre The search string @a f must not be empty. **This precondition is enforced with an assertion.** @since version 2.0.0 */ template inline void replace_substring(StringType& s, const StringType& f, const StringType& t) { JSON_ASSERT(!f.empty()); for (auto pos = s.find(f); // find first occurrence of f pos != StringType::npos; // make sure f was found s.replace(pos, f.size(), t), // replace with t, and pos = s.find(f, pos + t.size())) // find next occurrence of f {} } /*! * @brief string escaping as described in RFC 6901 (Sect. 4) * @param[in] s string to escape * @return escaped string * * Note the order of escaping "~" to "~0" and "/" to "~1" is important. */ template inline StringType escape(StringType s) { replace_substring(s, StringType{"~"}, StringType{"~0"}); replace_substring(s, StringType{"/"}, StringType{"~1"}); return s; } /*! * @brief string unescaping as described in RFC 6901 (Sect. 4) * @param[in] s string to unescape * @return unescaped string * * Note the order of escaping "~1" to "/" and "~0" to "~" is important. */ template static void unescape(StringType& s) { replace_substring(s, StringType{"~1"}, StringType{"/"}); replace_substring(s, StringType{"~0"}, StringType{"~"}); } } // namespace detail NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT #include // size_t // #include NLOHMANN_JSON_NAMESPACE_BEGIN namespace detail { /// struct to capture the start position of the current token struct position_t { /// the total number of characters read std::size_t chars_read_total = 0; /// the number of characters read in the current line std::size_t chars_read_current_line = 0; /// the number of lines read std::size_t lines_read = 0; /// conversion to size_t to preserve SAX interface constexpr operator size_t() const { return chars_read_total; } }; } // namespace detail NLOHMANN_JSON_NAMESPACE_END // #include // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-FileCopyrightText: 2018 The Abseil Authors // SPDX-License-Identifier: MIT #include // array #include // size_t #include // conditional, enable_if, false_type, integral_constant, is_constructible, is_integral, is_same, remove_cv, remove_reference, true_type #include // index_sequence, make_index_sequence, index_sequence_for // #include NLOHMANN_JSON_NAMESPACE_BEGIN namespace detail { template using uncvref_t = typename std::remove_cv::type>::type; #ifdef JSON_HAS_CPP_14 // the following utilities are natively available in C++14 using std::enable_if_t; using std::index_sequence; using std::make_index_sequence; using std::index_sequence_for; #else // alias templates to reduce boilerplate template using enable_if_t = typename std::enable_if::type; // The following code is taken from https://github.com/abseil/abseil-cpp/blob/10cb35e459f5ecca5b2ff107635da0bfa41011b4/absl/utility/utility.h // which is part of Google Abseil (https://github.com/abseil/abseil-cpp), licensed under the Apache License 2.0. //// START OF CODE FROM GOOGLE ABSEIL // integer_sequence // // Class template representing a compile-time integer sequence. An instantiation // of `integer_sequence` has a sequence of integers encoded in its // type through its template arguments (which is a common need when // working with C++11 variadic templates). `absl::integer_sequence` is designed // to be a drop-in replacement for C++14's `std::integer_sequence`. // // Example: // // template< class T, T... Ints > // void user_function(integer_sequence); // // int main() // { // // user_function's `T` will be deduced to `int` and `Ints...` // // will be deduced to `0, 1, 2, 3, 4`. // user_function(make_integer_sequence()); // } template struct integer_sequence { using value_type = T; static constexpr std::size_t size() noexcept { return sizeof...(Ints); } }; // index_sequence // // A helper template for an `integer_sequence` of `size_t`, // `absl::index_sequence` is designed to be a drop-in replacement for C++14's // `std::index_sequence`. template using index_sequence = integer_sequence; namespace utility_internal { template struct Extend; // Note that SeqSize == sizeof...(Ints). It's passed explicitly for efficiency. template struct Extend, SeqSize, 0> { using type = integer_sequence < T, Ints..., (Ints + SeqSize)... >; }; template struct Extend, SeqSize, 1> { using type = integer_sequence < T, Ints..., (Ints + SeqSize)..., 2 * SeqSize >; }; // Recursion helper for 'make_integer_sequence'. // 'Gen::type' is an alias for 'integer_sequence'. template struct Gen { using type = typename Extend < typename Gen < T, N / 2 >::type, N / 2, N % 2 >::type; }; template struct Gen { using type = integer_sequence; }; } // namespace utility_internal // Compile-time sequences of integers // make_integer_sequence // // This template alias is equivalent to // `integer_sequence`, and is designed to be a drop-in // replacement for C++14's `std::make_integer_sequence`. template using make_integer_sequence = typename utility_internal::Gen::type; // make_index_sequence // // This template alias is equivalent to `index_sequence<0, 1, ..., N-1>`, // and is designed to be a drop-in replacement for C++14's // `std::make_index_sequence`. template using make_index_sequence = make_integer_sequence; // index_sequence_for // // Converts a typename pack into an index sequence of the same length, and // is designed to be a drop-in replacement for C++14's // `std::index_sequence_for()` template using index_sequence_for = make_index_sequence; //// END OF CODE FROM GOOGLE ABSEIL #endif // dispatch utility (taken from ranges-v3) template struct priority_tag : priority_tag < N - 1 > {}; template<> struct priority_tag<0> {}; // taken from ranges-v3 template struct static_const { static JSON_INLINE_VARIABLE constexpr T value{}; }; #ifndef JSON_HAS_CPP_17 template constexpr T static_const::value; #endif template inline constexpr std::array make_array(Args&& ... args) { return std::array {{static_cast(std::forward(args))...}}; } } // namespace detail NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT #include // numeric_limits #include // false_type, is_constructible, is_integral, is_same, true_type #include // declval #include // tuple #include // char_traits // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT #include // random_access_iterator_tag // #include // #include // #include NLOHMANN_JSON_NAMESPACE_BEGIN namespace detail { template struct iterator_types {}; template struct iterator_types < It, void_t> { using difference_type = typename It::difference_type; using value_type = typename It::value_type; using pointer = typename It::pointer; using reference = typename It::reference; using iterator_category = typename It::iterator_category; }; // This is required as some compilers implement std::iterator_traits in a way that // doesn't work with SFINAE. See https://github.com/nlohmann/json/issues/1341. template struct iterator_traits { }; template struct iterator_traits < T, enable_if_t < !std::is_pointer::value >> : iterator_types { }; template struct iterator_traits::value>> { using iterator_category = std::random_access_iterator_tag; using value_type = T; using difference_type = ptrdiff_t; using pointer = T*; using reference = T&; }; } // namespace detail NLOHMANN_JSON_NAMESPACE_END // #include // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT // #include NLOHMANN_JSON_NAMESPACE_BEGIN NLOHMANN_CAN_CALL_STD_FUNC_IMPL(begin); NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT // #include NLOHMANN_JSON_NAMESPACE_BEGIN NLOHMANN_CAN_CALL_STD_FUNC_IMPL(end); NLOHMANN_JSON_NAMESPACE_END // #include // #include // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ // | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT #ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_ #define INCLUDE_NLOHMANN_JSON_FWD_HPP_ #include // int64_t, uint64_t #include // map #include // allocator #include // string #include // vector // #include /*! @brief namespace for Niels Lohmann @see https://github.com/nlohmann @since version 1.0.0 */ NLOHMANN_JSON_NAMESPACE_BEGIN /*! @brief default JSONSerializer template argument This serializer ignores the template arguments and uses ADL ([argument-dependent lookup](https://en.cppreference.com/w/cpp/language/adl)) for serialization. */ template struct adl_serializer; /// a class to store JSON values /// @sa https://json.nlohmann.me/api/basic_json/ template class ObjectType = std::map, template class ArrayType = std::vector, class StringType = std::string, class BooleanType = bool, class NumberIntegerType = std::int64_t, class NumberUnsignedType = std::uint64_t, class NumberFloatType = double, template class AllocatorType = std::allocator, template class JSONSerializer = adl_serializer, class BinaryType = std::vector, // cppcheck-suppress syntaxError class CustomBaseClass = void> class basic_json; /// @brief JSON Pointer defines a string syntax for identifying a specific value within a JSON document /// @sa https://json.nlohmann.me/api/json_pointer/ template class json_pointer; /*! @brief default specialization @sa https://json.nlohmann.me/api/json/ */ using json = basic_json<>; /// @brief a minimal map-like container that preserves insertion order /// @sa https://json.nlohmann.me/api/ordered_map/ template struct ordered_map; /// @brief specialization that maintains the insertion order of object keys /// @sa https://json.nlohmann.me/api/ordered_json/ using ordered_json = basic_json; NLOHMANN_JSON_NAMESPACE_END #endif // INCLUDE_NLOHMANN_JSON_FWD_HPP_ NLOHMANN_JSON_NAMESPACE_BEGIN /*! @brief detail namespace with internal helper functions This namespace collects functions that should not be exposed, implementations of some @ref basic_json methods, and meta-programming helpers. @since version 2.1.0 */ namespace detail { ///////////// // helpers // ///////////// // Note to maintainers: // // Every trait in this file expects a non CV-qualified type. // The only exceptions are in the 'aliases for detected' section // (i.e. those of the form: decltype(T::member_function(std::declval()))) // // In this case, T has to be properly CV-qualified to constraint the function arguments // (e.g. to_json(BasicJsonType&, const T&)) template struct is_basic_json : std::false_type {}; NLOHMANN_BASIC_JSON_TPL_DECLARATION struct is_basic_json : std::true_type {}; // used by exceptions create() member functions // true_type for pointer to possibly cv-qualified basic_json or std::nullptr_t // false_type otherwise template struct is_basic_json_context : std::integral_constant < bool, is_basic_json::type>::type>::value || std::is_same::value > {}; ////////////////////// // json_ref helpers // ////////////////////// template class json_ref; template struct is_json_ref : std::false_type {}; template struct is_json_ref> : std::true_type {}; ////////////////////////// // aliases for detected // ////////////////////////// template using mapped_type_t = typename T::mapped_type; template using key_type_t = typename T::key_type; template using value_type_t = typename T::value_type; template using difference_type_t = typename T::difference_type; template using pointer_t = typename T::pointer; template using reference_t = typename T::reference; template using iterator_category_t = typename T::iterator_category; template using to_json_function = decltype(T::to_json(std::declval()...)); template using from_json_function = decltype(T::from_json(std::declval()...)); template using get_template_function = decltype(std::declval().template get()); // trait checking if JSONSerializer::from_json(json const&, udt&) exists template struct has_from_json : std::false_type {}; // trait checking if j.get is valid // use this trait instead of std::is_constructible or std::is_convertible, // both rely on, or make use of implicit conversions, and thus fail when T // has several constructors/operator= (see https://github.com/nlohmann/json/issues/958) template struct is_getable { static constexpr bool value = is_detected::value; }; template struct has_from_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> { using serializer = typename BasicJsonType::template json_serializer; static constexpr bool value = is_detected_exact::value; }; // This trait checks if JSONSerializer::from_json(json const&) exists // this overload is used for non-default-constructible user-defined-types template struct has_non_default_from_json : std::false_type {}; template struct has_non_default_from_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> { using serializer = typename BasicJsonType::template json_serializer; static constexpr bool value = is_detected_exact::value; }; // This trait checks if BasicJsonType::json_serializer::to_json exists // Do not evaluate the trait when T is a basic_json type, to avoid template instantiation infinite recursion. template struct has_to_json : std::false_type {}; template struct has_to_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> { using serializer = typename BasicJsonType::template json_serializer; static constexpr bool value = is_detected_exact::value; }; template using detect_key_compare = typename T::key_compare; template struct has_key_compare : std::integral_constant::value> {}; // obtains the actual object key comparator template struct actual_object_comparator { using object_t = typename BasicJsonType::object_t; using object_comparator_t = typename BasicJsonType::default_object_comparator_t; using type = typename std::conditional < has_key_compare::value, typename object_t::key_compare, object_comparator_t>::type; }; template using actual_object_comparator_t = typename actual_object_comparator::type; ///////////////// // char_traits // ///////////////// // Primary template of char_traits calls std char_traits template struct char_traits : std::char_traits {}; // Explicitly define char traits for unsigned char since it is not standard template<> struct char_traits : std::char_traits { using char_type = unsigned char; using int_type = uint64_t; // Redefine to_int_type function static int_type to_int_type(char_type c) noexcept { return static_cast(c); } static char_type to_char_type(int_type i) noexcept { return static_cast(i); } static constexpr int_type eof() noexcept { return static_cast(EOF); } }; // Explicitly define char traits for signed char since it is not standard template<> struct char_traits : std::char_traits { using char_type = signed char; using int_type = uint64_t; // Redefine to_int_type function static int_type to_int_type(char_type c) noexcept { return static_cast(c); } static char_type to_char_type(int_type i) noexcept { return static_cast(i); } static constexpr int_type eof() noexcept { return static_cast(EOF); } }; /////////////////// // is_ functions // /////////////////// // https://en.cppreference.com/w/cpp/types/conjunction template struct conjunction : std::true_type { }; template struct conjunction : B { }; template struct conjunction : std::conditional(B::value), conjunction, B>::type {}; // https://en.cppreference.com/w/cpp/types/negation template struct negation : std::integral_constant < bool, !B::value > { }; // Reimplementation of is_constructible and is_default_constructible, due to them being broken for // std::pair and std::tuple until LWG 2367 fix (see https://cplusplus.github.io/LWG/lwg-defects.html#2367). // This causes compile errors in e.g. clang 3.5 or gcc 4.9. template struct is_default_constructible : std::is_default_constructible {}; template struct is_default_constructible> : conjunction, is_default_constructible> {}; template struct is_default_constructible> : conjunction, is_default_constructible> {}; template struct is_default_constructible> : conjunction...> {}; template struct is_default_constructible> : conjunction...> {}; template struct is_constructible : std::is_constructible {}; template struct is_constructible> : is_default_constructible> {}; template struct is_constructible> : is_default_constructible> {}; template struct is_constructible> : is_default_constructible> {}; template struct is_constructible> : is_default_constructible> {}; template struct is_iterator_traits : std::false_type {}; template struct is_iterator_traits> { private: using traits = iterator_traits; public: static constexpr auto value = is_detected::value && is_detected::value && is_detected::value && is_detected::value && is_detected::value; }; template struct is_range { private: using t_ref = typename std::add_lvalue_reference::type; using iterator = detected_t; using sentinel = detected_t; // to be 100% correct, it should use https://en.cppreference.com/w/cpp/iterator/input_or_output_iterator // and https://en.cppreference.com/w/cpp/iterator/sentinel_for // but reimplementing these would be too much work, as a lot of other concepts are used underneath static constexpr auto is_iterator_begin = is_iterator_traits>::value; public: static constexpr bool value = !std::is_same::value && !std::is_same::value && is_iterator_begin; }; template using iterator_t = enable_if_t::value, result_of_begin())>>; template using range_value_t = value_type_t>>; // The following implementation of is_complete_type is taken from // https://blogs.msdn.microsoft.com/vcblog/2015/12/02/partial-support-for-expression-sfinae-in-vs-2015-update-1/ // and is written by Xiang Fan who agreed to using it in this library. template struct is_complete_type : std::false_type {}; template struct is_complete_type : std::true_type {}; template struct is_compatible_object_type_impl : std::false_type {}; template struct is_compatible_object_type_impl < BasicJsonType, CompatibleObjectType, enable_if_t < is_detected::value&& is_detected::value >> { using object_t = typename BasicJsonType::object_t; // macOS's is_constructible does not play well with nonesuch... static constexpr bool value = is_constructible::value && is_constructible::value; }; template struct is_compatible_object_type : is_compatible_object_type_impl {}; template struct is_constructible_object_type_impl : std::false_type {}; template struct is_constructible_object_type_impl < BasicJsonType, ConstructibleObjectType, enable_if_t < is_detected::value&& is_detected::value >> { using object_t = typename BasicJsonType::object_t; static constexpr bool value = (is_default_constructible::value && (std::is_move_assignable::value || std::is_copy_assignable::value) && (is_constructible::value && std::is_same < typename object_t::mapped_type, typename ConstructibleObjectType::mapped_type >::value)) || (has_from_json::value || has_non_default_from_json < BasicJsonType, typename ConstructibleObjectType::mapped_type >::value); }; template struct is_constructible_object_type : is_constructible_object_type_impl {}; template struct is_compatible_string_type { static constexpr auto value = is_constructible::value; }; template struct is_constructible_string_type { // launder type through decltype() to fix compilation failure on ICPC #ifdef __INTEL_COMPILER using laundered_type = decltype(std::declval()); #else using laundered_type = ConstructibleStringType; #endif static constexpr auto value = conjunction < is_constructible, is_detected_exact>::value; }; template struct is_compatible_array_type_impl : std::false_type {}; template struct is_compatible_array_type_impl < BasicJsonType, CompatibleArrayType, enable_if_t < is_detected::value&& is_iterator_traits>>::value&& // special case for types like std::filesystem::path whose iterator's value_type are themselves // c.f. https://github.com/nlohmann/json/pull/3073 !std::is_same>::value >> { static constexpr bool value = is_constructible>::value; }; template struct is_compatible_array_type : is_compatible_array_type_impl {}; template struct is_constructible_array_type_impl : std::false_type {}; template struct is_constructible_array_type_impl < BasicJsonType, ConstructibleArrayType, enable_if_t::value >> : std::true_type {}; template struct is_constructible_array_type_impl < BasicJsonType, ConstructibleArrayType, enable_if_t < !std::is_same::value&& !is_compatible_string_type::value&& is_default_constructible::value&& (std::is_move_assignable::value || std::is_copy_assignable::value)&& is_detected::value&& is_iterator_traits>>::value&& is_detected::value&& // special case for types like std::filesystem::path whose iterator's value_type are themselves // c.f. https://github.com/nlohmann/json/pull/3073 !std::is_same>::value&& is_complete_type < detected_t>::value >> { using value_type = range_value_t; static constexpr bool value = std::is_same::value || has_from_json::value || has_non_default_from_json < BasicJsonType, value_type >::value; }; template struct is_constructible_array_type : is_constructible_array_type_impl {}; template struct is_compatible_integer_type_impl : std::false_type {}; template struct is_compatible_integer_type_impl < RealIntegerType, CompatibleNumberIntegerType, enable_if_t < std::is_integral::value&& std::is_integral::value&& !std::is_same::value >> { // is there an assert somewhere on overflows? using RealLimits = std::numeric_limits; using CompatibleLimits = std::numeric_limits; static constexpr auto value = is_constructible::value && CompatibleLimits::is_integer && RealLimits::is_signed == CompatibleLimits::is_signed; }; template struct is_compatible_integer_type : is_compatible_integer_type_impl {}; template struct is_compatible_type_impl: std::false_type {}; template struct is_compatible_type_impl < BasicJsonType, CompatibleType, enable_if_t::value >> { static constexpr bool value = has_to_json::value; }; template struct is_compatible_type : is_compatible_type_impl {}; template struct is_constructible_tuple : std::false_type {}; template struct is_constructible_tuple> : conjunction...> {}; template struct is_json_iterator_of : std::false_type {}; template struct is_json_iterator_of : std::true_type {}; template struct is_json_iterator_of : std::true_type {}; // checks if a given type T is a template specialization of Primary template