Full Code of andrewkchan/deepseek.cpp for AI

main 8db9e56af867 cached
35 files
1.6 MB
484.3k tokens
781 symbols
1 requests
Download .txt
Showing preview only (1,663K chars total). Download the full file or copy to clipboard to get everything.
Repository: andrewkchan/deepseek.cpp
Branch: main
Commit: 8db9e56af867
Files: 35
Total size: 1.6 MB

Directory structure:
gitextract_edm78p7f/

├── .gitignore
├── LICENSE.md
├── Makefile
├── README.md
├── convert.py
├── pyproject.toml
├── quantizer.cpp
├── quantizer.py
├── setup.py
├── src/
│   ├── codec.cpp
│   ├── codec.h
│   ├── debug.cpp
│   ├── debug.h
│   ├── infer.cpp
│   ├── main.cpp
│   ├── model.cpp
│   ├── model.h
│   ├── profile.cpp
│   ├── profile.h
│   ├── quant.cpp
│   ├── quant.h
│   ├── sampler.cpp
│   ├── sampler.h
│   ├── test.cpp
│   ├── time_utils.cpp
│   ├── time_utils.h
│   ├── tokenizer.cpp
│   ├── tokenizer.h
│   ├── wikitest.cat.1chunk.v2-encoded.txt
│   └── wikitest.cat.1chunk.v3-encoded.txt
└── vendor/
    ├── fmt/
    │   ├── base.h
    │   ├── format-inl.h
    │   └── format.h
    ├── format.cc
    └── json.hpp

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
env/

# build intermediates
.vscode/
build/
__pycache__/
*.egg-info/
*.cpython-312-x86_64-linux-gnu.so

# profiling tools
*.sqlite
*.nsys-rep
*.ncu-rep
perf.data*
*.gputrace/

**/.DS_Store

================================================
FILE: LICENSE.md
================================================
MIT License

Copyright (c) 2025 Andrew Chan

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

=================================

K-quants adapted from llama.cpp

MIT License

Copyright (c) 2023-2024 The ggml authors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

=================================

nlohmann/json

MIT License 

Copyright (c) 2013-2025 Niels Lohmann

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

=================================

fmt

Copyright (c) 2012 - present, Victor Zverovich and {fmt} contributors

Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:

The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

--- Optional exception to the license ---

As an exception, if, as a result of your compiling your source code, portions
of this Software are embedded into a machine-executable object form of such
source code, you may redistribute such embedded portions in such object form
without including the above copyright and permission notices.


================================================
FILE: Makefile
================================================
MAKEFLAGS+=-r -j

UNAME=$(shell uname)

BUILD=build
ASM_DIR=$(BUILD)/asm

# compile .c, .cpp, .cu files
SOURCES=$(filter-out src/test.cpp,$(wildcard src/*.c))
SOURCES+=$(filter-out src/test.cpp,$(wildcard src/*.cc))
SOURCES+=$(filter-out src/test.cpp,$(wildcard src/*.cpp))
SOURCES+=$(filter-out src/test.cpp,$(wildcard src/*.cu))
SOURCES+=$(wildcard vendor/*.c)
SOURCES+=$(wildcard vendor/*.cc)
SOURCES+=$(wildcard vendor/*.cpp)
SOURCES+=$(wildcard vendor/*.cu)

# Define test sources separately
TEST_SOURCES=src/test.cpp
TEST_SOURCES+=$(filter-out src/main.cpp,$(SOURCES))

OBJECTS=$(SOURCES:%=$(BUILD)/%.o)
TEST_OBJECTS=$(TEST_SOURCES:%=$(BUILD)/%.o)
ASM_FILES=$(patsubst %.cpp,$(ASM_DIR)/%.s,$(filter %.cpp,$(SOURCES)))
TEST_ASM_FILES=$(patsubst %.cpp,$(ASM_DIR)/%.s,$(filter %.cpp,$(TEST_SOURCES)))

BINARY=$(BUILD)/main
TEST_BINARY=$(BUILD)/test
PROFILE_BINARY=$(BUILD)/main_profile

BASE_CFLAGS=-g -Wall -Wpointer-arith -Werror -O3 -ffast-math -Ivendor -std=c++20
BASE_LDFLAGS=-lm

BASE_CFLAGS+=-fopenmp -mf16c -mavx2 -mfma
BASE_LDFLAGS+=-fopenmp

PROFILE_CFLAGS=$(BASE_CFLAGS) -pg -fno-omit-frame-pointer
PROFILE_LDFLAGS=$(BASE_LDFLAGS) -pg

CFLAGS=$(BASE_CFLAGS)
LDFLAGS=$(BASE_LDFLAGS)

all: $(BINARY) asm

profile: CFLAGS=$(PROFILE_CFLAGS)
profile: LDFLAGS=$(PROFILE_LDFLAGS)
profile: $(PROFILE_BINARY)

test: $(TEST_BINARY) test-asm

# Target to build just assembly files
asm: $(ASM_FILES)

test-asm: $(TEST_ASM_FILES)

format:
	clang-format -i src/*

$(BINARY): $(OBJECTS)
	$(CXX) $^ $(LDFLAGS) -o $@

$(TEST_BINARY): $(TEST_OBJECTS)
	$(CXX) $^ $(LDFLAGS) -o $@

$(PROFILE_BINARY): $(OBJECTS)
	$(CXX) $^ $(PROFILE_LDFLAGS) -o $@

# Rule to generate assembly for cpp files
$(ASM_DIR)/%.s: %.cpp
	@mkdir -p $(dir $@)
	$(CXX) $< $(CFLAGS) -S -masm=intel -o $@

$(BUILD)/%.c.o: %.c
	@mkdir -p $(dir $@)
	$(CXX) $< $(CFLAGS) -c -MMD -MP -o $@

$(BUILD)/%.cpp.o: %.cpp
	@mkdir -p $(dir $@)
	$(CXX) $< $(CFLAGS) -c -MMD -MP -o $@

$(BUILD)/%.cc.o: %.cc
	@mkdir -p $(dir $@)
	$(CXX) $< $(CFLAGS) -c -MMD -MP -o $@

-include $(OBJECTS:.o=.d)
-include $(TEST_OBJECTS:.o=.d)

clean:
	rm -rf $(BUILD)

.PHONY: all clean format test asm test-asm profile

================================================
FILE: README.md
================================================
This is an CPU-only inference implementation for the DeepSeek family of large language models written in C++, based on [Yet Another Language Model](https://github.com/andrewkchan/yalm). 

## Why?

For fun and learning!

I was initially adding DeepSeek support to `yalm` but realized that the changes were large and complex enough that it might ruin the simplicity of that project. Maybe at some point I'll upstream the changes, but for now I've decided to fork them into a separate, smaller, leaner codebase. 

Since this program only supports DeepSeek, it's tiny compared to other inference engines (<2k LOC not including `fmt` and `json`, vs. >250k for llama.cpp and vllm) and is extra hackable. I'm currently using it as a testbed to study single-batch DeepSeek decoding performance on CPU.

## Model and hardware support

Quantizations other than FP32 require AVX2 and F16C support.

| Model      | Q2_K | Q3_K | Q4_K | F8E5M2 | F8E4M3 | FP16 | BF16 | FP32 |
| -----      | ---- | ---- | ------ | ------ | ---- | ---- | ---- | ---- |
| DeepSeek-V2-Lite | ✅ | ✅ | WIP | ✅ | WIP | ✅ | WIP | ✅ |
| DeepSeek-V2 | ✅ | ✅ | WIP | ✅ | WIP | ✅ | WIP | ✅ |
| DeepSeek-V2.5 | ✅ | ✅ | WIP | ✅ | WIP | ✅ | WIP | ✅ |
| DeepSeek-V3 | ✅ | ✅ | WIP | ✅ | WIP | - | - | - |
| DeepSeek-V3.1 (Terminus) | ✅ | ✅ | WIP | ✅ | WIP | - | - | - |
| DeepSeek-R1 | ✅ | ✅ | WIP | ✅ | WIP | - | - | - |

deepseek.cpp is missing important optimizations for production use (see notes below), but gets pretty close to llama.cpp in single-batch decode speed. Benchmarking DeepSeek-V3-Base with Q2_K quantization on an AWS r6a.12xlarge instance (AMD EPYC 7R13, 2x24 cores, 384GB DDR4 RAM):
- llama.cpp ([DeepSeek-V3-Q2_K_XS](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q2_K_XS) 207GB, tg128, best of 16/24/32/48 threads): 4.57 tok/s
- deepseek.cpp (Q2_K 207GB, MHA, `-n 128 -L` completion with 16 threads): 4.02 tok/s

A big part of this is that deepseek.cpp uses the llama.cpp vec_dot kernels for Q2_K, so I can't claim to have matched its performance purely through my own ingenuity. But it is surprising given the inference code is much simpler, opting for OpenMP over a [global threadpool with spinlock kernel barriers](https://justine.lol/matmul/#threads). I'm hoping that in addition to serving as a testbed for myself, this gives a good base for others to hack on.

# Instructions

deepseek.cpp requires a computer with a C++20-compatible compiler. You'll also need a directory containing LLM safetensor weights and configuration files in huggingface format, which you'll need to convert by providing a directory into which `.dseek` files containing the converted weights will go. Follow the below to download DeepSeek-V2-Lite, build `deepseek.cpp`, and run it:

```
# install git LFS and build tools
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
sudo apt-get -y install git-lfs python3-dev build-essential
# download DeepSeek-V2-Lite
git clone https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite
# clone this repository
git clone https://github.com/andrewkchan/deepseek.cpp.git

cd deepseek.cpp
pip install .
python convert.py --quant fp16 v2-lite-f16 ../DeepSeek-V2-Lite/
./build/main v2-lite-f16 -i "What is a large language model?" -m c -t 1.0
```

## Usage

See the CLI help documentation below for `./build/main`:

```
Usage:   main <checkpoint_dir> [options]
Example: main model_weights_dir/ -i "Q: What is the meaning of life?"
Options:
  -h Display this help message
  -L Locks model weights to RAM, disabling swap. Requires sudo.
  -m [completion,passkey,perplexity,interactive] which mode to run in (default - completion)
  -T <int> sliding window context length (0 - max)

Perplexity mode options:
  Choose one:
    -i <string> input prompt
    -f <filepath> input file with prompt
    -w use wikitext as input
Completion mode options:
  -n <int>    number of steps to run for in completion mode, default 256. 0 = max_seq_len, -1 = infinite
  Choose one:
    -i <string> input prompt
    -t <float> temperature (default - 1.0)
    -p <float> p for top-p sampling (default - 0.95)
    -f <filepath> input file with prompt
Passkey mode options:
  -n <int>    number of junk lines to insert (default - 250)
  -l <int>    passkey position (-1 - random)
```

You will likely need to tune the number of OpenMP threads to achieve good performance. For example: 
```
OMP_NUM_THREADS=32 ./build/main <...args>
```

The default OpenMP thread count can result in severely degraded throughput, likely due to thread contention. I have found a good heuristic to be half the number of cores.

## Notes

- `--quant=f8e5m2` specifies model weight quantization using 128x128 blocks. MoE gates and layer norms are left in full precision. This should provide better accuracy than per-tensor quantization or the naive truncating quantization done by `yalm` (which results in nonsensical output for the DeepSeek family of models).
- `--quant=q2_k` and `--quant=q3_k` specify model weight quantization using the 2-bit and 3-bit llama.cpp [K-quantization schemes](https://github.com/ggml-org/llama.cpp/pull/1684), which use a two-level hierarchy of blocks and super-blocks to store scales/biases for ranges of weights.
- The models have a tendency to repeat themselves and get into infinite loops at lower temperatures. In my testing, a temperature of ~1.0 avoids this failure mode but also keeps the models reasonably grounded.
- Some new, optional architectural features (e.g. the `noaux_tc` method of expert selection) of DeepSeek V3 have not yet been implemented, so the model accuracy may be lower than the reference model.
- You will need ~650GB of memory to run DeepSeek V3 in F8E5M2, or 206GB for 2-bit Q2_K. For best performance, you should ensure there is enough physical RAM available and run as `sudo` with `-L` to force weights to stay in RAM, but otherwise, most operating systems will also automatically supplement this with swap space (storing some memory on disk and some in RAM) at the cost of severely degraded token throughput. More aggressive quantization methods such as [1.58-bit](https://unsloth.ai/blog/deepseekr1-dynamic) are planned.
- Model quality is not stable because I've been using this repository as an experiment testbed. See (https://github.com/andrewkchan/deepseek.cpp/pull/14) for the latest perplexity measurements on DeepSeek-V2-Lite as well as instructions on how to run standard measurements yourself. Known issues impacting generation quality include the tokenizer (which is not a true BPE tokenizer) and the use of attention sinks rather than yarn (https://github.com/andrewkchan/deepseek.cpp/pull/15).
- Only decoding (e.g. incremental, iterative generation or reading of one token at a time) has been implemented. Prefills (reading a batch of prompt tokens in a single pass) have not been implemented, nor prefill-based optimizations for the decoding phase such as speculative decoding or multi-token prediction. Finally, the current multi-latent attention implementation is still slower than multi-latent attention in surprising scenarios (https://github.com/andrewkchan/deepseek.cpp/pull/8) and appears to be under-utilizing memory bandwidth. I have limited time to implement these optimizations as this is a side project for me, but PRs are welcome!

================================================
FILE: convert.py
================================================
# Converts a model consisting of a huggingface config.json, tokenizer.json, and .safetensors weights into a .yalm file,
# which:
# - Normalizes the config to a common format in the header
# - Combines any safetensors shards
# - Reads the token vocabulary into a simpler format
# - Performs quantization if specified

import argparse
import os
import json
import safetensors
from safetensors.torch import save_file
import torch

from quantizer import k_quantize

from typing import Tuple, List, Literal, Union
import dataclasses

SUPPORTED_ARCHITECTURES = [
  "DeepseekV2ForCausalLM",
  "DeepseekV3ForCausalLM",
]

@dataclasses.dataclass
class BlockQuant:
  name: Literal["fp32", "fp16", "f8e5m2"]
  block_size: Union[Tuple[int, int], None]
  dtype: torch.dtype

@dataclasses.dataclass
class KQuant:
  name: Literal["q2_k", "q3_k"]
  dtype: torch.dtype

Quant = Union[BlockQuant, KQuant]

SUPPORTED_QUANTS = {
  "fp32": BlockQuant(name="fp32", block_size=None, dtype=torch.float32),
  "fp16": BlockQuant(name="fp16", block_size=None, dtype=torch.float16),
  "f8e5m2": BlockQuant(name="f8e5m2", block_size=(128, 128), dtype=torch.float8_e5m2),
  "q2_k": KQuant(name="q2_k", dtype=torch.uint8),
  "q3_k": KQuant(name="q3_k", dtype=torch.uint8),
}

class Metadata:
  def __init__(self, config, tokenizer_config, quant, n_layers, use_mla, bsize):
    arch = config["architectures"][0]
    if arch not in SUPPORTED_ARCHITECTURES:
      raise Exception(f"Architecture {arch} is not supported, must be one of {SUPPORTED_ARCHITECTURES}")
    self.arch = arch
    self.use_mla = bool(use_mla)
    if quant not in SUPPORTED_QUANTS:
      raise Exception(f"Quantization {quant} is not supported, must be one of {SUPPORTED_QUANTS}")
    self.quant: Quant = SUPPORTED_QUANTS[quant]
    if isinstance(self.quant, BlockQuant):
      is_bsize_configurable = self.quant.block_size is not None
      if is_bsize_configurable and bsize is not None:
        self.quant.block_size = (bsize, bsize)
    if arch in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
      self.dim = config["hidden_size"]
      self.hidden_dim = config["intermediate_size"]
      self.n_layers = config["num_hidden_layers"]
      if n_layers is not None and self.n_layers > n_layers:
        self.n_layers = n_layers
      self.n_heads = config["num_attention_heads"]
      self.vocab_size = config["vocab_size"]
      self.max_seq_len = tokenizer_config["model_max_length"]
      self.bos_token_id = config["bos_token_id"]
      self.eos_token_id = config["eos_token_id"]
      self.rope_theta = config.get("rope_theta", 10000.0)
      self.norm_eps = config["rms_norm_eps"]
      self.norm_type = "rmsnorm"
      
      # quantization
      self.original_quantization_config = config.get("quantization_config", None)
      if self.original_quantization_config is not None:
        dequant_block_sizes = self.original_quantization_config["weight_block_size"]
        assert type(dequant_block_sizes) == list and len(dequant_block_sizes) == 2
        assert self.original_quantization_config["quant_method"] == "fp8"

      assert config.get("attention_bias", False) == False
      assert config.get("mlp_bias", False) == False

      assert config["hidden_act"] in ["gelu", "silu"]
      self.act_type = config["hidden_act"]
      self.first_k_dense_replace = config["first_k_dense_replace"]

      # multi-latent attention
      self.kv_lora_rank = config["kv_lora_rank"]
      self.q_lora_rank = config["q_lora_rank"] or 0
      if self.use_mla:
        # TODO: support MLA with q_lora_rank == 0 (DeepSeek V2 Lite)
        assert self.q_lora_rank > 0 and self.kv_lora_rank > 0
      self.qk_nope_head_dim = config["qk_nope_head_dim"]
      self.qk_rope_head_dim = config["qk_rope_head_dim"]
      self.v_head_dim = config["v_head_dim"]

      # mixture of experts
      self.n_shared_experts = config["n_shared_experts"]
      self.n_routed_experts = config["n_routed_experts"]
      self.n_active_routed = config["num_experts_per_tok"]
      self.moe_intermediate_size = config["moe_intermediate_size"]
      self.routed_scaling_factor = config["routed_scaling_factor"]
      self.n_group = config["n_group"]
      self.norm_topk_prob = config["norm_topk_prob"]
      self.scoring_func = config["scoring_func"]
      self.topk_group = config["topk_group"]
      self.topk_method = config["topk_method"]
      if self.topk_method == "noaux_tc":
        self.topk_method = "group_limited_greedy" # TODO: support for Deepseek v3
      
      # rope
      rope_scaling = config["rope_scaling"]
      assert rope_scaling["type"] == "yarn"
      self.rope_scaling_beta_fast = rope_scaling["beta_fast"]
      self.rope_scaling_beta_slow = rope_scaling["beta_slow"]
      self.rope_scaling_factor = rope_scaling["factor"]
      self.rope_scaling_mscale = rope_scaling["mscale"]
      self.rope_scaling_mscale_all_dim = rope_scaling["mscale_all_dim"]
      self.rope_scaling_original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
  
  def to_dict(self):
    result = {}
    result["arch"] = self.arch
    result["use_mla"] = str(int(self.use_mla))
    result["quant"] = self.quant.name
    if self.arch in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
      result["dim"] = str(self.dim)
      result["hidden_dim"] = str(self.hidden_dim)
      result["n_layers"] = str(self.n_layers)
      result["n_heads"] = str(self.n_heads)
      result["vocab_size"] = str(self.vocab_size)
      result["max_seq_len"] = str(self.max_seq_len)
      result["bos_token_id"] = str(self.bos_token_id)
      result["eos_token_id"] = str(self.eos_token_id)
      result["rope_theta"] = str(self.rope_theta)
      result["norm_eps"] = str(self.norm_eps)
      result["norm_type"] = str(self.norm_type)
      result["act_type"] = str(self.act_type)
      result["first_k_dense_replace"] = str(self.first_k_dense_replace)
      # quantization
      if isinstance(self.quant, BlockQuant) and self.quant.block_size is not None:
        result["quantization_block_size_0"] = str(self.quant.block_size[0])
        result["quantization_block_size_1"] = str(self.quant.block_size[1])
      # multi-latent attention
      result["kv_lora_rank"] = str(self.kv_lora_rank)
      result["q_lora_rank"] = str(self.q_lora_rank)
      result["qk_nope_head_dim"] = str(self.qk_nope_head_dim)
      result["qk_rope_head_dim"] = str(self.qk_rope_head_dim)
      result["v_head_dim"] = str(self.v_head_dim)
      # mixture of experts
      result["n_shared_experts"] = str(self.n_shared_experts)
      result["n_routed_experts"] = str(self.n_routed_experts)
      result["n_active_routed"] = str(self.n_active_routed)
      result["moe_intermediate_size"] = str(self.moe_intermediate_size)
      result["routed_scaling_factor"] = str(self.routed_scaling_factor)
      result["n_group"] = str(self.n_group)
      result["norm_topk_prob"] = str(self.norm_topk_prob)
      result["scoring_func"] = str(self.scoring_func)
      result["topk_group"] = str(self.topk_group)
      result["topk_method"] = str(self.topk_method)
      # rope scaling
      result["rope_scaling_beta_fast"] = str(self.rope_scaling_beta_fast)
      result["rope_scaling_beta_slow"] = str(self.rope_scaling_beta_slow)
      result["rope_scaling_factor"] = str(self.rope_scaling_factor)
      result["rope_scaling_mscale"] = str(self.rope_scaling_mscale)
      result["rope_scaling_mscale_all_dim"] = str(self.rope_scaling_mscale_all_dim)
      result["rope_scaling_original_max_position_embeddings"] = str(self.rope_scaling_original_max_position_embeddings)
    return result

# this is a horrible gpt-2 unicode byte encoder hack from https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9
# this has poisoned all HF tokenizer configs that use ByteLevel decoder/preprocessor
# as a result we get crazy UTF-8-as-bytes-as-UTF8 in the tokenizer data that we need to convert back
def gpt2_bytes_to_unicode():
  bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
  cs = bs[:]
  n = 0
  for b in range(2**8):
    if b not in bs:
      bs.append(b)
      cs.append(2**8+n)
      n += 1
  cs = [chr(n) for n in cs]
  return dict(zip(bs, cs))

def load_tokens(tokenizer_path, vocab_size):
  tokens = [""] * vocab_size
  with open(tokenizer_path, "r") as f:
    tokenizer = json.load(f)
  use_gpt2_byte_preprocessing = not tokenizer["model"].get("byte_fallback", False)
  
  vocab = tokenizer["model"]["vocab"]
  assert len(vocab) <= vocab_size

  for t, i in vocab.items():
    tokens[i] = t
  
  for added in tokenizer["added_tokens"]:
    tokens[added["id"]] = added["content"]
  
  gpt2_decode = {v: k for k, v in gpt2_bytes_to_unicode().items()}
  # Preprocess tokens into UTF-8 encoding
  for i, t in enumerate(tokens):
    if use_gpt2_byte_preprocessing:
      b = bytes([gpt2_decode.get(c, 0) for c in t])
    else:
      t = t.replace('\u2581', ' ') # sentencepiece uses this character as whitespace
      b = t.encode('utf-8')
    b = b.replace(b"\0", b"\7") # replace null bytes with bell characters
    assert b.count(0) == 0 # no null bytes allowed
    tokens[i] = b
  
  return tokens

def per_tensor_quantize(tensor: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
  """Quantize a tensor using per-tensor static scaling factor.
  Args:
      tensor: The input tensor.
      dtype: The data type to quantize to.
  """
  finfo = torch.finfo(dtype)
  # Calculate the scale as dtype max divided by absmax.
  # Since .abs() creates a new tensor, we use aminmax to get
  # the min and max first and then calculate the absmax.
  if tensor.numel() == 0:
    # Deal with empty tensors (triggered by empty MoE experts)
    min_val, max_val = (
      torch.tensor(-16.0, dtype=tensor.dtype),
      torch.tensor(16.0, dtype=tensor.dtype),
    )
  else:
    min_val, max_val = tensor.aminmax()
  amax = torch.maximum(min_val.abs(), max_val.abs())
  scale = finfo.max / amax.clamp(min=1e-12)
  # scale and clamp the tensor to bring it to
  # the representative range of float8 data type
  # (as default cast is unsaturated)
  qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
  # Return both float8 data and the inverse scale (as float),
  # as both required as inputs to torch._scaled_mm
  qweight = qweight.to(dtype)
  scale = scale.float().reciprocal()
  return qweight, scale

def per_tensor_dequantize(qweight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
  assert scale.numel() == 1
  return qweight.to(torch.float32) * scale

def blockwise_dequantize(qweight: torch.Tensor, scale: torch.Tensor, block_size: torch.Tensor) -> torch.Tensor:
  assert qweight.ndim == scale.ndim and scale.ndim == block_size.numel() and scale.ndim == 2
  assert torch.all((torch.tensor(list(qweight.shape)) / block_size).ceil() == torch.tensor(list(scale.shape)))
  out = torch.empty_like(qweight, dtype=torch.float32)
  for i in range(scale.shape[0]):
    for j in range(scale.shape[1]):
      block_size_i = block_size[0]
      block_size_j = block_size[1]
      qw_block = qweight[i*block_size_i:(i+1)*block_size_i, j*block_size_j:(j+1)*block_size_j]
      out[i*block_size_i:(i+1)*block_size_i, j*block_size_j:(j+1)*block_size_j] = per_tensor_dequantize(qw_block, scale[i, j])
  return out

def blockwise_quantize(weight: torch.Tensor, block_size: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
  assert weight.ndim == block_size.numel() and weight.ndim == 2
  scale_shape = torch.Size((torch.tensor(list(weight.shape)) / block_size).ceil().long())
  scale = torch.empty(scale_shape, dtype=torch.float32)
  out = torch.empty_like(weight, dtype=dtype)
  for i in range(scale.shape[0]):
    for j in range(scale.shape[1]):
      block_size_i = block_size[0]
      block_size_j = block_size[1]
      w_block = weight[i*block_size_i:(i+1)*block_size_i, j*block_size_j:(j+1)*block_size_j]
      qw_block, scale_block = per_tensor_quantize(w_block, dtype)
      out[i*block_size_i:(i+1)*block_size_i, j*block_size_j:(j+1)*block_size_j] = qw_block
      scale[i, j] = scale_block
  return out, scale

def per_expert_blockwise_quantize(expert_weights: torch.Tensor, block_size: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
  assert expert_weights.ndim == 3
  num_experts = expert_weights.shape[0]
  output_weights = []
  scales = []
  for e in range(num_experts):
    weight, scale = blockwise_quantize(expert_weights[e], block_size, dtype)
    output_weights.append(weight)
    scales.append(scale)
  return torch.stack(output_weights), torch.stack(scales)

def per_expert_k_quantize(expert_weights: torch.Tensor, method: Literal["q2_k", "q3_k"]) -> torch.Tensor:
  assert expert_weights.ndim == 3
  num_experts = expert_weights.shape[0]
  output_weights = []
  for e in range(num_experts):
    output_weights.append(k_quantize(expert_weights[e], method))
  return torch.stack(output_weights)

def load_weights(model_files: List[str], metadata: Metadata, tie_word_embeddings: bool, n_layers: int):
  """
  Generator that yields shards of weights loaded from the model files in huggingface format.
  Each shard contains a dictionary of tensors, with weights normalized and cast to the specified dtype
  (except layer norm weights which are converted to float32).
  """
  weights = {}
  for model_path in model_files:
    ext = os.path.splitext(model_path)[1]
    if ext == ".safetensors":
      with safetensors.safe_open(model_path, framework="pt") as f:
        for k in f.keys():
          assert(k not in weights)
          weights[k] = f.get_tensor(k)
  dtype = metadata.quant.dtype

  # convert weights
  progress = 0
  dequant_block_size = None
  if metadata.original_quantization_config is not None:
    dequant_block_size = torch.tensor(metadata.original_quantization_config["weight_block_size"])
  tensors = {}

  def load_and_dequantize(weight_name: str, scale_name: str) -> torch.Tensor:
    t = weights[weight_name]
    if scale_name in weights:
      scale = weights[scale_name]
      t = blockwise_dequantize(t, scale, dequant_block_size)
    return t
  
  def quantize(t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    if dtype not in [torch.float32, torch.float16]:
      if isinstance(metadata.quant, KQuant):
        t = k_quantize(t.to(torch.float32), metadata.quant.name)
      elif metadata.quant.block_size is None:
        return per_tensor_quantize(t, dtype)
      else:
        quant_block_size = torch.tensor(metadata.quant.block_size)
        return blockwise_quantize(t, quant_block_size, dtype)
    return t.to(dtype), None

  def conv(weight_name: str, scale_name: str) -> Tuple[torch.Tensor, torch.Tensor]:
    nonlocal progress
    progress += 1
    t = load_and_dequantize(weight_name, scale_name)
    print(f"\rConverting tensor {progress}: {t.shape}", end="", flush=True)
    return quantize(t)
  
  def conv_experts(weight_and_scale_names: List[Tuple[str, str]]) -> Tuple[torch.Tensor, torch.Tensor]:
    nonlocal progress
    progress += 1
    expert_weights = [weights[weight_name] for weight_name, _ in weight_and_scale_names]
    if weight_and_scale_names[0][1] in weights:
      for i in range(len(weight_and_scale_names)):
        scale = weights[weight_and_scale_names[i][1]]
        expert_weights[i] = blockwise_dequantize(expert_weights[i], scale, dequant_block_size)
    t = torch.stack(expert_weights)
    print(f"\rConverting tensor {progress}: {t.shape}", end="", flush=True)
    if dtype not in [torch.float32, torch.float16]:
      if isinstance(metadata.quant, KQuant):
        t = per_expert_k_quantize(t.to(torch.float32), metadata.quant.name)
      elif metadata.quant.block_size is None:
        return per_tensor_quantize(t, dtype)
      else:
        quant_block_size = torch.tensor(metadata.quant.block_size)
        return per_expert_blockwise_quantize(t, quant_block_size, dtype)
    return t.to(dtype), None
  
  def save_weight_and_scale(weight_name: str, scale_name: str, weight_and_scale: Tuple[torch.Tensor, torch.Tensor]):
    tensors[weight_name] = weight_and_scale[0]
    if weight_and_scale[1] is not None:
      tensors[scale_name] = weight_and_scale[1]

  save_weight_and_scale(
    "model.embed.weight", "model.embed.scale", 
    conv("model.embed_tokens.weight", "model.embed_tokens.weight_scale_inv")
  )

  for l in range(config["num_hidden_layers"]):
    if l % 8 == 0 and l > 0:
      yield tensors
      tensors = {}
    if n_layers is not None and l >= n_layers:
      break

    tensors[f"model.layers.{l}.attn.norm.weight"] = weights[f"model.layers.{l}.input_layernorm.weight"].float()
    tensors[f"model.layers.{l}.attn.kv_a_norm.weight"] = weights[f"model.layers.{l}.self_attn.kv_a_layernorm.weight"].float()

    if metadata.use_mla:
      assert metadata.q_lora_rank > 0
      head_dim = metadata.qk_nope_head_dim + metadata.qk_rope_head_dim
      save_weight_and_scale(
        f"model.layers.{l}.attn.wkv_a.weight", f"model.layers.{l}.attn.wkv_a.scale", 
        conv(f"model.layers.{l}.self_attn.kv_a_proj_with_mqa.weight", f"model.layers.{l}.self_attn.kv_a_proj_with_mqa.weight_scale_inv")
      )
      save_weight_and_scale(
        f"model.layers.{l}.attn.wq_a.weight", f"model.layers.{l}.attn.wq_a.scale", 
        conv(f"model.layers.{l}.self_attn.q_a_proj.weight", f"model.layers.{l}.self_attn.q_a_proj.weight_scale_inv")
      )
      tensors[f"model.layers.{l}.attn.q_a_norm.weight"] = weights[f"model.layers.{l}.self_attn.q_a_layernorm.weight"].float()
      # (n_heads, head_dim-qk_rope_head_dim+v_head_dim, kv_lora_rank)
      kv_b_proj = load_and_dequantize(
        f"model.layers.{l}.self_attn.kv_b_proj.weight", f"model.layers.{l}.self_attn.kv_b_proj.weight_scale_inv"
      ).reshape(
        metadata.n_heads, -1, metadata.kv_lora_rank
      )
      # (n_heads, head_dim, q_lora_rank)
      q_b_proj = load_and_dequantize(
        f"model.layers.{l}.self_attn.q_b_proj.weight", f"model.layers.{l}.self_attn.q_b_proj.weight_scale_inv"
      ).reshape(
        metadata.n_heads, -1, metadata.q_lora_rank
      )
      # (n_heads, head_dim-qk_rope_head_dim, kv_lora_rank)
      k_nope_b_proj = kv_b_proj[:, :head_dim-metadata.qk_rope_head_dim]
      # (n_heads * v_head_dim, kv_lora_rank)
      v_b_proj = kv_b_proj[:, head_dim-metadata.qk_rope_head_dim:].reshape(
        metadata.n_heads * metadata.v_head_dim, metadata.kv_lora_rank
      )
      # (n_heads, head_dim-qk_rope_head_dim, q_lora_rank)
      q_nope_b_proj = q_b_proj[:, :head_dim-metadata.qk_rope_head_dim]
      # (n_heads, qk_rope_head_dim, q_lora_rank)
      q_rope_b_proj = q_b_proj[:, head_dim-metadata.qk_rope_head_dim:]
      # (n_heads, kv_lora_rank, q_lora_rank)
      c_proj = torch.bmm(k_nope_b_proj.transpose(1, 2), q_nope_b_proj)
      
      # NOTE: k_rope gets split from kv_a, so there is no k_rope_b_proj
      save_weight_and_scale(
        f"model.layers.{l}.attn.wq_rope_b.weight", f"model.layers.{l}.attn.wq_rope_b.scale", 
        quantize(q_rope_b_proj.reshape(-1, q_rope_b_proj.shape[-1]))
      )
      save_weight_and_scale(
        f"model.layers.{l}.attn.wc.weight", f"model.layers.{l}.attn.wc.scale", 
        quantize(c_proj.reshape(-1, c_proj.shape[-1]))
      )
      
      save_weight_and_scale(
        f"model.layers.{l}.attn.wv_b.weight", f"model.layers.{l}.attn.wv_b.scale",
        quantize(v_b_proj)
      )
      save_weight_and_scale(
        f"model.layers.{l}.attn.wo.weight", f"model.layers.{l}.attn.wo.scale", 
        conv(f"model.layers.{l}.self_attn.o_proj.weight", f"model.layers.{l}.self_attn.o_proj.weight_scale_inv")
      )
    else:
      save_weight_and_scale(
        f"model.layers.{l}.attn.wkv_a.weight", f"model.layers.{l}.attn.wkv_a.scale", 
        conv(f"model.layers.{l}.self_attn.kv_a_proj_with_mqa.weight", f"model.layers.{l}.self_attn.kv_a_proj_with_mqa.weight_scale_inv")
      )
      save_weight_and_scale(
        f"model.layers.{l}.attn.wkv_b.weight", f"model.layers.{l}.attn.wkv_b.scale", 
        conv(f"model.layers.{l}.self_attn.kv_b_proj.weight", f"model.layers.{l}.self_attn.kv_b_proj.weight_scale_inv")
      )
      save_weight_and_scale(
        f"model.layers.{l}.attn.wo.weight", f"model.layers.{l}.attn.wo.scale", 
        conv(f"model.layers.{l}.self_attn.o_proj.weight", f"model.layers.{l}.self_attn.o_proj.weight_scale_inv")
      )
      if metadata.q_lora_rank > 0:
        save_weight_and_scale(
          f"model.layers.{l}.attn.wq_a.weight", f"model.layers.{l}.attn.wq_a.scale", 
          conv(f"model.layers.{l}.self_attn.q_a_proj.weight", f"model.layers.{l}.self_attn.q_a_proj.weight_scale_inv")
        )
        save_weight_and_scale(
          f"model.layers.{l}.attn.wq_b.weight", f"model.layers.{l}.attn.wq_b.scale", 
          conv(f"model.layers.{l}.self_attn.q_b_proj.weight", f"model.layers.{l}.self_attn.q_b_proj.weight_scale_inv")
        )
        tensors[f"model.layers.{l}.attn.q_a_norm.weight"] = weights[f"model.layers.{l}.self_attn.q_a_layernorm.weight"].float()
      else:
        save_weight_and_scale(
          f"model.layers.{l}.attn.wq.weight", f"model.layers.{l}.attn.wq.scale", 
          conv(f"model.layers.{l}.self_attn.q_proj.weight", f"model.layers.{l}.self_attn.q_proj.weight_scale_inv")
        )

    tensors[f"model.layers.{l}.mlp.norm.weight"] = weights[f"model.layers.{l}.post_attention_layernorm.weight"].float()

    if l < metadata.first_k_dense_replace:
      save_weight_and_scale(
        f"model.layers.{l}.mlp.w1.weight", f"model.layers.{l}.mlp.w1.scale", 
        conv(f"model.layers.{l}.mlp.gate_proj.weight", f"model.layers.{l}.mlp.gate_proj.weight_scale_inv")
      )
      save_weight_and_scale(
        f"model.layers.{l}.mlp.w2.weight", f"model.layers.{l}.mlp.w2.scale", 
        conv(f"model.layers.{l}.mlp.down_proj.weight", f"model.layers.{l}.mlp.down_proj.weight_scale_inv")
      )
      save_weight_and_scale(
        f"model.layers.{l}.mlp.w3.weight", f"model.layers.{l}.mlp.w3.scale", 
        conv(f"model.layers.{l}.mlp.up_proj.weight", f"model.layers.{l}.mlp.up_proj.weight_scale_inv")
      )
    else:
      tensors[f"model.layers.{l}.moegate.weight"] = weights[f"model.layers.{l}.mlp.gate.weight"].float()
      if metadata.arch == "DeepseekV3ForCausalLM":
        tensors[f"model.layers.{l}.moegate.bias"] = weights[f"model.layers.{l}.mlp.gate.e_score_correction_bias"].float()
      
      save_weight_and_scale(
        f"model.layers.{l}.mlp.w1.weight", f"model.layers.{l}.mlp.w1.scale", 
        conv_experts([
          (f"model.layers.{l}.mlp.experts.{e}.gate_proj.weight", f"model.layers.{l}.mlp.experts.{e}.gate_proj.weight_scale_inv") 
          for e in range(metadata.n_routed_experts)
        ])
      )
      save_weight_and_scale(
        f"model.layers.{l}.mlp.w2.weight", f"model.layers.{l}.mlp.w2.scale", 
        conv_experts([
          (f"model.layers.{l}.mlp.experts.{e}.down_proj.weight", f"model.layers.{l}.mlp.experts.{e}.down_proj.weight_scale_inv") 
          for e in range(metadata.n_routed_experts)
        ])
      )
      save_weight_and_scale(
        f"model.layers.{l}.mlp.w3.weight", f"model.layers.{l}.mlp.w3.scale", 
        conv_experts([
          (f"model.layers.{l}.mlp.experts.{e}.up_proj.weight", f"model.layers.{l}.mlp.experts.{e}.up_proj.weight_scale_inv") 
          for e in range(metadata.n_routed_experts)
        ])
      )
      save_weight_and_scale(
        f"model.layers.{l}.shared_mlp.w1.weight", f"model.layers.{l}.shared_mlp.w1.scale", 
        conv(f"model.layers.{l}.mlp.shared_experts.gate_proj.weight", f"model.layers.{l}.mlp.shared_experts.gate_proj.weight_scale_inv")
      )
      save_weight_and_scale(
        f"model.layers.{l}.shared_mlp.w2.weight", f"model.layers.{l}.shared_mlp.w2.scale", 
        conv(f"model.layers.{l}.mlp.shared_experts.down_proj.weight", f"model.layers.{l}.mlp.shared_experts.down_proj.weight_scale_inv")
      )
      save_weight_and_scale(
        f"model.layers.{l}.shared_mlp.w3.weight", f"model.layers.{l}.shared_mlp.w3.scale", 
        conv(f"model.layers.{l}.mlp.shared_experts.up_proj.weight", f"model.layers.{l}.mlp.shared_experts.up_proj.weight_scale_inv")
      )

  tensors["model.norm.weight"] = weights["model.norm.weight"].float()
  if tie_word_embeddings == False:
    save_weight_and_scale(
      "model.output.weight", "model.output.scale", 
      conv("lm_head.weight", "lm_head.weight_scale_inv")
    )
  else:
    # Model output classifier just uses the word embeddings matrix
    pass

  print()  # newline
  yield tensors

if __name__ == "__main__":
  argp = argparse.ArgumentParser()
  argp.add_argument("output_dir", type=str)
  argp.add_argument("input", type=str, nargs="?")
  argp.add_argument("--mla", action="store_true")
  argp.add_argument("--quant", type=str, default="fp16", choices=SUPPORTED_QUANTS)
  argp.add_argument("--bsize", type=int, default=None, help="block size for blockwise quantization")
  argp.add_argument("--n-layers", type=int, default=None, help="number of layers to convert (if None, convert all)")
  args = argp.parse_args()

  if os.path.exists(args.output_dir) and not os.path.isdir(args.output_dir):
    argp.error(f"output directory {args.output_dir} already exists and is not a directory")
  os.makedirs(args.output_dir, exist_ok=True)

  if args.input is not None:
    # Input is a directory with HuggingFace layout, e.g. files:
    #   config.json
    #   tokenizer.json
    #   *.safetensors
    args.config = os.path.join(args.input, "config.json")
    if not os.path.exists(args.config):
      argp.error(f"config.json not found in {args.input}")
    
    args.tokenizer = os.path.join(args.input, "tokenizer.json")
    if not os.path.exists(args.tokenizer):
      argp.error(f"tokenizer.json not found in {args.input}")
    
    args.tokenizer_config = os.path.join(args.input, "tokenizer_config.json")
    if not os.path.exists(args.tokenizer_config):
      argp.error(f"tokenizer_config.json not found in {args.input}")
    
    files = os.listdir(args.input)
    args.models = [os.path.join(args.input, fname) for fname in files if os.path.splitext(fname)[1] == ".safetensors"]
    if len(args.models) == 0:
      argp.error(f"no .safetensors files found in {args.input}")
  else:
    argp.error("argument input is required")

  with open(args.tokenizer_config, "r") as f:
    tokenizer_config = json.load(f)
  with open(args.config, "r") as f:
    config = json.load(f)
  metadata = Metadata(config, tokenizer_config,args.quant, args.n_layers, args.mla, args.bsize)

  tokens = load_tokens(args.tokenizer, metadata.vocab_size)
  
  # Process and save weight shards
  for shard_idx, shard in enumerate(load_weights(args.models, metadata, config.get("tie_word_embeddings", None), args.n_layers)):
    if shard_idx == 0:
      shard["tokenizer.tokens"] = torch.cat([torch.tensor([x for x in b] + [0], dtype=torch.uint8) for b in tokens])
      save_file(shard, os.path.join(args.output_dir, f"shard_{shard_idx:03d}.dseek"), metadata.to_dict())
    else:
      save_file(shard, os.path.join(args.output_dir, f"shard_{shard_idx:03d}.dseek"), {})
    print(f"\nSaved shard {shard_idx}", flush=True)

================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=42", "wheel", "torch>=2.0.0", "ninja", "numpy"]
build-backend = "setuptools.build_meta" 
[project]
name = "deepseek-cpp"
version = "0.1.0"
requires-python = ">=3.8"
dependencies = [
    "safetensors",
    "torch>=2.0.0",
    "ninja",
    "numpy"
]

================================================
FILE: quantizer.cpp
================================================
#include <torch/extension.h>
#include "quant.h"

torch::Tensor quantize_q2_k(torch::Tensor& input) {
  // Row-major quantization (equivalent to block size [1, 256]) 
  // of input tensor using Q2_K scheme.
  TORCH_CHECK(input.ndimension() == 2, "input must be 2D");
  TORCH_CHECK(input.size(1) % QK_K == 0, "ncols must be divisible by QK_K");
  TORCH_CHECK(input.dtype() == torch::kFloat32, "input must be float32");
  if (!input.is_contiguous()) {
    input = input.contiguous();
  }
  const int64_t nrows = input.size(0);
  const int64_t ncols = input.size(1);
  const int64_t blocks_per_row = ncols / QK_K;
  const int64_t block_size = sizeof(block_q2_K);
  
  auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
  auto output = torch::empty({nrows, blocks_per_row * block_size}, options);
  
  const float* input_ptr = input.data_ptr<float>();
  uint8_t* output_ptr = output.data_ptr<uint8_t>();

  // Parallelize over rows
  #pragma omp parallel for
  for (int64_t row = 0; row < nrows; row++) {
    const float* row_input = input_ptr + row * ncols;
    block_q2_K* row_output = reinterpret_cast<block_q2_K*>(output_ptr + row * blocks_per_row * block_size);

    quantize_row_q2_K_ref(row_input, row_output, ncols);
  }
  
  return output;
}

torch::Tensor quantize_q3_k(torch::Tensor& input) {
  // Row-major quantization (equivalent to block size [1, 256]) 
  // of input tensor using Q3_K scheme.
  TORCH_CHECK(input.ndimension() == 2, "input must be 2D");
  TORCH_CHECK(input.size(1) % QK_K == 0, "ncols must be divisible by QK_K");
  TORCH_CHECK(input.dtype() == torch::kFloat32, "input must be float32");
  if (!input.is_contiguous()) {
    input = input.contiguous();
  }
  const int64_t nrows = input.size(0);
  const int64_t ncols = input.size(1);
  const int64_t blocks_per_row = ncols / QK_K;
  const int64_t block_size = sizeof(block_q3_K);
  
  auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
  auto output = torch::empty({nrows, blocks_per_row * block_size}, options);
  
  const float* input_ptr = input.data_ptr<float>();
  uint8_t* output_ptr = output.data_ptr<uint8_t>();

  // Parallelize over rows
  #pragma omp parallel for
  for (int64_t row = 0; row < nrows; row++) {
    const float* row_input = input_ptr + row * ncols;
    block_q3_K* row_output = reinterpret_cast<block_q3_K*>(output_ptr + row * blocks_per_row * block_size);

    quantize_row_q3_K_ref(row_input, row_output, ncols);
  }
  
  return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("quantize_q2_k", &quantize_q2_k, "Quantize a tensor to Q2_K format");
  m.def("quantize_q3_k", &quantize_q3_k, "Quantize a tensor to Q3_K format");
} 

================================================
FILE: quantizer.py
================================================
import torch
import quantizer_cpp
from typing import Literal

def k_quantize(tensor: torch.Tensor, method: Literal["q2_k", "q3_k"]) -> torch.Tensor:
  """
  Quantize a 2D float32 tensor to Q2_K or Q3_K format.
  
  Args:
    tensor: Input tensor of shape (M, N) where N must be a multiple of 256
  
  Returns:
    Quantized tensor of type uint8 and shape (M, sizeof(block_q2_K) * N/256) containing the block_q2_K data
  """ 
  if method == "q2_k":
    return quantizer_cpp.quantize_q2_k(tensor) 
  elif method == "q3_k":
    return quantizer_cpp.quantize_q3_k(tensor) 
  else:
    raise ValueError(f"Invalid method: {method}")


================================================
FILE: setup.py
================================================
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
from setuptools.dist import Distribution
import os

class BinaryDistribution(Distribution):
    def has_ext_modules(self):
        return True

setup(
  name="quantizer_cpp",
  ext_modules=[
    CppExtension(
      name="quantizer_cpp",
      sources=["quantizer.cpp", "src/quant.cpp"],
      include_dirs=[
        os.path.join(os.path.dirname(__file__), "src"), 
        os.path.join(os.path.dirname(__file__), "vendor")
      ],
      extra_compile_args=["-O3", "-march=native", "-std=c++20", "-fopenmp"],
      extra_link_args=["-fopenmp"],
    ),
  ],
  cmdclass={
    'build_ext': BuildExtension
  },
  python_requires='>=3.8',
  install_requires=[
    'torch>=2.0.0',
  ],
  setup_requires=[
    'torch>=2.0.0',
    'ninja',
    'numpy',
  ],
  distclass=BinaryDistribution,
)

================================================
FILE: src/codec.cpp
================================================
#include "codec.h"

#include "quant.h"

#include "fmt/format.h"

#include <fcntl.h>
#include <iostream>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>

std::string quant_to_string(Quant quant) {
  switch (quant) {
    case Quant::F32: return "F32";
    case Quant::F16: return "F16";
    case Quant::F8E5M2: return "F8_E5M2";
    case Quant::Q2_K: return "Q2_K";
    case Quant::Q3_K: return "Q3_K";
  }
  __builtin_unreachable();
}

std::optional<Quant> string_to_quant(const std::string& quant_str) {
  if (quant_str == "F32") {
    return Quant::F32;
  } else if (quant_str == "F16") {
    return Quant::F16;
  } else if (quant_str == "F8_E5M2") {
    return Quant::F8E5M2;
  } else if (quant_str == "Q2_K") {
    return Quant::Q2_K;
  } else if (quant_str == "Q3_K") {
    return Quant::Q3_K;
  } else {
    return std::nullopt;
  }
}

double bits_per_weight(Quant quant, size_t blockwise_quant_size) {
  if (blockwise_quant_size > 0 && quant != Quant::F8E5M2) {
    std::cerr << "blockwise quantization should only be used with F8E5M2" << std::endl;
    assert(false);
  }
  switch (quant) {
    case Quant::F32: return 32;
    case Quant::F16: return 16;
    case Quant::F8E5M2: return (8 + blockwise_quant_size) / blockwise_quant_size;
    case Quant::Q2_K: return 2.5625;
    case Quant::Q3_K: return 3.4375;
  }
  __builtin_unreachable();
}

CodecDType quant_to_codec_dtype(Quant quant) {
  switch (quant) {
    case Quant::F32: return CodecDType::F32;
    case Quant::F16: return CodecDType::F16;
    case Quant::F8E5M2: return CodecDType::F8E5M2;
    case Quant::Q2_K: return CodecDType::U8;
    case Quant::Q3_K: return CodecDType::U8;
  }
  __builtin_unreachable();
}

bool is_k_quant(Quant quant) {
  return quant == Quant::Q2_K || quant == Quant::Q3_K;
}

std::string codec_dtype_to_string(CodecDType dtype) {
  switch (dtype) {
    case CodecDType::F32: return "F32";
    case CodecDType::F16: return "F16";
    case CodecDType::BF16: return "BF16";
    case CodecDType::F8E5M2: return "F8_E5M2";
    case CodecDType::F8E4M3: return "F8_E4M3";
    case CodecDType::I32: return "I32";
    case CodecDType::I16: return "I16";
    case CodecDType::I8: return "I8";
    case CodecDType::U8: return "U8";
  }
  return "UNKNOWN";
}

std::optional<CodecDType> string_to_codec_dtype(const std::string& dtype_str) {
  if (dtype_str == "F32") {
    return CodecDType::F32;
  } else if (dtype_str == "F16") {
    return CodecDType::F16;
  } else if (dtype_str == "BF16") {
    return CodecDType::BF16;
  } else if (dtype_str == "F8_E5M2") {
    return CodecDType::F8E5M2;
  } else if (dtype_str == "F8_E4M3") {
    return CodecDType::F8E4M3;
  } else if (dtype_str == "I32") {
    return CodecDType::I32;
  } else if (dtype_str == "I16") {
    return CodecDType::I16;
  } else if (dtype_str == "I8") {
    return CodecDType::I8;
  } else if (dtype_str == "U8") {
    return CodecDType::U8;
  } else {
    return std::nullopt;
  }
}

size_t codec_dtype_size(CodecDType dtype) {
  switch (dtype) {
    case CodecDType::F32: return 4;
    case CodecDType::F16: return 2;
    case CodecDType::BF16: return 2;
    case CodecDType::F8E5M2: return 1;
    case CodecDType::F8E4M3: return 1;
    case CodecDType::I32: return 4;
    case CodecDType::I16: return 2;
    case CodecDType::I8: return 1;
    case CodecDType::U8: return 1;
  }
  return 0;
}

int Tensor::from_json(const std::string& name, const json& val, void* bytes_ptr, size_t bytes_size) {
  this->name = name;
  std::string dtype_str = val.value("dtype", ""); 
  if (auto dtype = string_to_codec_dtype(dtype_str)) {
    this->dtype = *dtype;
  } else {
    std::cerr << "bad dtype" << std::endl;
    return -1;
  }
  size_t dsize = codec_dtype_size(this->dtype);

  size_t numel = 1;
  if (val.at("shape").size() > 4) {
    std::cerr << "shape exceeds 4 dimensions" << std::endl;
  }
  for (size_t i = 0; i < val.at("shape").size() && i < 4; i++) {
    if (val.at("shape")[i].get<int>() != val.at("shape")[i]) {
      std::cerr << "bad shape" << std::endl;
      return -1;
    }
    shape[i] = val.at("shape")[i].get<int>();
    numel *= shape[i];
  }
  if (val.at("data_offsets").size() != 2) {
    return -1;
  }
  size_t offset_start = static_cast<size_t>(val.at("data_offsets")[0]);
  size_t offset_end = static_cast<size_t>(val.at("data_offsets")[1]);
  if (offset_start < 0 || offset_end <= offset_start || offset_end > bytes_size) {
    std::cerr << "bad offsets" << std::endl;
    return -1;
  }
  this->data = (char*)bytes_ptr + offset_start;
  this->size = offset_end - offset_start;
  // validate the shape matches the size
  if (numel * dsize != this->size) {
    std::cerr << "bad size" << std::endl;
    return -1;
  }
  return 0;
}

QTensor QTensor::from_codec_tensor(const Tensor& tensor, Quant weight_quant, std::array<int, 4> shape, const int debug_line) {
  QTensor qtensor;
  CodecDType expected_dtype = quant_to_codec_dtype(weight_quant);
  std::array<int, 4> expected_shape = shape;
  if (is_k_quant(weight_quant)) {
    size_t numel = 1;
    for (int i = 0; i < 4; i++) {
      if (shape[i] > 0) {
        numel *= shape[i];
      }
    }
    size_t block_size = sizeof(block_q2_K);
    switch (weight_quant) {
      case Quant::Q2_K: {
        block_size = sizeof(block_q2_K);
        break;
      }
      case Quant::Q3_K: {
        block_size = sizeof(block_q3_K);
        break;
      }
      default: {}
    }
    size_t total_blocks = numel / QK_K;
    size_t total_bytes = total_blocks * block_size;
    if (tensor.dtype != expected_dtype || tensor.size != total_bytes) {
      std::cerr << "FATAL: tensor mismatch for " << tensor.name << std::endl;
      std::cerr 
        << fmt::format(
          "expected: dtype={}, size={}", 
          codec_dtype_to_string(expected_dtype), 
          total_bytes
        ) 
        << std::endl;
      std::cerr 
        << fmt::format(
          "got: dtype={}, size={}", 
          codec_dtype_to_string(tensor.dtype), 
          tensor.size
        ) << std::endl;
      assert(false);
    }
  } else if (tensor.dtype != expected_dtype || tensor.shape != expected_shape) {
    std::cerr << "FATAL: tensor mismatch for " << tensor.name << std::endl;
    std::cerr 
      << fmt::format(
        "expected: dtype={}, shape=[{},{},{},{}]", 
        codec_dtype_to_string(expected_dtype), 
        expected_shape[0], 
        expected_shape[1], 
        expected_shape[2], 
        expected_shape[3]
      ) 
      << std::endl;
    std::cerr 
      << fmt::format(
        "got: dtype={}, shape=[{},{},{},{}]", 
        codec_dtype_to_string(tensor.dtype), 
        tensor.shape[0], tensor.shape[1], tensor.shape[2], tensor.shape[3]
      ) 
      << std::endl;
    assert(false);
  }
  qtensor.quant = weight_quant;
  qtensor.shape = shape;
  qtensor.size = tensor.size;
  qtensor.data = tensor.data;
  return qtensor;
}

size_t QTensor::ndim() const {
  for (size_t i = 0; i < shape.size(); i++) {
    if (shape[i] == 0) {
      return i;
    }
  }
  return shape.size();
}

size_t QTensor::n_elements() const {
  size_t numel = 1;
  for (size_t i = 0; i < shape.size(); i++) {
    if (shape[i] > 0) {
      numel *= shape[i];
    }
  }
  return numel;
}

YALMData::YALMData(const std::string& dirname, bool lock_model_weights) {
  if (from_directory(dirname, lock_model_weights) != 0) {
    std::cerr << "failed to load YALMData from directory" << std::endl;
    assert(false);
  }
}

int YALMData::update_from_file(const std::string& filename, bool read_metadata, bool lock_model_weights) {
  std::cout << "loading data from file: " << filename << std::endl;
  int fd = open(filename.c_str(), O_RDONLY);
  if (fd == -1) {
    return -1;
  }

  struct stat st;
  if (fstat(fd, &st) != 0) {
    close(fd);
    return -1;
  }
  
  size_t size = st.st_size;
  int mmap_flags = MAP_PRIVATE;
  if (lock_model_weights) {
    // Eagerly load memory-mapped file into memory.
    // This ensures the mlock call later is locking memory already in RAM.
    mmap_flags |= MAP_POPULATE;
  }
  void* data = mmap(NULL, size, PROT_READ | PROT_WRITE, mmap_flags, fd, 0);
  if (data == MAP_FAILED) {
    close(fd);
    return -1;
  }
  if (lock_model_weights && mlock(data, size) != 0) {
    std::cerr << "Warning: mlock failed for model data. Performance may be suboptimal. Are you running as sudo?" << std::endl;
  }

#ifdef __linux__
  // increases readahead buffer size, resulting in faster cold loads
  posix_fadvise(fd, 0, size, POSIX_FADV_SEQUENTIAL);
#endif

  close(fd);

  // Parse the metadata JSON and the tensors
  if (size < sizeof(uint64_t)) {
    munmap(data, size);
    return -1;
  }

  uint64_t json_size = *(uint64_t*)data;
  if (json_size == 0 || json_size > size - sizeof(uint64_t)) {
    munmap(data, size);
    return -1;
  }

  char* json_ptr = (char*)data + sizeof(uint64_t);
  void* bytes_ptr = (char*)data + sizeof(uint64_t) + json_size;
  size_t bytes_size = size - sizeof(uint64_t) - json_size;

  std::string json_str(json_ptr, json_size);
  json header = json::parse(json_str);

  for (auto& [key, val] : header.items()) {
    if (key == "__metadata__" && read_metadata) {
      metadata = val;
    } else if (key != "__metadata__") {
      Tensor& tensor = tensors[key];
      if (tensor.from_json(key, val, bytes_ptr, bytes_size) != 0) {
        std::cerr << "failed to parse tensor " << key << std::endl;
        munmap(data, size);
        return -1;
      }
    }
  }

  return 0;
}

int YALMData::from_directory(const std::string& dirname, bool lock_model_weights) {
  std::vector<std::string> files;
  DIR* dir = opendir(dirname.c_str());
  if (dir == nullptr) {
    std::cout << "failed to open directory" << std::endl;
    return -1;
  }

  // Collect all files
  struct dirent* entry;
  while ((entry = readdir(dir)) != nullptr) {
    std::string filename = entry->d_name;
    // Skip . and .. directory entries
    if (filename != "." && filename != "..") {
      files.push_back(dirname + "/" + filename);
    }
  }
  closedir(dir);

  if (files.empty()) {
    std::cout << "no files found" << std::endl;
    return -1;
  }

  // Sort files to ensure consistent ordering
  std::sort(files.begin(), files.end());

  // Read first file with metadata
  if (update_from_file(files[0], true, lock_model_weights) != 0) {
    std::cout << "failed to read metadata" << std::endl;
    return -1;
  }

  std::cout << "read metadata " << metadata << std::endl;

  // Read remaining files without metadata
  for (size_t i = 1; i < files.size(); i++) {
    if (update_from_file(files[i], false, lock_model_weights) != 0) {
      std::cout << "failed to read file " << files[i] << std::endl;
      return -1;
    }
  }

  return 0;
}

================================================
FILE: src/codec.h
================================================
#pragma once

#include "json.hpp"

#include <array>
#include <cstdint>
#include <string>
#include <unordered_map>
#include <dirent.h>
#include <algorithm>
#include <vector>
#include <optional>

#include "immintrin.h"
#include "f16cintrin.h"

using json = nlohmann::json;

typedef uint16_t f16_t;
typedef uint8_t f8e5m2_t;

#if defined(__AVX2__) && defined(__F16C__)
inline float half_to_float(f16_t x) {
  return _cvtsh_ss(x);
}
inline f16_t float_to_half(float x) {
  return _cvtss_sh(x, 0);
}
#else
inline float half_to_float(f16_t x) {
  assert(false && "float16 not supported on this platform");
  return 0.0f;
}
inline f16_t float_to_half(float x) {
  assert(false && "float16 not supported on this platform");
  return 0;
}
#endif

inline float float8e5m2_to_float(f8e5m2_t x) {
  f16_t val = 0;
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
  memcpy(&val, &x, sizeof(f8e5m2_t));
#else
  memcpy((char*)&val + sizeof(f8e5m2_t), &x, sizeof(f8e5m2_t));
#endif
  return half_to_float(val);
}
[[maybe_unused]] inline f8e5m2_t float_to_float8e5m2(float x) {
  f16_t val = float_to_half(x);
  f8e5m2_t out;
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
  memcpy(&out, (char*)&val, sizeof(f8e5m2_t)); // TODO: round instead of truncate?
#else
  memcpy(&out, (char*)&val + sizeof(f8e5m2_t), sizeof(f8e5m2_t)); // TODO: round instead of truncate?
#endif
  return out;
}

// Quant of tensors saved in the file.
// This corresponds to PyTorch tensor dtypes.
enum class CodecDType {
  F32,
  F16,
  BF16,
  F8E5M2,
  F8E4M3,
  I32,
  I16,
  I8,
  U8,
};
std::string codec_dtype_to_string(CodecDType dtype);
std::optional<CodecDType> string_to_codec_dtype(const std::string& dtype_str);
size_t codec_dtype_size(CodecDType dtype);

// Internal Quant.
// This corresponds to the in-memory representation of tensors in the model.
enum class Quant {
  F32,
  F16,
  F8E5M2,
  Q2_K, // 2-bit llama.cpp K-quants
  Q3_K, // 3-bit llama.cpp K-quants
};

std::string quant_to_string(Quant quant);
std::optional<Quant> string_to_quant(const std::string& quant_str);
double bits_per_weight(Quant quant, size_t blockwise_quant_size);
CodecDType quant_to_codec_dtype(Quant quant);
bool is_k_quant(Quant quant);

// Tensor data as read from the file, which serializes tensors
// in PyTorch format.
struct Tensor {
  std::string name;
  CodecDType dtype;
  std::array<int, 4> shape = {0, 0, 0, 0};
  void* data = nullptr; // not managed by Tensor
  size_t size; // size in bytes (number of elements * element size)

  // Returns 0 if successful, other if failed
  int from_json(const std::string& name, const json& j, void* bytes_ptr, size_t bytes_size);
};

// Tensor with quantization metadata.
struct QTensor {
  Quant quant = Quant::F32;
  std::array<int, 4> shape = {0, 0, 0, 0};
  void* data = nullptr; // not managed by QTensor
  size_t size = 0; // size in bytes

  QTensor() = default;
  QTensor(Quant quant, std::array<int, 4> shape, void* data, size_t size) : quant(quant), shape(shape), data(data), size(size) {}
  QTensor(const QTensor& other) = default;
  static QTensor from_codec_tensor(const Tensor& tensor, Quant weight_quant, std::array<int, 4> shape, const int debug_line);

  size_t ndim() const;
  size_t n_elements() const;
};

struct YALMData {
  json metadata;
  std::unordered_map<std::string, Tensor> tensors;

  YALMData(const std::string& dirname, bool lock_model_weights);

private:
  // Update YALMData with tensors from a file
  // If read_metadata is true, also update metadata from this file
  // Returns 0 if successful, other if failed
  int update_from_file(const std::string& filename, bool read_metadata, bool lock_model_weights);

  // Initialize YALMData from all files in a directory
  // Metadata is read from the first file (in sorted order)
  // Returns 0 if successful, other if failed
  int from_directory(const std::string& dirname, bool lock_model_weights);
};

================================================
FILE: src/debug.cpp
================================================
#include "debug.h"
#include "model.h"

template <typename T>
bool BinaryDumper::save(const std::string& filename, const T* data, size_t count) {
  std::ofstream file(filename, std::ios::binary);
  if (!file) return false;
  
  // Write count first
  file.write(reinterpret_cast<const char*>(&count), sizeof(count));
  // Write T data
  file.write(reinterpret_cast<const char*>(data), count * sizeof(T));
  
  return file.good();
}

template bool BinaryDumper::save<float>(const std::string&, const float*, size_t);
template bool BinaryDumper::save<f16_t>(const std::string&, const f16_t*, size_t);

template <typename T>
std::vector<T> BinaryDumper::load(const std::string& filename) {
  std::ifstream file(filename, std::ios::binary);
  if (!file) return {};
  
  // Read count
  size_t count;
  file.read(reinterpret_cast<char*>(&count), sizeof(count));
  
  // Read T data
  std::vector<T> data(count);
  file.read(reinterpret_cast<char*>(data.data()), count * sizeof(T));
  
  if (!file.good()) return {};
  return data;
}

template std::vector<float> BinaryDumper::load(const std::string&);
template std::vector<f16_t> BinaryDumper::load(const std::string&);

================================================
FILE: src/debug.h
================================================
#include <fstream>
#include <vector>
#include <cstdint>
#include <iostream>

struct BinaryDumper {
  // Save T array to binary file
  template <typename T>
  static bool save(const std::string& filename, const T* data, size_t count);
  
  // Load T array from binary file
  template <typename T>
  static std::vector<T> load(const std::string& filename);
};

================================================
FILE: src/infer.cpp
================================================
#include "model.h"

#include <assert.h>
#include <cfloat>
#include <math.h>

#include "quant.h"
#include "profile.h"

#if DEBUG_MODEL
#include "json.hpp"
#include <fstream>
#include "fmt/format.h"
static std::map<std::string, DebugTensor> _debug_map;
std::map<std::string, DebugTensor>& debug_map_cpu() {
  return _debug_map;
}
template <typename T>
static std::vector<T> copy_debug_tensor(T* x, size_t size) {
  std::vector<T> out(size);
  for (size_t i = 0; i < size; i++) {
    out[i] = x[i];
  }
  return out;
}
template <typename T>
static void save_debug_tensor(const std::string& name, T* x, size_t size) {
  _debug_map[name] = DebugTensor(copy_debug_tensor<T>(x, size));
}
void dump_debug_map(const std::string& filename) {
  std::ofstream out(filename);
  if (!out.is_open()) {
    fprintf(stderr, "Failed to open %s for writing\n", filename.c_str());
    return;
  }

  // Write Python imports
  out << "import torch\n\n";
  out << "debug_tensors = {\n";

  // Iterate through debug map and write each tensor
  bool first = true;
  for (const auto& pair : _debug_map) {
    if (!first) {
      out << ",\n";
    }
    first = false;

    const std::string& name = pair.first;
    const DebugTensor& tensor = pair.second;

    out << "    '" << name << "': torch.tensor([";

    // Write tensor values
    bool first_val = true;
    assert(tensor.data_type == DebugTensor::DataType::F32);
    for (const auto& val : tensor.data_f32) {
      if (!first_val) {
        out << ", ";
      }
      first_val = false;
      
      // Use scientific notation with high precision
      out << std::scientific << std::setprecision(8) << val;
    }
    
    out << "])";
  }
  
  out << "\n}\n";
  out.close();
}
void dump_debug_map_as_safetensors(const std::string& filename) {
  std::ofstream out(filename, std::ios::binary);
  if (!out.is_open()) {
    fprintf(stderr, "Failed to open %s for writing\n", filename.c_str());
    return;
  }

  json header;
  size_t offset = 0;
  for (auto& [key, val] : _debug_map) {
    size_t offset_end = offset;
    CodecDType dtype = val.data_type == DebugTensor::DataType::F32 ? CodecDType::F32 : CodecDType::F16;
    if (dtype == CodecDType::F32) {
      offset_end += val.data_f32.size() * sizeof(float);
      header[key] = {
        {"dtype", codec_dtype_to_string(dtype)},
        {"shape", {val.data_f32.size()}},
        {"data_offsets", {offset, offset_end}}
      };
    } else {
      offset_end += val.data_f16.size() * sizeof(f16_t);
      header[key] = {
        {"dtype", codec_dtype_to_string(dtype)},
        {"shape", {val.data_f16.size()}},
        {"data_offsets", {offset, offset_end}}
      };
    }
    offset = offset_end;
  }
  header["__metadata__"] = {{"debug", ""}};
  std::string header_str = header.dump();
  // 1. write uint64 (size of json header)
  uint64_t header_len = static_cast<uint64_t>(header_str.size());
  out.write(reinterpret_cast<const char*>(&header_len), sizeof(header_len));
  // 2. write json header
  out.write(header_str.c_str(), header_len);
  // 3. write tensor data
  for (auto& [key, val] : _debug_map) {
    if (val.data_type == DebugTensor::DataType::F32) {
      out.write(reinterpret_cast<const char*>(val.data_f32.data()), val.data_f32.size() * sizeof(float));
    } else {
      out.write(reinterpret_cast<const char*>(val.data_f16.data()), val.data_f16.size() * sizeof(f16_t));
    }
  }
  out.close();
}
#endif

static void _matmul(
  float* xout, float* x, float* w, int n, int d, 
  const int* block_size, float* scale,
  void* unused_aqb
) {
  // W (d,n) @ x (n,) -> xout (d,)
  (void)unused_aqb;
  static float one = 1.0f;
  int dummy_block_size[2] = {d, n};
  if (scale == nullptr) {
    scale = &one;
    block_size = dummy_block_size;
  }
  int scale_num_cols = (n + block_size[1] - 1) / block_size[1];
  for (int scale_i = 0; scale_i < cdiv(d, block_size[0]); scale_i++) {
    int ii;
#pragma omp parallel for private(ii)
    for (ii = 0; ii < block_size[0]; ii++) {
      int i = scale_i * block_size[0] + ii;
      if (i >= d) {
        continue;
      }
      float val = 0.0f;
      for (int scale_j = 0; scale_j < cdiv(n, block_size[1]); scale_j++) {
        float scale_val = scale[scale_i * scale_num_cols + scale_j];
        for (int jj = 0; jj < block_size[1]; jj++) {
          int j = scale_j * block_size[1] + jj;
          if (j >= n) {
            break;
          }
          val += (w[i * n + j] * x[j]) * scale_val;
        }
      }
      xout[i] = val;
    }
  }
}

// matmul supporting float16 weights via the F16C extension, which allows
// conversion into float32 values before calculations.
static void _matmul(
  float* xout, float* x, f16_t* w, int n, int d, 
  const int* block_size, float* scale,
  void* unused_aqb
) {
  (void)unused_aqb;
#if defined(__AVX2__) && defined(__F16C__)
  // W (d,n) @ x (n,) -> xout (d,)
  assert(n % 16 == 0);
  assert(scale == nullptr || block_size[1] % 16 == 0);
  static float one = 1.0f;
  int dummy_block_size[2] = {d, n};
  if (scale == nullptr) {
    scale = &one;
    block_size = dummy_block_size;
  }
  int scale_num_cols = (n + block_size[1] - 1) / block_size[1];
  for (int scale_i = 0; scale_i < cdiv(d, block_size[0]); scale_i++) {
    int ii;
#pragma omp parallel for private(ii)
    for (ii = 0; ii < block_size[0]; ii++) {
      int i = scale_i * block_size[0] + ii;
      if (i >= d) {
        continue;
      }
      // Vectorized dot product of w[i][:] and x[:] where w is a packed float16 array.
      __m256 sumlo = _mm256_setzero_ps();
      __m256 sumhi = _mm256_setzero_ps();
      for (int scale_j = 0; scale_j < cdiv(n, block_size[1]); scale_j++) {
        // Broadcast scale_val to all elements of a vector
        float scale_val = scale[scale_i * scale_num_cols + scale_j];
        __m256 scale_vec = _mm256_set1_ps(scale_val);
        for (int jj = 0; jj < block_size[1]; jj+=16) {
          int j = scale_j * block_size[1] + jj;
          if (j >= n) {
            break;
          }
          
          // Extract the next set of 16 float16 weights from `w` and store them
          // to two separate float32 vectors of width 8 (`wveclo_ps`, `wvechi_ps`)
          __m256i wvec = _mm256_loadu_si256((__m256i*)&w[i * n + j]);
          __m128i wveclo = _mm256_extractf128_si256(wvec, 0);
          __m128i wvechi = _mm256_extractf128_si256(wvec, 1);
          __m256 wveclo_ps = _mm256_cvtph_ps(wveclo);
          __m256 wvechi_ps = _mm256_cvtph_ps(wvechi);
          
          // Scale the weight vectors
          wveclo_ps = _mm256_mul_ps(wveclo_ps, scale_vec);
          wvechi_ps = _mm256_mul_ps(wvechi_ps, scale_vec);
          
          // Extract the next two float32 vectors of width 8 `xveclo`, `xvechi` from `x`
          __m256 xveclo = _mm256_loadu_ps(&x[j]);
          __m256 xvechi = _mm256_loadu_ps(&x[j + 8]);
          
          // Compute vectorized FMAs: sumlo += wveclo * xveclo, sumhi += wvechi * xvechi
          sumlo = _mm256_fmadd_ps(wveclo_ps, xveclo, sumlo);
          sumhi = _mm256_fmadd_ps(wvechi_ps, xvechi, sumhi);
        }
      }
      // Horizontally reduce width-8 float32 vectors sumlo, sumhi to a scalar.
      __m256 sum8 = _mm256_add_ps(sumlo, sumhi);              // sum8[0:8] = sumlo[0:8] + sumhi[0:8]
      __m128 sum4 = _mm_add_ps(                               // sum4[0:4] = sum8[0:4] + sum8[4:8]
        _mm256_extractf128_ps(sum8, 0), 
        _mm256_extractf128_ps(sum8, 1)
      );
      __m128 sum1 = _mm_dp_ps(sum4, _mm_set1_ps(1.0f), 0xf1); // sum1[0] = dot(sum4, [1,1,1,1])
      xout[i] = _mm_cvtss_f32(sum1);
    }
  }
#else
  assert(false && "float16 not supported on this platform");
#endif
}

// matmul supporting float8e5m2 weights via AVX2 and F16C extensions, which (1) 
// allows vectorized conversion from f8e5m2 to float16 and (2) conversion from 
// float16 to float32 values before calculations.
static void _matmul(
  float* xout, float* x, f8e5m2_t* w, int n, int d, 
  const int* block_size, float* scale,
  void* unused_aqb
) {
  (void)unused_aqb;
#if defined(__AVX2__) && defined(__F16C__)
  // W (d,n) @ x (n,) -> xout (d,)
  assert(n % 16 == 0);
  assert(scale == nullptr || block_size[1] % 16 == 0);
  static float one = 1.0f;
  int dummy_block_size[2] = {d, n};
  if (scale == nullptr) {
    scale = &one;
    block_size = dummy_block_size;
  }
  int scale_num_cols = (n + block_size[1] - 1) / block_size[1];
  for (int scale_i = 0; scale_i < cdiv(d, block_size[0]); scale_i++) {
    int ii;
#pragma omp parallel for private(ii)
    for (ii = 0; ii < block_size[0]; ii++) {
      int i = scale_i * block_size[0] + ii;
      if (i >= d) {
        continue;
      }
      // Vectorized dot product of w[i][:] and x[:] where w is a packed float8e5m2 array.
      __m256 sumlo = _mm256_setzero_ps();
      __m256 sumhi = _mm256_setzero_ps();
      for (int scale_j = 0; scale_j < cdiv(n, block_size[1]); scale_j++) {
        // Broadcast scale_val to all elements of a vector
        float scale_val = scale[scale_i * scale_num_cols + scale_j];
        __m256 scale_vec = _mm256_set1_ps(scale_val);
        for (int jj = 0; jj < block_size[1]; jj+=16) {
          int j = scale_j * block_size[1] + jj;
          if (j >= n) {
            break;
          }

          // Extract the next set of 16 float8e5m2 weights from `w` and store them
          // to two separate float32 vectors of width 8 (`wveclo_ps`, `wvechi_ps`)
          __m128i wvec = _mm_loadu_si128((__m128i*)&w[i * n + j]);
          // Take each half of `wvec` which consists of 8 float8e5m2 weights and
          // pad each 8-bit float8e5m2 value with 8 zeros in the mantissa (least significant bits),
          // converting to 8 float16 values.
          __m128i wveclo = _mm_unpacklo_epi8(_mm_setzero_si128(), wvec);
          __m128i wvechi = _mm_unpackhi_epi8(_mm_setzero_si128(), wvec);
          // Widen each 8xf16 vector to 8xf32.
          __m256 wveclo_ps = _mm256_cvtph_ps(wveclo);
          __m256 wvechi_ps = _mm256_cvtph_ps(wvechi);
          
          // Scale the weight vectors
          wveclo_ps = _mm256_mul_ps(wveclo_ps, scale_vec);
          wvechi_ps = _mm256_mul_ps(wvechi_ps, scale_vec);
          
          // Extract the next two float32 vectors of width 8 `xveclo`, `xvechi` from `x`
          __m256 xveclo = _mm256_loadu_ps(&x[j]);
          __m256 xvechi = _mm256_loadu_ps(&x[j + 8]);
          // Compute vectorized FMAs: sumlo += wveclo * xveclo, sumhi += wvechi * xvechi
          sumlo = _mm256_fmadd_ps(wveclo_ps, xveclo, sumlo);
          sumhi = _mm256_fmadd_ps(wvechi_ps, xvechi, sumhi);
        }
      }
      // Horizontally reduce width-8 float32 vectors sumlo, sumhi to a scalar.
      __m256 sum8 = _mm256_add_ps(sumlo, sumhi);              // sum8[0:8] = sumlo[0:8] + sumhi[0:8]
      __m128 sum4 = _mm_add_ps(                               // sum4[0:4] = sum8[0:4] + sum8[4:8]
        _mm256_extractf128_ps(sum8, 0), 
        _mm256_extractf128_ps(sum8, 1)
      );
      __m128 sum1 = _mm_dp_ps(sum4, _mm_set1_ps(1.0f), 0xf1); // sum1[0] = dot(sum4, [1,1,1,1])
      xout[i] = _mm_cvtss_f32(sum1);
    }
  }
#else
  assert(false && "float8e5m2 not supported on this platform");
#endif
}

static void _matmul(
  float* xout, float* x, block_q2_K* w, int n, int d, 
  const int* unused_block_size, float* unused_scale,
  void* aqb
) {
  // W (d,n) @ x (n,) -> xout (d,)
  (void)unused_block_size;
  (void)unused_scale;
  size_t blocks_per_row = n / QK_K;
  block_q8_K* aqb_q8 = (block_q8_K*)aqb;
  int chunk_size = QK_K * 2;
  int num_chunks = cdiv(n, chunk_size);
  {
    PROFILE_BLOCK("quantize_acts");
    #pragma omp parallel for
      for (int i = 0; i < num_chunks; i++) {
        int start = i * chunk_size;
        int k = (i == num_chunks - 1) ? (n - start) : chunk_size;
        if (k > 0) {
          quantize_row_q8_K_ref(x + start, aqb_q8 + (start/QK_K), k);
        }
      }
  }
  {
    PROFILE_BLOCK("matmul_w2a8");
    int i;
  #pragma omp parallel for private(i)
    for (i = 0; i < d; i++) {
      ggml_vec_dot_q2_K_q8_K(n, xout + i, w + i * blocks_per_row, aqb_q8);
    }
  }
}

static void _matmul(
  float* xout, float* x, block_q3_K* w, int n, int d, 
  const int* unused_block_size, float* unused_scale,
  void* aqb
) {
  // W (d,n) @ x (n,) -> xout (d,)
  (void)unused_block_size;
  (void)unused_scale;
  size_t blocks_per_row = n / QK_K;
  block_q8_K* aqb_q8 = (block_q8_K*)aqb;
  int chunk_size = QK_K * 2;
  int num_chunks = cdiv(n, chunk_size);
  {
    PROFILE_BLOCK("quantize_acts");
    #pragma omp parallel for
      for (int i = 0; i < num_chunks; i++) {
        int start = i * chunk_size;
        int k = (i == num_chunks - 1) ? (n - start) : chunk_size;
        if (k > 0) {
          quantize_row_q8_K_ref(x + start, aqb_q8 + (start/QK_K), k);
        }
      }
  }
  {
    PROFILE_BLOCK("matmul_w3a8");
    int i;
  #pragma omp parallel for private(i)
    for (i = 0; i < d; i++) {
      ggml_vec_dot_q3_K_q8_K(n, xout + i, w + i * blocks_per_row, aqb_q8);
    }
  }
}

static void matmul(
  float* xout, float* x, const QTensor& w,
  const int* block_size, std::optional<QTensor> scale,
  void* aqb
) {
  // W (d,n) @ x (n,) -> xout (d,)
  int n = w.shape[1];
  int d = w.shape[0];
  float* scale_data = nullptr;
  if (scale) {
    assert(scale->quant == Quant::F32);
    scale_data = static_cast<float*>(scale->data);
  }
  switch (w.quant) {
    case Quant::F32: {
      _matmul(xout, x, static_cast<float*>(w.data), n, d, block_size, scale_data, aqb);
      break;
    }
    case Quant::F16: {
      _matmul(xout, x, static_cast<f16_t*>(w.data), n, d, block_size, scale_data, aqb);
      break;
    }
    case Quant::F8E5M2: {
      _matmul(xout, x, static_cast<f8e5m2_t*>(w.data), n, d, block_size, scale_data, aqb);  
      break;
    }
    case Quant::Q2_K: {
      _matmul(xout, x, static_cast<block_q2_K*>(w.data), n, d, block_size, scale_data, aqb);
      break;
    }
    case Quant::Q3_K: {
      _matmul(xout, x, static_cast<block_q3_K*>(w.data), n, d, block_size, scale_data, aqb);
      break;
    }
    default: assert(false);
  }
}

void matmul_unscaled(float* xout, float* x, const QTensor& w) {
  matmul(xout, x, w, nullptr, std::nullopt, nullptr);
}

static void matmul_expert(
  float* xout, float* x, 
  const QTensor& w_experts, int expert_index,
  const int* block_size, std::optional<QTensor> scale_experts,
  void* aqb
) {
  // W_experts (n_experts,d,n)
  // W (d,n) @ x (n,) -> xout (d,)
  int n = w_experts.shape[2];
  int d = w_experts.shape[1];
  size_t expert_size = n * d;
  float* scale_data = nullptr;
  if (scale_experts) {
    assert(scale_experts->quant == Quant::F32);
    int expert_scale_size = cdiv(d, block_size[0]) * cdiv(n, block_size[1]);
    size_t scale_offset = expert_index * expert_scale_size;
    scale_data = static_cast<float*>(scale_experts->data) + scale_offset;
  }
  size_t weight_offset = expert_index * expert_size;
  if (is_k_quant(w_experts.quant)) {
    // In K-quants, each element of the weight tensor is a block of QK_K elements
    weight_offset = weight_offset / QK_K;
  }
  switch (w_experts.quant) {
    case Quant::F32: {
      _matmul(xout, x, static_cast<float*>(w_experts.data) + weight_offset, n, d, block_size, scale_data, aqb);
      break;
    }
    case Quant::F16: {
      _matmul(xout, x, static_cast<f16_t*>(w_experts.data) + weight_offset, n, d, block_size, scale_data, aqb);
      break;
    }
    case Quant::F8E5M2: {
      _matmul(xout, x, static_cast<f8e5m2_t*>(w_experts.data) + weight_offset, n, d, block_size, scale_data, aqb);
      break;
    }
    case Quant::Q2_K: {
      _matmul(xout, x, static_cast<block_q2_K*>(w_experts.data) + weight_offset, n, d, block_size, scale_data, aqb);
      break;
    }
    case Quant::Q3_K: {
      _matmul(xout, x, static_cast<block_q3_K*>(w_experts.data) + weight_offset, n, d, block_size, scale_data, aqb);
      break;
    }
    default: assert(false);
  }
}

// Compute the softmax of an input vector `x` of length `size` and store it in `o`.
static void softmax(float* o, float* x, int size) {
  float score_max = -FLT_MAX;
  for (int i = 0; i < size; ++i) {
    if (x[i] > score_max) {
      score_max = x[i];
    }
  }
  float score_sum = 0.0f;
  for (int i = 0; i < size; ++i) {
    o[i] = expf(x[i] - score_max);
    score_sum += o[i];
  }
  for (int i = 0; i < size; ++i) {
    o[i] /= score_sum;
  }
}

inline float sigmoid(float x) {
  return 1.0f / (1.0f + expf(-x));
}

static void moe_gate(
  float* moe_weights,
  std::optional<QTensor> moegate_bias,
  int* active_experts,
  float* x,
  int n_routed_experts,
  int n_active_routed,
  bool norm_topk_prob,
  float routed_scaling_factor,
  ScoringFunc scoring_func,
  TopKMethod topk_method,
  int n_group,
  int topk_group
) {
  // Set moe_weights[:n_active_routed] to the weights of the top K experts.
  // Set active_experts[:n_active_routed] to the indices of the top K experts.
  if (scoring_func == ScoringFunc::SOFTMAX) {
    softmax(x, x, n_routed_experts);
  } else if (scoring_func == ScoringFunc::SIGMOID) {
    for (int i = 0; i < n_routed_experts; i++) {
      x[i] = sigmoid(x[i]);
    }
  }

  if (moegate_bias) {
    float* bias_data = static_cast<float*>(moegate_bias->data);
    for (int i = 0; i < n_routed_experts; ++i) {
      x[i] += bias_data[i];
    }
  }

  // top k
  float wsum = 0.0f;
  if (topk_method == TopKMethod::GREEDY) {
    assert(n_routed_experts <= 256);
    std::array<uint8_t, 32> mask{};
    for (int k = 0; k < n_active_routed; ++k) {
      int best = -1;
      for (int j = 0; j < n_routed_experts; ++j) {
        int mask_i = j / 8;
        int mask_r = j % 8;
        if ((mask[mask_i] & (1ull << mask_r)) == 0 && (best == -1 || x[j] > x[best])) {
          best = j;
        }
      }

      active_experts[k] = best;
      wsum += x[active_experts[k]];
      int best_mask_i = best / 8;
      int best_mask_r = best % 8;
      mask[best_mask_i] |= 1ull << best_mask_r;
    }
  } else if (topk_method == TopKMethod::GROUP_LIMITED_GREEDY) {
    int group_size = n_routed_experts / n_group;
    
    // First pass: select topk_group within each group
    std::array<uint8_t, 32> mask{};
    
    for (int g = 0; g < n_group; g++) {
      // Select topk_group items from this group
      for (int k = 0; k < topk_group; k++) {
        int best = -1;
        for (int j = g*group_size; j < (g+1)*group_size; j++) {
          int mask_i = j / 8;
          int mask_r = j % 8;
          if ((mask[mask_i] & (1u << mask_r)) == 0 && x[j] > x[best]) {
            best = j;
          }
        }
        int best_mask_i = best / 8;
        int best_mask_r = best % 8;
        mask[best_mask_i] |= 1u << best_mask_r;
      }
    }
    // Flip mask so that now we only look at the topk_group items in each group
    for (int i = 0; i < 32; i++) {
      mask[i] = ~mask[i];
    }
    
    // Second pass: select top n_active_routed overall
    for (int k = 0; k < n_active_routed; ++k) {
      int best = -1;
      for (int j = 0; j < n_routed_experts; ++j) {
        int mask_i = j / 8;
        int mask_r = j % 8;
        if ((mask[mask_i] & (1ull << mask_r)) == 0 && (best == -1 || x[j] > x[best])) {
          best = j;
        }
      }

      active_experts[k] = best;
      wsum += x[active_experts[k]];
      int best_mask_i = best / 8;
      int best_mask_r = best % 8;
      mask[best_mask_i] |= 1ull << best_mask_r;
    }
  } else if (topk_method == TopKMethod::NOAUX_TC) {
    assert(false && "TODO: implement noaux_tc");
  }

  if (!norm_topk_prob) {
    wsum = 1.0;
  }
  for (int k = 0; k < n_active_routed; ++k) {
    moe_weights[k] = x[active_experts[k]] / wsum * routed_scaling_factor;
  }
}

static void rmsnorm(float* o, float* x, float* weight, int size, float eps) {
  float rms = 0.0f;
  for (int i = 0; i < size; ++i) {
    rms += x[i] * x[i];
  }
  rms = sqrtf(rms / size + eps);
  float scale = 1.0f / rms;
  for (int i = 0; i < size; ++i) {
    o[i] = x[i] * scale * weight[i];
  }
}

[[maybe_unused]] static void layernorm(float* o, float* x, float* weight, float* bias, int size, float eps) {
  float mean = 0.0f;
  for (int i = 0; i < size; ++i) {
    mean += x[i];
  }
  mean /= size;
  float var = 0.0f;
  for (int i = 0; i < size; ++i) {
    var += (x[i] - mean) * (x[i] - mean);
  }
  var /= size;
  float scale = 1.0f / sqrtf(var + eps);
  if (bias) {
    for (int i = 0; i < size; ++i) {
      o[i] = (x[i] - mean) * scale * weight[i] + bias[i];
    }
  } else {
    for (int i = 0; i < size; ++i) {
      o[i] = (x[i] - mean) * scale * weight[i];
    }
  }
}

inline float gelu(float x) {
  return 0.5f * x * (1.0f + tanhf(0.797885f * (x + 0.044715f * x * x * x)));
}

inline float silu(float x) {
  return x / (1.0f + expf(-x));
}

inline float clip(float x, float v) {
  return x < -v ? -v : (x > v ? v : x);
}

static void rope(float* buf, float* vec, int d, int head_dim, int pos, float theta) {
  // For some reason, DeepSeek-V2 was trained using rope output 
  // layout transposed compared to the input. This means we need a buffer
  // to hold intermediate results.
  assert(d % 2 == 0);
  for (int i = 0; i < d; i += 2) {
    int j_head = i % head_dim;
    float freq = 1.0f / powf(theta, (float)j_head / (float)head_dim);
    float val = pos * freq;
    float fcr = cosf(val);
    float fci = sinf(val);

    float v0 = vec[i];
    float v1 = vec[i + 1];
    buf[i/2] = v0 * fcr - v1 * fci;
    buf[i/2 + d/2] = v0 * fci + v1 * fcr;
  }
  for (int i = 0; i < d; i++) {
    vec[i] = buf[i];
  }
}

static void rope_v3(float* vec, int d, int head_dim, int pos, float theta) {
  int rotary_dim = head_dim;

  for (int i = 0; i < d; i += 2) {
    int j_head = i % head_dim;
    float freq = j_head >= rotary_dim ? 0.f : 1.0f / powf(theta, (float)j_head / (float)rotary_dim);
    float val = pos * freq;
    float fcr = cosf(val);
    float fci = sinf(val);

    float v0 = vec[i];
    float v1 = vec[i + 1];
    vec[i] = v0 * fcr - v1 * fci;
    vec[i + 1] = v0 * fci + v1 * fcr;
  }
}

static void rope(float* buf, f16_t* vec, int d, int head_dim, int pos, float theta) {
  // For some reason, DeepSeek-V2 was trained using rope output
  // layout transposed compared to the input. This means we need a buffer
  // to hold intermediate results.
  assert(d % 2 == 0);
  for (int i = 0; i < d; i += 2) {
    int j_head = i % head_dim;
    float freq = 1.0f / powf(theta, (float)j_head / (float)head_dim);
    float val = pos * freq;
    float fcr = cosf(val);
    float fci = sinf(val);

    float v0 = half_to_float(vec[i]);
    float v1 = half_to_float(vec[i + 1]);
    buf[i/2] = v0 * fcr - v1 * fci;
    buf[i/2 + d/2] = v0 * fci + v1 * fcr;
  }
  for (int i = 0; i < d; i++) {
    vec[i] = float_to_half(buf[i]);
  }
}

static void rope_v3(f16_t* vec, int d, int head_dim, int pos, float theta) {
  int rotary_dim = head_dim;

  for (int i = 0; i < d; i += 2) {
    int j_head = i % head_dim;
    float freq = j_head >= rotary_dim ? 0.f : 1.0f / powf(theta, (float)j_head / (float)rotary_dim);
    float val = pos * freq;
    float fcr = cosf(val);
    float fci = sinf(val);

    float v0 = half_to_float(vec[i]);
    float v1 = half_to_float(vec[i + 1]);
    vec[i] = float_to_half(v0 * fcr - v1 * fci);
    vec[i + 1] = float_to_half(v0 * fci + v1 * fcr);
  }
}


// Compute next value in a sequence for a single causal self-attention head.
void attn(
  float* xout,    // (n_heads * v_head_dim,) - output vector
  float* atth,    // (kv_len,) - scratch space to hold attention scores of the sequence
  const float* qh,      // (head_dim,) - query vector for this head
  const f16_t* kh,      // (kv_len, n_heads, head_dim) - buffer containing key vectors of the sequence for all KV heads
  const f16_t* vh,      // (kv_len, n_heads, v_head_dim) - buffer containing value vectors of the sequence for all KV heads
  int head_dim,   // size of the "key-space"
  int v_head_dim, // size of the "value-space"
  int n_heads, // number of attention heads
  int kv_len      // number of tokens of the sequence we will attend over
) {
  int k_stride = n_heads * head_dim; // stride per token in this k head
  // calculate attention scores as dot products of q and k
  for (int t = 0; t < kv_len; ++t) {
    float score = 0.0f;
    for (int i = 0; i < head_dim; ++i) {
      score += qh[i] * half_to_float(kh[t * k_stride + i]);
    }
    score /= sqrtf(head_dim);
    atth[t] = score;
  }

  // softmax the scores to get attention weights over [0..kv_len)
  softmax(atth, atth, kv_len);

  int v_stride = n_heads * v_head_dim; // stride per token in this v head
  // mix values with attention weights
  for (int i = 0; i < v_head_dim; ++i) {
    float vi = 0.0f;
    for (int t = 0; t < kv_len; ++t) {
      vi += atth[t] * half_to_float(vh[t * v_stride + i]);
    }
    xout[i] = vi;
  }
}

// Compute next value in a sequence for a single causal self-attention head.
// MLA variant: uses combined latent-KV cache and PE-KV cache.
void attn_mla(
  float* xout,    // (n_heads * kv_lora_rank,) - output vector
  float* atth,    // (kv_len,) - scratch space to hold attention scores of the sequence
  const float* qh_c,    // (kv_lora_rank,) - transformed latent query vector for this head
  const float* qh_rope, // (qk_rope_head_dim,) - PE-query vector for this head
  const f16_t* compressed_kv,      // (kv_len, kv_lora_rank) - buffer containing latent vectors of the sequence
  const f16_t* k_rope,  // (kv_len, qk_rope_head_dim) - buffer containing PE key-vectors of the sequence
  int head_dim, // used for softmax scale factor
  int kv_lora_rank, // size of the "latent-space"
  int qk_rope_head_dim, // size of the "PE-space"
  int kv_len   // number of tokens of the sequence we will attend over
) {
  int kv_stride = kv_lora_rank; // stride per token in the latent buffer
  int k_rope_stride = qk_rope_head_dim; // stride per token in the PE buffer
  // calculate attention scores as dot products of q and k
  for (int t = 0; t < kv_len; ++t) {
    float score = 0.0f;
    for (int i = 0; i < kv_lora_rank; ++i) {
      score += qh_c[i] * half_to_float(compressed_kv[t * kv_stride + i]);
    }
    for (int i = 0; i < qk_rope_head_dim; ++i) {
      score += qh_rope[i] * half_to_float(k_rope[t * k_rope_stride + i]);
    }
    score /= sqrtf(head_dim);
    atth[t] = score;
  }

  // softmax the scores to get attention weights over [0..kv_len)
  softmax(atth, atth, kv_len);

  // mix latents with attention weights
  for (int i = 0; i < kv_lora_rank; ++i) {
    float vi = 0.0f;
    for (int t = 0; t < kv_len; ++t) {
      vi += atth[t] * half_to_float(compressed_kv[t * kv_stride + i]);
    }
    xout[i] = vi;
  }
}

// Compute forward pass for a single block and update the inference state accordingly.
// PRECONDITIONS: 
// - `s.x()` contains the input to the block. Output will also go here.
// - Block KV cache is hydrated.
template <typename T>
void Block::_block_cpu(
  InferenceState& s,  // inference state
  int pos,            // index of the current token in the sequence
  int kv_sink,        // number of sink tokens currently in the KV cache
  int kv_pos,         // index of the current token in the kv cache, must be in [0..kv_len) since kv cache is a ring buffer
  int kv_len          // number of tokens in the kv cache that we will attend over
) const {
  const Config& c = *_config;

  // Attention pre-norm
  switch (c.norm_type) {
    case LayerNormType::RMSNorm: {
      rmsnorm(s.xb(), s.x(), rms_att_weight(), c.dim, c.norm_eps);
      break;
    }
  }

  // Attention output into `hb`
  attention_impl(s, pos, kv_sink, kv_pos, kv_len);

  // Residual back into `x`
  for (int i = 0; i < c.dim; ++i) {
    s.x()[i] += s.hb()[i];
  }

  // FFN pre-norm
  switch (c.norm_type) {
    case LayerNormType::RMSNorm: {
      rmsnorm(s.xb(), s.x(), rms_ffn_weight(), c.dim, c.norm_eps);
      break;
    }
  }

  if (c.n_routed_experts > 0 && moegate() != std::nullopt) {
    PROFILE_BLOCK(ffn_moe);
    // Block is a sparse MoE FFN layer
    PROFILE(matmul_unscaled(s.moe_weights(), s.xb(), *moegate()));
    moe_gate(
      s.active_experts_weights(), moegate_bias(), s.active_experts(), s.moe_weights(),
      c.n_routed_experts, c.n_active_routed, c.norm_topk_prob, c.routed_scaling_factor,
      c.scoring_func, c.topk_method, c.n_group, c.topk_group
    );
    for (int k = 0; k < c.n_active_routed; ++k) {
      int expert_index = s.active_experts()[k];
      // mix self.w2(F.silu(self.w1(x)) * self.w3(x))
      // Note this is a feedforward with a GLU, not a simple MLP.
      PROFILE(matmul_expert(s.hb(), s.xb(), *w1(), expert_index, c.block_size.data(), _s1, s.aqb()));
      PROFILE(matmul_expert(s.hb2(), s.xb(), *w3(), expert_index, c.block_size.data(), _s3, s.aqb()));
      switch (c.act) {
        case ActivationType::GELU: {
          for (int i = 0; i < c.moe_intermediate_size; ++i) {
            s.hb()[i] = gelu(s.hb()[i]) * s.hb2()[i];
          }
          break;
        }
        case ActivationType::SILU: {
          for (int i = 0; i < c.moe_intermediate_size; ++i) {
            s.hb()[i] = silu(s.hb()[i]) * s.hb2()[i];
          }
          break;
        }
      }
      PROFILE(matmul_expert(s.xb2(), s.hb(), *w2(), expert_index, c.block_size.data(), _s2, s.aqb()));
      float expert_weight = s.active_experts_weights()[k];
      for (int i = 0; i < c.dim; ++i) {
        s.x()[i] += s.xb2()[i] * expert_weight;
      }
    }
    if (c.n_shared_experts > 0) {
      // mix self.w2(F.silu(self.w1(x)) * self.w3(x))
      // Note this is a feedforward with a GLU, not a simple MLP.
      PROFILE(matmul(s.hb(), s.xb(), *shared_w1(), c.block_size.data(), _shared_s1, s.aqb()));
      PROFILE(matmul(s.hb2(), s.xb(), *shared_w3(), c.block_size.data(), _shared_s3, s.aqb()));
      switch (c.act) {
        case ActivationType::GELU: {
          for (int i = 0; i < c.n_shared_experts * c.moe_intermediate_size; ++i) {
            s.hb()[i] = gelu(s.hb()[i]) * s.hb2()[i];
          }
          break;
        }
        case ActivationType::SILU: {
          for (int i = 0; i < c.n_shared_experts * c.moe_intermediate_size; ++i) {
            s.hb()[i] = silu(s.hb()[i]) * s.hb2()[i];
          }
          break;
        }
      }

      PROFILE(matmul(s.xb2(), s.hb(), *shared_w2(), c.block_size.data(), _shared_s2, s.aqb()));
      // residual connection back into x
      for (int i = 0; i < c.dim; ++i) {
        s.x()[i] += s.xb2()[i];
      }
    }
  } else {
    PROFILE_BLOCK(ffn_dense);
    // Block is a dense FFN layer
    // mix self.w2(F.silu(self.w1(x)) * self.w3(x))
    // Note this is a feedforward with a GLU, not a simple MLP.
    PROFILE(matmul(s.hb(), s.xb(), *w1(), c.block_size.data(), _s1, s.aqb()));
    PROFILE(matmul(s.hb2(), s.xb(), *w3(), c.block_size.data(), _s3, s.aqb()));
    switch (c.act) {
      case ActivationType::GELU: {
        for (int i = 0; i < c.hidden_dim; ++i) {
          s.hb()[i] = gelu(s.hb()[i]) * s.hb2()[i];
        }
        break;
      }
      case ActivationType::SILU: {
        for (int i = 0; i < c.hidden_dim; ++i) {
          s.hb()[i] = silu(s.hb()[i]) * s.hb2()[i];
        }
        break;
      }
    }
    PROFILE(matmul(s.xb2(), s.hb(), *w2(), c.block_size.data(), _s2, s.aqb()));
    // residual connection back into x
    for (int i = 0; i < c.dim; ++i) {
      s.x()[i] += s.xb2()[i];
    }
  }
}

template<typename T>
void BlockMHA::_attention_impl(
  InferenceState& s, int pos, int kv_sink, int kv_pos, int kv_len
) const {
  PROFILE_BLOCK(attn_mha);
  const Config& c = *_config;

  // qkv matmuls for this position
  if (c.q_lora_rank > 0) {
    PROFILE(matmul(s.q_a(), s.xb(), *wq_a(), c.block_size.data(), _sq_a, s.aqb()));
    switch (c.norm_type) {
      case LayerNormType::RMSNorm: {
        rmsnorm(s.q_a(), s.q_a(), this->rms_q_a_weight(), c.q_lora_rank, c.norm_eps);
        break;
      }
    }
    PROFILE(matmul(s.q(), s.q_a(), *wq_b(), c.block_size.data(), _sq_b, s.aqb()));
  } else {
    PROFILE(matmul(s.q(), s.xb(), *wq(), c.block_size.data(), _sq, s.aqb()));
  }
  PROFILE(matmul(s.kv_a(), s.xb(), *wkv_a(), c.block_size.data(), _skv_a, s.aqb()));

  // Apply RoPE positional encoding
  int q_pe_offset = c.head_dim - c.qk_rope_head_dim;
  bool is_v3 = c.has_moegate_bias;
  for (int h = 0; h < c.n_heads; h++) {
    if (is_v3) {
      rope_v3(s.q(h) + q_pe_offset, c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
    } else {
      rope(s.ropebuf(), s.q(h) + q_pe_offset, c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
    }
  }
  int kv_pe_offset = c.kv_lora_rank;
  float* k_rope = s.kv_a() + kv_pe_offset;
  if (is_v3) {
    rope_v3(k_rope, c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
  } else {
    rope(s.ropebuf(), k_rope, c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
  }
  // rms norm to non-pe chunk of kv_a
  rmsnorm(s.kv_a(), s.kv_a(), this->rms_kv_a_weight(), c.kv_lora_rank, c.norm_eps);
  // un-compress the latent kv via multiplication with wkv_b
  int qk_nope_head_dim = c.head_dim - c.qk_rope_head_dim;
  PROFILE(matmul(s.kv_b(), s.kv_a(), *wkv_b(), c.block_size.data(), _skv_b, s.aqb()));
  // concatenate kv_b and k_rope in each head to build key heads
  for (int h = 0; h < c.n_heads; h++) {
    for (int i = 0; i < qk_nope_head_dim; i++) {
      s.k(h)[i] = s.kv_b(h)[i];
    }
    for (int i = 0; i < c.qk_rope_head_dim; i++) {
      s.k(h)[qk_nope_head_dim + i] = k_rope[i];
    }
  }
  // transfer value heads from kv_b
  for (int h = 0; h < c.n_heads; h++) {
    for (int i = 0; i < c.v_head_dim; i++) {
      s.v(h)[i] = s.kv_b(h)[qk_nope_head_dim + i];
    }
  }

  // update kv cache
  int key_dim = c.n_heads * c.head_dim;
  for (int i = 0; i < key_dim; ++i) {
    this->key_cache(kv_pos)[i] = float_to_half(s.k()[i]);
  }
  int value_dim = c.n_heads * c.v_head_dim;
  for (int i = 0; i < value_dim; ++i) {
    this->value_cache(kv_pos)[i] = float_to_half(s.v()[i]);
  }

  // Sink tokens remain untouched while the rest of the KV cache is incrementally 
  // replaced in ring order, but sink i must always be positioned (max_seq_len - i)
  // away from current timestep. Hence, each forward pass, rotate sink tokens 
  // forward by 1. See https://arxiv.org/abs/2309.17453 for more.
  for (int r = 0; r < kv_sink; r++) {
    f16_t* key = key_cache(r);
    // in-place update PE-chunk of each key head
    int q_pe_offset = c.head_dim - c.qk_rope_head_dim;
    for (int h = 0; h < c.n_heads; h++) {
      f16_t* kh = key + h * c.head_dim;
      if (is_v3) {
        rope_v3(kh + q_pe_offset, c.qk_rope_head_dim, c.qk_rope_head_dim, 1, c.rope_theta);
      } else {
        rope(s.ropebuf(), kh + q_pe_offset, c.qk_rope_head_dim, c.qk_rope_head_dim, 1, c.rope_theta);
      }
    }
  }

  {
    PROFILE_BLOCK(self_attn_mha_inner);
    f16_t* kb = this->key_cache();
    f16_t* vb = this->value_cache();
    int h;
  #pragma omp parallel for private(h)
    for (h = 0; h < c.n_heads; h++) {
      int k_head_offset = h * c.head_dim;
      int v_head_offset = h * c.v_head_dim;
      f16_t* kh = kb + k_head_offset; // Use pointer arithmetic for base address
      f16_t* vh = vb + v_head_offset; // Use pointer arithmetic for base address
      attn(
        s.xb2(h, c.v_head_dim), // Output per Q head
        s.att(h),              // Attention scores per Q head
        s.q(h),                // Query vector for this head
        kh,                    // Pointer to start of relevant K cache base
        vh,                    // Pointer to start of relevant V cache base
        c.head_dim,            // Dimension of K space
        c.v_head_dim,          // Dimension of V space
        c.n_heads,          // Total number of KV heads (passed to inner attn func for stride calculation)
        kv_len                 // Sequence length to attend over
      );
    }
  }

  // final matmul to get output of the attention, place result in s.hb() for residual connection
  PROFILE(matmul(s.hb(), s.xb2(), *wo(), c.block_size.data(), _so, s.aqb()));
}

template<typename T>
void BlockMLA::_attention_impl(
  InferenceState& s, int pos, int kv_sink, int kv_pos, int kv_len
) const {
  PROFILE_BLOCK(attn_mla);
  const Config& c = *_config;
  assert(c.q_lora_rank > 0); // MLA requires q_lora_rank > 0

  // qkv down projections
  PROFILE(matmul(s.q_a(), s.xb(), *wq_a(), c.block_size.data(), _sq_a, s.aqb()));
  switch (c.norm_type) {
    case LayerNormType::RMSNorm: {
      rmsnorm(s.q_a(), s.q_a(), this->rms_q_a_weight(), c.q_lora_rank, c.norm_eps);
      break;
    }
  }
  PROFILE(matmul(s.kv_a(), s.xb(), *wkv_a(), c.block_size.data(), _skv_a, s.aqb()));
  // query transformations
  PROFILE(matmul(s.q_rope(), s.q_a(), *wq_rope_b(), c.block_size.data(), _sq_rope_b, s.aqb()));
  PROFILE(matmul(s.q_c(), s.q_a(), *wc(), c.block_size.data(), _sc, s.aqb()));

  // Apply RoPE positional encoding
  bool is_v3 = c.has_moegate_bias;
  for (int h = 0; h < c.n_heads; h++) {
    if (is_v3) {
      rope_v3(s.q_rope(h), c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
    } else {
      rope(s.ropebuf(), s.q_rope(h), c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
    }
  }
  int kv_pe_offset = c.kv_lora_rank;
  float* k_rope = s.kv_a() + kv_pe_offset;
  if (is_v3) {
    rope_v3(k_rope, c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
  } else {
    rope(s.ropebuf(), k_rope, c.qk_rope_head_dim, c.qk_rope_head_dim, pos, c.rope_theta);
  }
  // rms norm to non-pe chunk of kv_a (compressed latent kv)
  rmsnorm(s.kv_a(), s.kv_a(), this->rms_kv_a_weight(), c.kv_lora_rank, c.norm_eps);

  // update kv cache
  for (int i = 0; i < c.kv_lora_rank; ++i) {
    this->kv_nope_cache(kv_pos)[i] = float_to_half(s.kv_a()[i]);
  }
  for (int i = 0; i < c.qk_rope_head_dim; ++i) {
    this->kv_rope_cache(kv_pos)[i] = float_to_half(k_rope[i]);
  }

  // Sink tokens remain untouched while the rest of the KV cache is incrementally 
  // replaced in ring order, but sink i must always be positioned (max_seq_len - i)
  // away from current timestep. Hence, each forward pass, rotate sink tokens 
  // forward by 1. See https://arxiv.org/abs/2309.17453 for more.
  for (int r = 0; r < kv_sink; r++) {
    f16_t* kv = this->kv_rope_cache(r);
    if (is_v3) {
      rope_v3(kv, c.qk_rope_head_dim, c.qk_rope_head_dim, 1, c.rope_theta);
    } else {
      rope(s.ropebuf(), kv, c.qk_rope_head_dim, c.qk_rope_head_dim, 1, c.rope_theta);
    }
  }

  {
    PROFILE_BLOCK(self_attn_mla_inner);
    int h;
  #pragma omp parallel for private(h)
    for (h = 0; h < c.n_heads; h++) {
      attn_mla(
        s.xb2(h, c.kv_lora_rank), // Output is per-head latent vector
        s.att(h),
        s.q_c(h),
        s.q_rope(h),
        this->kv_nope_cache(),
        this->kv_rope_cache(),
        c.head_dim,
        c.kv_lora_rank,
        c.qk_rope_head_dim,
        kv_len
      );
    }
  }

  // Uncompress latent kvs output by each attention head, storing result in `kv_b`.
  // We reuse kv_b buffer here for the uncompressed value outputs.
  for (int h = 0; h < c.n_heads; h++) {
    float* v_b_head = s.kv_b() + h * c.v_head_dim;
    PROFILE(matmul_expert(v_b_head, s.xb2(h, c.kv_lora_rank), *wv_b(), h, c.block_size.data(), _sv_b, s.aqb()));
  }

  // final matmul to get output of the attention, place result in s.hb() for residual connection
  PROFILE(matmul(s.hb(), s.kv_b(), *wo(), c.block_size.data(), _so, s.aqb()));
}

void mha_cpu(
  float* xout,  // (n_heads, head_dim)
  float* att,   // (n_heads, max_seq_len)
  f16_t* kb,    // (max_seq_len, n_heads, head_dim)
  f16_t* vb,    // (max_seq_len, n_heads, head_dim)
  float* q,     // (n_heads, head_dim)
  int head_dim, int v_head_dim, int kv_len, int max_seq_len, int n_heads
) {
  // Multihead attention. Iterate over all heads.
  int h;
#pragma omp parallel for private(h)
  for (h = 0; h < n_heads; h++) {
    int k_head_offset = h * head_dim;
    int v_head_offset = h * v_head_dim;
    f16_t* kh = kb + k_head_offset;
    f16_t* vh = vb + v_head_offset;
    attn(
      xout + head_dim * h, att + max_seq_len * h, q + head_dim * h, 
      kh, vh, head_dim, v_head_dim, n_heads, kv_len
    );
  }
}

void ffn_cpu(
  float* xout, float* x, 
  float* w1, float* w2, float* w3, 
  int hidden_dim, int dim,
  ActivationType act
) {
  float* hb = new float[hidden_dim];
  float* hb2 = new float[hidden_dim];
  // mix self.w2(F.silu(self.w1(x)) * self.w3(x))
  // Note this is a feedforward with a GLU, not a simple MLP.
  matmul_unscaled(hb, x, QTensor(Quant::F32, {dim, hidden_dim}, w1, dim*hidden_dim*sizeof(float)));
  matmul_unscaled(hb2, x, QTensor(Quant::F32, {dim, hidden_dim}, w3, dim*hidden_dim*sizeof(float)));
  switch (act) {
    case ActivationType::GELU: {
      for (int i = 0; i < hidden_dim; ++i) {
        hb[i] = gelu(hb[i]) * hb2[i];
      }
      break;
    }
    case ActivationType::SILU: {
      for (int i = 0; i < hidden_dim; ++i) {
        hb[i] = silu(hb[i]) * hb2[i];
      }
      break;
    }
  }

  matmul_unscaled(xout, hb, QTensor(Quant::F32, {hidden_dim, dim}, w2, hidden_dim*dim*sizeof(float)));
  
  delete[] hb;
  delete[] hb2;
}

template void Block::_block_cpu<float>(InferenceState&, int, int, int, int) const;
template void Block::_block_cpu<f16_t>(InferenceState&, int, int, int, int) const;
template void Block::_block_cpu<f8e5m2_t>(InferenceState&, int, int, int, int) const;
template void Block::_block_cpu<block_q2_K>(InferenceState&, int, int, int, int) const;
template void Block::_block_cpu<block_q3_K>(InferenceState&, int, int, int, int) const;

template void BlockMHA::_attention_impl<float>(InferenceState&, int, int, int, int) const;
template void BlockMHA::_attention_impl<f16_t>(InferenceState&, int, int, int, int) const;
template void BlockMHA::_attention_impl<f8e5m2_t>(InferenceState&, int, int, int, int) const;
template void BlockMHA::_attention_impl<block_q2_K>(InferenceState&, int, int, int, int) const;
template void BlockMHA::_attention_impl<block_q3_K>(InferenceState&, int, int, int, int) const;

template void BlockMLA::_attention_impl<float>(InferenceState&, int, int, int, int) const;
template void BlockMLA::_attention_impl<f16_t>(InferenceState&, int, int, int, int) const;
template void BlockMLA::_attention_impl<f8e5m2_t>(InferenceState&, int, int, int, int) const;
template void BlockMLA::_attention_impl<block_q2_K>(InferenceState&, int, int, int, int) const;
template void BlockMLA::_attention_impl<block_q3_K>(InferenceState&, int, int, int, int) const;

void Model::_copy_embedding(InferenceState& s, int token) {
  const Config& c = *config;
  switch (c.weight_quant) {
    case Quant::F32: {
      float* emb = static_cast<float*>(token_embedding_table->data);
      for (int i = 0; i < c.dim; ++i) {
        s.x()[i] = emb[token * c.dim + i];
      }
      break;
    }
    case Quant::F16: {
      f16_t* emb = static_cast<f16_t*>(token_embedding_table->data);
      for (int i = 0; i < c.dim; i+=1) {
        s.x()[i] = half_to_float(emb[token * c.dim + i]);
      }
      break;
    }
    case Quant::F8E5M2: {
      f8e5m2_t* emb = static_cast<f8e5m2_t*>(token_embedding_table->data);
      float* emb_scale = static_cast<float*>(token_embedding_scale->data);
      int* block_size = config->block_size.data();
      int scale_num_cols = (c.dim + block_size[1] - 1) / block_size[1];
      for (int i = 0; i < c.dim; i+=1) {
        int scale_i = token / block_size[0];
        int scale_j = i / block_size[1];
        float scale = emb_scale[scale_i * scale_num_cols + scale_j];
        s.x()[i] = float8e5m2_to_float(emb[token * c.dim + i]) * scale;
      }
      break;
    }
    case Quant::Q2_K: {
      block_q2_K* emb = static_cast<block_q2_K*>(token_embedding_table->data);
      int blocks_per_row = c.dim / QK_K;
      dequantize_row_q2_K(emb + token * blocks_per_row, s.x(), c.dim);
      break;
    }
    case Quant::Q3_K: {
      block_q3_K* emb = static_cast<block_q3_K*>(token_embedding_table->data);
      int blocks_per_row = c.dim / QK_K;
      dequantize_row_q3_K(emb + token * blocks_per_row, s.x(), c.dim);
      break;
    }
    default: {
      assert(false && "unsupported weight quantization");
    }
  }
}

void Model::_forward_cpu(InferenceState& s, int token, int pos, InferenceMode mode) {
  const Config& c = *config;

  // copy the token embedding into `x`
  PROFILE(_copy_embedding(s, token));

  // When decoding past the context length, keep the first few tokens in the KV cache
  // untouched as "attention sinks" while replacing the rest in ring order.
  // See StreamingLLM (https://arxiv.org/pdf/2309.17453) for more.
  int original_max_position = c.rs_original_max_position_embeddings;
  int kv_sink = pos >= original_max_position ? KV_SINKS : 0;
  int kv_pos = kv_sink + (pos - kv_sink) % (original_max_position - kv_sink);
  int kv_len = pos >= original_max_position ? original_max_position : pos + 1;

  // forward all layers in order
  for (auto b : blocks) {
    b->block(s, pos, kv_sink, kv_pos, kv_len);
  }

  if (mode == InferenceMode::HYDRATE_KV_CACHE) {
    // only hydrate the KV cache and don't compute output logits
    return;
  }

  // final layer norm
  switch (c.norm_type) {
    case LayerNormType::RMSNorm: {
      rmsnorm(s.x(), s.x(), static_cast<float*>(rms_final_weight->data), c.dim, c.norm_eps);
      break;
    }
  }

  // classifier into logits
  {
    PROFILE_BLOCK(lm_head);
    switch (c.weight_quant) {
      case Quant::F32:
      case Quant::F16: {
        matmul_unscaled(s.logits(), s.x(), *wcls);
        break;
      }
      case Quant::F8E5M2:
      case Quant::Q2_K:
      case Quant::Q3_K: {
        matmul(s.logits(), s.x(), *wcls, c.block_size.data(), scls, s.aqb());
        break;
      }
      default: {
        assert(false && "unsupported weight quantization");
      }
    }
  }
}

================================================
FILE: src/main.cpp
================================================
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
#include <stdio.h>

#include "fmt/format.h"

#include "codec.h"
#include "model.h"
#include "profile.h"
#include "sampler.h"
#include "time_utils.h"
#include "tokenizer.h"

void error_usage() {
  fprintf(stderr, "Usage:   main <checkpoint_dir> [options]\n");
  fprintf(stderr, "Example: main model_weights_dir/ -i \"Q: What is the meaning of life?\"\n");
  fprintf(stderr, "Options:\n");
  fprintf(stderr, "  -h Display this help message\n");
  fprintf(stderr, "  -L Locks model weights to RAM, disabling swap. Requires sudo.\n");
  fprintf(stderr, "  -m [completion,passkey,perplexity,interactive] which mode to run in (default - completion)\n");
  fprintf(stderr, "  -T <int> sliding window context length (0 - max)\n");
  fprintf(stderr, "\n");
  fprintf(stderr, "Perplexity mode options:\n");
  fprintf(stderr, "  Choose one:\n");
  fprintf(stderr, "    -i <string> input prompt\n");
  fprintf(stderr, "    -f <filepath> input file with prompt\n");
  fprintf(stderr, "    -w use wikitext as input\n");
  fprintf(stderr, "Completion mode options:\n");
  fprintf(stderr, "  -n <int>    number of steps to run for in completion mode, default 256. 0 = max_seq_len, -1 = infinite\n");
  fprintf(stderr, "  Choose one:\n");
  fprintf(stderr, "    -i <string> input prompt\n");
  fprintf(stderr, "    -t <float> temperature (default - 1.0)\n");
  fprintf(stderr, "    -p <float> p for top-p sampling (default - 0.95)\n");
  fprintf(stderr, "    -f <filepath> input file with prompt\n");
  fprintf(stderr, "Passkey mode options:\n");
  fprintf(stderr, "  -n <int>    number of junk lines to insert (default - 250)\n");
  fprintf(stderr, "  -l <int>    passkey position (-1 - random)\n");
  exit(1);
}

void help_usage_interactive() {
  fprintf(stderr, "Usage:   <mode> [options]\n");
  fprintf(stderr, "Example: c -i \"Q: What is the meaning of life?\"\n");
  fprintf(stderr, "Modes:\n");
  fprintf(stderr, "  h Display this help message\n");
  fprintf(stderr, "  c Completion - complete a single prompt \n");
  fprintf(stderr, "  p Perplexity - compute perplexity of a single prompt \n");
  fprintf(stderr, "  k Passkey - test passkey extraction \n");
  fprintf(stderr, "\n");
  fprintf(stderr, "Perplexity mode options:\n");
  fprintf(stderr, "  Choose one:\n");
  fprintf(stderr, "    -i <string> input prompt\n");
  fprintf(stderr, "    -f <filepath> input file with prompt\n");
  fprintf(stderr, "    -w use wikitext as input\n");
  fprintf(stderr, "Completion mode options:\n");
  fprintf(stderr, "  -n <int>    number of steps to run for in completion mode, default 256. 0 = max_seq_len, -1 = infinite\n");
  fprintf(stderr, "  Choose one:\n");
  fprintf(stderr, "    -i <string> input prompt\n");
  fprintf(stderr, "    -t <float> temperature (default - 1.0)\n");
  fprintf(stderr, "    -p <float> p for top-p sampling (default - 0.95)\n");
  fprintf(stderr, "    -f <filepath> input file with prompt\n");
  fprintf(stderr, "Passkey mode options:\n");
  fprintf(stderr, "  -n <int>    number of junk lines to insert (default - 250)\n");
  fprintf(stderr, "  -l <int>    passkey position (-1 - random)\n");
}

struct Session {
  Session(const std::string& checkpoint_dir, bool lock_model_weights, int context, uint64_t sampler_seed):
    model_data(checkpoint_dir, lock_model_weights),
    model(model_data, context),
    state(model.config),
    sampler(model.config, sampler_seed),
    tokenizer(model_data) {}
  YALMData model_data;
  Model model;
  InferenceState state;
  Sampler sampler;
  Tokenizer tokenizer;
};

struct CompletionArgs {
  std::string prompt;
  int num_steps;
  float temperature = 1.0;
  float top_p = 0.95;
  // Returns true if args are valid, false otherwise
  bool parse_args(const std::vector<const char*>& args) {
    std::string prompt_path = "";
    for (size_t i = 0; i < args.size();) {
      // do some basic validation
      if (args[i][0] != '-') {
        return false;
      } // must start with dash
      if (strlen(args[i]) != 2) {
        return false;
      } // must be -x (one dash, one letter)

      // read in the args
      if (args[i][1] == 'h') {
        return false;
      } else if (args[i][1] == 'i') {
        if (i + 1 >= args.size()) {
          return false;
        }
        prompt = args[i + 1];
        i += 2;
      } else if (args[i][1] == 't') {
        if (i + 1 >= args.size()) {
          return false;
        }
        temperature = std::stof(args[i + 1]);
        i += 2;
      } else if (args[i][1] == 'p') {
        if (i + 1 >= args.size()) {
          return false;
        }
        top_p = std::stof(args[i + 1]);
        i += 2;
      } else if (args[i][1] == 'f') {
        if (i + 1 >= args.size()) {
          return false;
        }
        prompt_path = args[i + 1];
        i += 2;
      } else if (args[i][1] == 'n') {
        if (i + 1 >= args.size()) {
          return false;
        }
        num_steps = std::stoi(args[i + 1]);
        i += 2;
      } else {
        return false;
      }
    }
    int has_prompt = prompt.size() > 0 ? 1 : 0;
    int has_prompt_path = prompt_path.size() > 0 ? 1 : 0;
    if ((has_prompt + has_prompt_path) != 1) {
      return false;
    } else if (has_prompt_path) {
      std::ifstream file(prompt_path);
      if (!file.is_open()) {
        std::cerr << "Error: could not open file " << prompt_path << std::endl;
        return false;
      }

      std::stringstream buffer;
      buffer << file.rdbuf();
      prompt = buffer.str();
    }
    return true;
  }
};

struct PasskeyArgs {
  int n_junk;
  int passkey_pos;
  // Returns true if args are valid, false otherwise
  bool parse_args(const std::vector<const char*>& args) {
    for (size_t i = 2; i < args.size();) {
      // do some basic validation
      if (args[i][0] != '-') {
        return false;
      } // must start with dash
      if (strlen(args[i]) != 2) {
        return false;
      } // must be -x (one dash, one letter)

      // read in the args
      if (args[i][1] == 'h') {
        return false;
      } else if (args[i][1] == 'l') {
        if (i + 1 >= args.size()) {
          return false;
        }
        passkey_pos = std::stoi(args[i + 1]);
        i += 2;
      } else if (args[i][1] == 'n') {
        if (i + 1 >= args.size()) {
          return false;
        }
        n_junk = std::stoi(args[i + 1]);
        i += 2;
      } else {
        return false;
      }
    }
    if (passkey_pos != -1 && (passkey_pos >= n_junk || passkey_pos < 0)) {
      std::cerr << "Error: passkey position must be between 0 and " << n_junk - 1 << std::endl;
      return false;
    }
    return true;
  }
};

struct PerplexityArgs {
  std::string prompt;
  bool use_wikitext = false;
  // Returns true if args are valid, false otherwise
  bool parse_args(const std::vector<const char*>& args) {
    std::string prompt_path = "";
    for (size_t i = 0; i < args.size();) {
      // do some basic validation
      if (args[i][0] != '-') {
        return false;
      } // must start with dash
      if (strlen(args[i]) != 2) {
        return false;
      } // must be -x (one dash, one letter)

      // read in the args
      if (args[i][1] == 'h') {
        return false;
      } else if (args[i][1] == 'i') {
        if (i + 1 >= args.size()) {
          return false;
        }
        prompt = args[i + 1];
        i += 2;
      } else if (args[i][1] == 'f') {
        if (i + 1 >= args.size()) {
          return false;
        }
        prompt_path = args[i + 1];
        i += 2;
      } else if (args[i][1] == 'w') {
        use_wikitext = true;
        i += 1;
      } else {
        return false;
      }
    }
    int has_prompt = prompt.size() > 0 ? 1 : 0;
    int has_prompt_path = prompt_path.size() > 0 ? 1 : 0;
    int has_wikitext = use_wikitext ? 1 : 0;
    if ((has_prompt + has_prompt_path + has_wikitext) != 1) {
      std::cerr << "Error: must provide exactly one nonempty -i <input prompt> or -f <input filepath> or -w" << std::endl;
      return false;
    } else if (has_prompt_path) {
      std::ifstream file(prompt_path);
      if (!file.is_open()) {
        std::cerr << "Error: could not open file " << prompt_path << std::endl;
        return false;
      }

      std::stringstream buffer;
      buffer << file.rdbuf();
      prompt = buffer.str();
    }
    return true;
  }
};

std::vector<int> encode_prompt(const std::string& prompt, Tokenizer& tokenizer) {
  std::vector<int> encoding;
  {
    uint64_t encode_start_ms = get_timestamp_ms();
    encoding = tokenizer.encode(prompt, true);
    uint64_t encode_end_ms = get_timestamp_ms();

    std::cout << tokenizer.encoding_to_debug_string(encoding) << std::endl;
    uint64_t encoding_ms = encode_end_ms - encode_start_ms;
    std::cout << fmt::format(
      "Encoding stats: ({} tokens, throughput: {:.5}tok/s, latency: {:.5}s/tok, total: {:.5}s)\n",
      encoding.size(),
      encoding.size() / (encoding_ms / 1000.0),
      (encoding_ms / 1000.0) / encoding.size(),
      encoding_ms / 1000.0
    ) << std::endl;
  }
  return encoding;
}

void run_completion(
  Session& session,
  const std::string& prompt,
  int num_steps,
  float temperature,
  float top_p
) {
  Model& model = session.model;
  InferenceState& state = session.state;
  Sampler& sampler = session.sampler;
  Tokenizer& tokenizer = session.tokenizer;

  std::cout << "Model active bytes with full context window: " << model.active_bytes(model.config->max_seq_len) << std::endl;
  std::cout << "Model active bytes with no context: " << model.active_bytes(0) << std::endl;

  if (num_steps == 0) {
    // `-n 0` means use the full context length
    num_steps = model.config->max_seq_len;
  }

  {
    ProfileDisabledScope profile_disabled;
    std::cout << "Running warmup..." << std::endl;
    // Do one inference as warmup.
    // On CPU, this ensures all tensors are loaded into memory via mmap.
    model.forward(state, 0, 0);
    std::cout << "Warmup complete" << std::endl;
  }

  std::vector<int> encoding = encode_prompt(prompt, tokenizer);

  uint64_t start_ms = get_timestamp_ms();
  size_t read_bytes = 0;
  // Hydrate KV cache by forwarding model on all prompt tokens and discarding output.
  // This also generates output logits for the last token.
  for (size_t pos = 0; pos < encoding.size(); pos++) {
    ProfileScope scope(fmt::format("fwd_pos_{:03d}_hydrate", pos));
    int token_id = encoding[pos];
    InferenceMode inferMode = pos + 1 == encoding.size() ? 
      InferenceMode::OUTPUT_LOGITS : InferenceMode::HYDRATE_KV_CACHE;
    model.forward(state, token_id, pos, inferMode);
    read_bytes += model.active_bytes(pos);
  }
  uint64_t end_hydrate_ms = get_timestamp_ms();
  // For N steps:
  // - Sample + decode output logits
  // - Forward the model
  for (int i = 0; i < num_steps || num_steps == -1; i++) {
    int token_id = sampler.sample(state, temperature, top_p);
    std::string token_str = tokenizer.decode_one(encoding.back(), token_id);
    std::cout << token_str << std::flush;
    encoding.push_back(token_id);
    if (token_id == tokenizer.eos_id || token_id == tokenizer.eot_id) {
      break;
    }
    ProfileScope scope(fmt::format("fwd_pos_{:03d}_decode", encoding.size() - 1));
    model.forward(state, token_id, encoding.size() - 1);
    read_bytes += model.active_bytes(encoding.size() - 1);
  }
  std::cout << "\n" << std::endl;
  uint64_t end_ms = get_timestamp_ms();
  double elapsed_s = (end_ms - start_ms) / 1000.0;
  std::cout << fmt::format(
    "Generation stats:\n"
    "  {} tokens\n"
    "  throughput: {:.5}tok/s\n"
    "  latency: {:.5}s/tok\n"
    "  hydrate: {:.5}s\n"
    "  bandwidth: {:.5}GB/s\n"
    "  total: {:.5}s\n",
    encoding.size(),
    encoding.size() / elapsed_s,
    elapsed_s / encoding.size(),
    (end_hydrate_ms - start_ms) / 1000.0,
    ((double)read_bytes / 1e9) / elapsed_s,
    elapsed_s
  ) << std::endl;

#if PROFILE_ENABLED
  std::cout << "Profile total times (sec): " << std::endl;
  for (const auto& [key, value] : profile_times()) {
    std::cout << key << ": " << value << std::endl;
  }
#endif
}

std::vector<int> V2_ENCODED_WIKITEXT = {
  #include "wikitest.cat.1chunk.v2-encoded.txt"
};

std::vector<int> V3_ENCODED_WIKITEXT = {
  #include "wikitest.cat.1chunk.v3-encoded.txt"
};

void run_perplexity(
  Session& session,
  const std::vector<int>& encoding
) {
  Model& model = session.model;
  InferenceState& state = session.state;
  Sampler& sampler = session.sampler;

  std::cout << "Model active bytes with full context window: " << model.active_bytes(model.config->max_seq_len) << std::endl;

  {
    ProfileDisabledScope profile_disabled;
    std::cout << "Running warmup..." << std::endl;
    // Do one inference as warmup.
    // On CPU, this ensures all tensors are loaded into memory via mmap.
    model.forward(state, 0, 0);
    std::cout << "Warmup complete" << std::endl;
  }

  double sum_logprob = 0.0;
  double ss_logprob = 0.0;
  // Generates output logits for all tokens in the prompt and sum log probs to
  // compute perplexity.
  uint64_t start_ms = get_timestamp_ms();
  size_t read_bytes = 0;
  size_t N = encoding.size() - 1;
  for (size_t pos = 0; pos + 1 < encoding.size(); pos++) {
    std::cout << "\r Computing perplexity..." << pos + 1 << "/" << N << std::flush;
    
    int token_id = encoding[pos];
    model.forward(state, token_id, pos);
    read_bytes += model.active_bytes(pos);

    double logprob = std::log(sampler.sample_prob(encoding[pos + 1], state));
    sum_logprob += logprob;
    ss_logprob += logprob * logprob;
  }
  std::cout << std::endl;
  uint64_t end_ms = get_timestamp_ms();
  double elapsed_s = (end_ms - start_ms)/1000.0;
  double perplexity = std::exp(-sum_logprob / N);
  double perplexity_error = perplexity * std::sqrt(
    (ss_logprob - sum_logprob * sum_logprob / N) / N / N
  );
  std::cout << fmt::format(
    "Stats:\n"
    "  {} tokens\n"
    "  perplexity: {:.5} ± {:.5}\n"
    "  throughput: {:.5}tok/s\n"
    "  latency: {:.5}s/tok\n"
    "  bandwidth: {:.5}GB/s\n"
    "  total: {:.5}s\n",
    N,
    perplexity,
    perplexity_error,
    N / elapsed_s,
    elapsed_s / N,
    ((double)read_bytes / 1e9) / elapsed_s,
    elapsed_s
  ) << std::endl;
}

void run_passkey(
  Session& session,
  const int n_junk,
  const int passkey_pos
) {
  Model& model = session.model;
  InferenceState& state = session.state;
  Sampler& sampler = session.sampler;
  Tokenizer& tokenizer = session.tokenizer;

  std::cout << "Model active bytes with full context window: " << model.active_bytes(model.config->max_seq_len) << std::endl;

  const std::string PROMPT_PREFIX = 
    "There is an important info hidden inside a lot of irrelevant text. "
    "Find it and memorize them. I will quiz you about the important information there.";
  const std::string PROMPT_SUFFIX = " What is the pass key? The pass key is";

  const int passkey = std::rand() % 50000 + 1;
  const int pos = passkey_pos == -1 ? std::rand() % n_junk : passkey_pos;

  std::string prompt = PROMPT_PREFIX;
  for (int i = 0; i < n_junk; i++) {
    if (i % n_junk == pos) {
      prompt += " The pass key is " + std::to_string(passkey) + ". Remember it. " + std::to_string(passkey) + " is the pass key.";
    }
    prompt += " The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.";
  }
  prompt += PROMPT_SUFFIX;

  std::vector<int> encoding;
  {
    uint64_t encode_start_ms = get_timestamp_ms();
    encoding = tokenizer.encode(prompt, true);
    uint64_t encode_end_ms = get_timestamp_ms();

    uint64_t encoding_ms = encode_end_ms - encode_start_ms;
    std::cout << fmt::format(
      "Encoding stats: ({} tokens, throughput: {:.5}tok/s, latency: {:.5}s/tok, total: {:.5}s)\n",
      encoding.size(),
      encoding.size() / (encoding_ms / 1000.0),
      (encoding_ms / 1000.0) / encoding.size(),
      encoding_ms / 1000.0
    ) << std::endl;
  }

  // Allow max 16 steps to generate passkey
  const size_t MAX_GENERATION_STEPS = 16;

  std::cout << fmt::format(
    "Passkey test:\n"
    "  prompt: {} tokens\n"
    "  passkey: {}\n"
    "  passkey token index: ~{}\n",
    encoding.size(),
    passkey,
    (int)(((float)pos) / n_junk * encoding.size())
  ) << std::endl;

  size_t N = encoding.size();
  for (size_t pos = 0; pos < N; pos++) {
    std::cout << "\r Running passkey test..." << pos + 1 << "/" << N << std::flush;
    int token_id = encoding[pos];
    InferenceMode inferMode = pos + 1 == N ? 
      InferenceMode::OUTPUT_LOGITS : InferenceMode::HYDRATE_KV_CACHE;
    model.forward(state, token_id, pos, inferMode);
  }
  std::cout << std::endl;
  std::cout << PROMPT_SUFFIX << std::flush;
  for (size_t pos = N; pos < N + MAX_GENERATION_STEPS; pos++) {
    int token_id = sampler.sample(state);
    std::string token_str = tokenizer.decode_one(encoding.back(), token_id);
    std::cout << token_str << std::flush;
    encoding.push_back(token_id);
    if (token_id == tokenizer.eos_id || token_id == tokenizer.eot_id) {
      break;
    }
    model.forward(state, token_id, pos);
  }
  std::cout << std::endl;
}

void run_interactive(Session& session) {
  std::string input = "";
  while (true) {
    std::cout << "> " << std::flush;
    std::getline(std::cin, input);
    if (input == "exit") {
      break;
    }
    // Split string by space
    std::vector<std::string> arg_strs;
    std::stringstream ss(input);
    std::string arg;
    while (ss >> arg) {
      if (arg_strs.size() > 0 && arg_strs[arg_strs.size() - 1].starts_with("\"") && !arg_strs[arg_strs.size() - 1].ends_with("\"")) {
        // Double quotes enclose strings that can contain spaces
        arg_strs[arg_strs.size() - 1] += " " + arg;
        if (arg.ends_with("\"")) {
          // Remove the double quotes
          arg_strs[arg_strs.size() - 1] = arg_strs[arg_strs.size() - 1].substr(1, arg_strs[arg_strs.size() - 1].size() - 2);
        }
      } else {
        arg_strs.push_back(arg);
      }
    }
    if (arg_strs.size() == 0) {
      help_usage_interactive();
      continue;
    }
    std::string mode = arg_strs[0];
    if (std::string("completion").starts_with(mode)) {
      mode = "completion";
    } else if (std::string("passkey").starts_with(mode) && mode != "p") {
      mode = "passkey";
    } else if (std::string("perplexity").starts_with(mode) && mode != "p") {
      mode = "perplexity";
    } else if (std::string("interactive").starts_with(mode)) {
      mode = "interactive";
    } else {
      help_usage_interactive();
      continue;
    }
    std::vector<const char*> args;
    for (size_t i = 1; i < arg_strs.size(); i++) {
      args.push_back(arg_strs[i].c_str());
    }
    if (mode == "completion") {
      CompletionArgs completion_args;
      if (!completion_args.parse_args(args)) {
        help_usage_interactive();
        continue;
      }
      run_completion(session, completion_args.prompt, completion_args.num_steps, completion_args.temperature, completion_args.top_p);
    } else if (mode == "passkey") {
      PasskeyArgs passkey_args;
      if (!passkey_args.parse_args(args)) {
        help_usage_interactive();
        continue;
      }
      run_passkey(session, passkey_args.n_junk, passkey_args.passkey_pos);
    } else if (mode == "perplexity") {
      PerplexityArgs perplexity_args;
      if (!perplexity_args.parse_args(args)) {
        help_usage_interactive();
        continue;
      }
      std::vector<int> encoding;
      if (perplexity_args.use_wikitext) {
        if (session.model_data.metadata.at("arch").get<std::string>() == "DeepseekV3ForCausalLM") {
          encoding = V3_ENCODED_WIKITEXT;
        } else {
          encoding = V2_ENCODED_WIKITEXT;
        }
      } else {
        encoding = encode_prompt(perplexity_args.prompt, session.tokenizer);
      }
      run_perplexity(session, encoding);
    }
  }
}

int main(int argc, char* argv[]) {
  std::vector<const char*> args(argv, argv + argc);
  std::vector<const char*> next_args;

  std::string checkpoint_dir = "";    // e.g. out/model.bin
  // Options
  std::string mode = "completion";     // completion, passkey, perplexity, or interactive
  int context = 0;
  bool lock_model_weights = false;

  if (args.size() >= 2) {
    checkpoint_dir = args[1];
  } else {
    error_usage();
  }

  // read in session args first, put everything else in next_args
  for (size_t i = 2; i < args.size();) {
    if (args[i][0] == '-' && strlen(args[i]) == 2) {
      if (args[i][1] == 'h') {
        error_usage();
      } else if (args[i][1] == 'L') {
        lock_model_weights = true;
        i += 1;
      } else if (args[i][1] == 'm') {
        if (i + 1 >= args.size()) {
          error_usage();
        }
        mode = args[i + 1];
        if (std::string("completion").starts_with(mode)) {
          mode = "completion";
        } else if (std::string("passkey").starts_with(mode) && mode != "p") {
          mode = "passkey";
        } else if (std::string("perplexity").starts_with(mode) && mode != "p") {
          mode = "perplexity";
        } else if (std::string("interactive").starts_with(mode)) {
          mode = "interactive";
        } else {
          error_usage();
        }
        i += 2;
      } else if (args[i][1] == 'T') {
        if (i + 1 >= args.size()) {
          error_usage();
        }
        context = std::stoi(args[i + 1]);
        i += 2;
      } else {
        next_args.push_back(args[i]);
        i += 1;
      }
    } else {
      next_args.push_back(args[i]);
      i += 1;
    }
  }

  if (mode == "completion") {
    CompletionArgs completion_args;
    if (!completion_args.parse_args(next_args)) {
      error_usage();
    }
    Session session(checkpoint_dir, lock_model_weights, context, get_timestamp_ms());
    run_completion(session, completion_args.prompt, completion_args.num_steps, completion_args.temperature, completion_args.top_p);
  } else if (mode == "passkey") {
    PasskeyArgs passkey_args;
    if (!passkey_args.parse_args(next_args)) {
      error_usage();
    }
    Session session(checkpoint_dir, lock_model_weights, context, get_timestamp_ms());
    run_passkey(session, passkey_args.n_junk, passkey_args.passkey_pos);
  } else if (mode == "perplexity") {
    PerplexityArgs perplexity_args;
    if (!perplexity_args.parse_args(next_args)) {
      error_usage();
    }
    Session session(checkpoint_dir, lock_model_weights, context, get_timestamp_ms());
    std::vector<int> encoding;
    if (perplexity_args.use_wikitext) {
      if (session.model_data.metadata.at("arch").get<std::string>() == "DeepseekV3ForCausalLM") {
        encoding = V3_ENCODED_WIKITEXT;
      } else {
        encoding = V2_ENCODED_WIKITEXT;
      }
    } else {
      encoding = encode_prompt(perplexity_args.prompt, session.tokenizer);
    }
    run_perplexity(session, encoding);
  } else if (mode == "interactive") {
    if (next_args.size() != 0) {
      error_usage();
    }
    Session session(checkpoint_dir, lock_model_weights, context, get_timestamp_ms());
    run_interactive(session);
  }

  return 0;
}

================================================
FILE: src/model.cpp
================================================
#include "model.h"

#include "json.hpp"
#include <algorithm>
#include <array>
#include <cfloat>
#include "fmt/format.h"
#include <iostream>
#include <limits.h>
#include <string>

#include "immintrin.h"

#include "quant.h"

using json = nlohmann::json;

int cdiv(int a, int b) {
  return (a + b - 1) / b;
}

void Config::from_yalm(YALMData& yalm, int context) {
  dim = std::stoi(yalm.metadata.at("dim").get<std::string>());
  hidden_dim = std::stoi(yalm.metadata.at("hidden_dim").get<std::string>());
  n_layers = std::stoi(yalm.metadata.at("n_layers").get<std::string>());
  n_heads = std::stoi(yalm.metadata.at("n_heads").get<std::string>());
  vocab_size = std::stoi(yalm.metadata.at("vocab_size").get<std::string>());
  // mixture of experts
  n_shared_experts = yalm.metadata.contains("n_shared_experts") ? std::stoi(yalm.metadata.at("n_shared_experts").get<std::string>()) : 0;
  n_routed_experts = yalm.metadata.contains("n_routed_experts") ? std::stoi(yalm.metadata.at("n_routed_experts").get<std::string>()) : 0;
  n_active_routed = yalm.metadata.contains("n_active_routed") ? std::stoi(yalm.metadata.at("n_active_routed").get<std::string>()) : 0;
  moe_intermediate_size = yalm.metadata.contains("moe_intermediate_size") ? std::stoi(yalm.metadata.at("moe_intermediate_size").get<std::string>()) : 0;
  routed_scaling_factor = yalm.metadata.contains("routed_scaling_factor") ? std::stof(yalm.metadata.at("routed_scaling_factor").get<std::string>()) : 1.0;
  n_group = yalm.metadata.contains("n_group") ? std::stoi(yalm.metadata.at("n_group").get<std::string>()) : 1;
  norm_topk_prob = yalm.metadata.contains("norm_topk_prob") ? yalm.metadata.at("norm_topk_prob").get<std::string>() == "True" : false;
  std::string scoring_func_str = yalm.metadata.value("scoring_func", "softmax");
  if (scoring_func_str == "softmax") {
    scoring_func = ScoringFunc::SOFTMAX;
  } else if (scoring_func_str == "sigmoid") {
    scoring_func = ScoringFunc::SIGMOID;
  } else {
    std::cerr << "unsupported scoring_func '" << scoring_func_str << "', defaulting to softmax" << std::endl;
    scoring_func = ScoringFunc::SOFTMAX;
  }
  topk_group = yalm.metadata.contains("topk_group") ? std::stoi(yalm.metadata.at("topk_group").get<std::string>()) : 0;
  std::string topk_method_str = yalm.metadata.value("topk_method", "");
  if (topk_method_str == "greedy") {
    topk_method = TopKMethod::GREEDY;
  } else if (topk_method_str == "group_limited_greedy") {
    topk_method = TopKMethod::GROUP_LIMITED_GREEDY;
  } else if (topk_method_str == "noaux_tc") {
    topk_method = TopKMethod::NOAUX_TC;
    assert(false && "TODO: support for Deepseek v3");
  } else {
    std::cerr << "unsupported topk_method '" << topk_method_str << "', defaulting to greedy" << std::endl;
    topk_method = TopKMethod::GREEDY;
  }
  has_moegate_bias = yalm.metadata.at("arch").get<std::string>() == "DeepseekV3ForCausalLM";
  // multi-latent attention
  use_mla = yalm.metadata.contains("use_mla") ? 
    static_cast<bool>(std::stoi(yalm.metadata.at("use_mla").get<std::string>())) : false;
  kv_lora_rank = yalm.metadata.contains("kv_lora_rank") ? std::stoi(yalm.metadata.at("kv_lora_rank").get<std::string>()) : 0;
  q_lora_rank = yalm.metadata.contains("q_lora_rank") ? std::stoi(yalm.metadata.at("q_lora_rank").get<std::string>()) : 0;
  qk_nope_head_dim = yalm.metadata.contains("qk_nope_head_dim") ? std::stoi(yalm.metadata.at("qk_nope_head_dim").get<std::string>()) : 0;
  qk_rope_head_dim = yalm.metadata.contains("qk_rope_head_dim") ? std::stoi(yalm.metadata.at("qk_rope_head_dim").get<std::string>()) : 0;
  v_head_dim = yalm.metadata.contains("v_head_dim") ? std::stoi(yalm.metadata.at("v_head_dim").get<std::string>()) : 0;
  head_dim = qk_nope_head_dim + qk_rope_head_dim;

  max_seq_len = std::stoi(yalm.metadata.at("max_seq_len").get<std::string>());
  if (context) {
    max_seq_len = std::min(max_seq_len, context);
  }

  rope_theta = std::stof(yalm.metadata.at("rope_theta").get<std::string>());
  norm_eps = std::stof(yalm.metadata.value("norm_eps", "1e-5"));

  std::string act_str = yalm.metadata.value("act_type", "gelu");
  if (act_str == "gelu") {
    act = ActivationType::GELU;
  } else if (act_str == "silu") {
    act = ActivationType::SILU;
  } else {
    std::cerr << "unsupported act_type, defaulting to gelu" << std::endl;
    act = ActivationType::GELU;
  }

  std::string norm_type_str = yalm.metadata.value("norm_type", "rmsnorm");
  if (norm_type_str == "rmsnorm") {
    norm_type = LayerNormType::RMSNorm;
  } else {
    std::cerr << "unsupported norm_type, defaulting to rmsnorm" << std::endl;
    norm_type = LayerNormType::RMSNorm;
  }

  first_k_dense_replace = yalm.metadata.contains("first_k_dense_replace") ? 
    std::stoi(yalm.metadata.at("first_k_dense_replace").get<std::string>()) : 0;

  std::string quant = yalm.metadata.at("quant").get<std::string>();
  if (quant == "fp32") {
    weight_quant = Quant::F32;
  } else if (quant == "fp16") {
    weight_quant = Quant::F16;
  } else if (quant == "f8e5m2") {
    weight_quant = Quant::F8E5M2;
  } else if (quant == "q2_k") {
    weight_quant = Quant::Q2_K;
  } else if (quant == "q3_k") {
    weight_quant = Quant::Q3_K;
  } else {
    std::cerr << "FATAL: unsupported quant: " << quant << std::endl;
    assert(false);
  }

  // quantization
  if (yalm.metadata.contains("quantization_block_size_0")) {
    block_size[0] = std::stoi(yalm.metadata.at("quantization_block_size_0").get<std::string>());
    block_size[1] = std::stoi(yalm.metadata.at("quantization_block_size_1").get<std::string>());
  }

  // RoPE scaling
  rs_beta_fast = std::stoi(yalm.metadata.at("rope_scaling_beta_fast").get<std::string>());
  rs_beta_slow = std::stoi(yalm.metadata.at("rope_scaling_beta_slow").get<std::string>());
  rs_factor = std::stof(yalm.metadata.at("rope_scaling_factor").get<std::string>());
  rs_mscale = std::stof(yalm.metadata.at("rope_scaling_mscale").get<std::string>());
  rs_mscale_all_dim = std::stof(yalm.metadata.at("rope_scaling_mscale_all_dim").get<std::string>());
  rs_original_max_position_embeddings = std::stoi(yalm.metadata.at("rope_scaling_original_max_position_embeddings").get<std::string>());
}

std::optional<QTensor> check_tensor(const Tensor* tensor, Quant weight_quant, std::array<int, 4> shape, const int debug_line) {
  if (tensor == nullptr) {
    std::cerr << "FATAL: missing tensor at line " << debug_line << std::endl;
    assert(false);
    return std::nullopt;
  }
  return QTensor::from_codec_tensor(*tensor, weight_quant, shape, debug_line);
};

const Tensor* get_tensor(const YALMData& yalm, const std::string& key) {
  auto it = yalm.tensors.find(key);
  if (it == yalm.tensors.end()) {
    std::cerr << "FATAL: missing tensor: " << key << std::endl;
    assert(false);
    return nullptr;
  }
  const Tensor& tensor = it->second;
  return &tensor;
};

Block::Block(
  int layer_i,
  const std::shared_ptr<Config> config,
  const Tensor* rms_att_weight,
  const Tensor* rms_ffn_weight,
  const Tensor* w1,
  const Tensor* s1,
  const Tensor* w2,
  const Tensor* s2,
  const Tensor* w3,
  const Tensor* s3,
  const Tensor* shared_w1,
  const Tensor* shared_s1,
  const Tensor* shared_w2,
  const Tensor* shared_s2,
  const Tensor* shared_w3,
  const Tensor* shared_s3,
  const Tensor* moegate,
  const Tensor* moegate_bias
) : _layer_i(layer_i), _config(config) {
  switch (config->weight_quant) {
    case Quant::F32:
    case Quant::F16:
    case Quant::F8E5M2:
    case Quant::Q2_K:
    case Quant::Q3_K: {
      break;
    }
    default: {
      std::cerr << "FATAL: unsupported weight quantization: " << quant_to_string(config->weight_quant) << std::endl;
      assert(false);
      break;
    }
  }

  _rms_att_weight = check_tensor(
    rms_att_weight, Quant::F32, {config->dim, 0, 0, 0}, __LINE__
  );
  _rms_ffn_weight = check_tensor(
    rms_ffn_weight, Quant::F32, {config->dim, 0, 0, 0}, __LINE__
  );

  bool need_block_scales = _config->weight_quant == Quant::F8E5M2;
  int b0 = config->block_size[0];
  int b1 = config->block_size[1];

  if (config->n_routed_experts > 0 && layer_i >= config->first_k_dense_replace) {
    _moegate = check_tensor(
      moegate, Quant::F32, {config->n_routed_experts, config->dim, 0, 0}, __LINE__
    );
    if (moegate_bias != nullptr) {
      _moegate_bias = check_tensor(
        moegate_bias, Quant::F32, {config->n_routed_experts, 0, 0, 0}, __LINE__
      );
    }
    _w1 = check_tensor(
      w1, config->weight_quant, {config->n_routed_experts, config->moe_intermediate_size, config->dim, 0}, __LINE__
    );
    _w2 = check_tensor(
      w2, config->weight_quant, {config->n_routed_experts, config->dim, config->moe_intermediate_size, 0}, __LINE__
    );
    _w3 = check_tensor(
      w3, config->weight_quant, {config->n_routed_experts, config->moe_intermediate_size, config->dim, 0}, __LINE__
    );
    if (need_block_scales) {
      _s1 = check_tensor(
        s1, Quant::F32,
        {config->n_routed_experts, cdiv(config->moe_intermediate_size, b0), cdiv(config->dim, b1), 0},
        __LINE__
      );
      _s2 = check_tensor(
        s2, Quant::F32,
        {config->n_routed_experts, cdiv(config->dim, b0), cdiv(config->moe_intermediate_size, b1), 0},
        __LINE__
      );
      _s3 = check_tensor(
        s3, Quant::F32,
        {config->n_routed_experts, cdiv(config->moe_intermediate_size, b0), cdiv(config->dim, b1), 0},
        __LINE__
      );
    }
    if (config->n_shared_experts > 0) {
      _shared_w1 = check_tensor(
        shared_w1, config->weight_quant, {config->n_shared_experts * config->moe_intermediate_size, config->dim, 0}, __LINE__
      );
      _shared_w2 = check_tensor(
        shared_w2, config->weight_quant, {config->dim, config->n_shared_experts * config->moe_intermediate_size, 0}, __LINE__
      );
      _shared_w3 = check_tensor(
        shared_w3, config->weight_quant, {config->n_shared_experts * config->moe_intermediate_size, config->dim, 0}, __LINE__
      );
      if (need_block_scales) {
        _shared_s1 = check_tensor(
          shared_s1, Quant::F32,
          {cdiv(config->n_shared_experts * config->moe_intermediate_size, b0), cdiv(config->dim, b1), 0},
          __LINE__
        );
        _shared_s2 = check_tensor(
          shared_s2, Quant::F32,
          {cdiv(config->dim, b0), cdiv(config->n_shared_experts * config->moe_intermediate_size, b1), 0},
          __LINE__
        );
        _shared_s3 = check_tensor(
          shared_s3, Quant::F32,
          {cdiv(config->n_shared_experts * config->moe_intermediate_size, b0), cdiv(config->dim, b1), 0},
          __LINE__
        );
      }
    }
  } else {
    _w1 = check_tensor(
      w1, config->weight_quant, {config->hidden_dim, config->dim, 0, 0}, __LINE__
    );
    _w2 = check_tensor(
      w2, config->weight_quant, {config->dim, config->hidden_dim, 0, 0}, __LINE__
    );
    _w3 = check_tensor(
      w3, config->weight_quant, {config->hidden_dim, config->dim, 0, 0}, __LINE__
    );
    if (need_block_scales) {
      _s1 = check_tensor(
        s1, Quant::F32,
        {cdiv(config->hidden_dim, b0), cdiv(config->dim, b1), 0, 0},
        __LINE__
      );
      _s2 = check_tensor(
        s2, Quant::F32,
        {cdiv(config->dim, b0), cdiv(config->hidden_dim, b1), 0, 0},
        __LINE__
      );
      _s3 = check_tensor(
        s3, Quant::F32,
        {cdiv(config->hidden_dim, b0), cdiv(config->dim, b1), 0, 0},
        __LINE__
      );
    }
  }
}

Block::~Block() {}

void Block::block(
  InferenceState& s,
  int pos,
  int kv_sink,
  int kv_pos,
  int kv_len
) const {
  if (_device == Device::CPU) {
    switch (_config->weight_quant) {
      case Quant::F32:
        _block_cpu<float>(s, pos, kv_sink, kv_pos, kv_len);
        break;
      case Quant::F16:
#if defined(__AVX2__) && defined(__F16C__)
        _block_cpu<f16_t>(s, pos, kv_sink, kv_pos, kv_len);
#else
        assert(false && "float16 not supported on this platform");
#endif
        break;
      case Quant::F8E5M2:
        _block_cpu<f8e5m2_t>(s, pos, kv_sink, kv_pos, kv_len);
        break;
      case Quant::Q2_K:
        _block_cpu<block_q2_K>(s, pos, kv_sink, kv_pos, kv_len);
        break;
      case Quant::Q3_K:
        _block_cpu<block_q3_K>(s, pos, kv_sink, kv_pos, kv_len);
        break;
      default:
        assert(false && "unsupported weight quantization for cpu");
    }
  }
}

double Block::active_bytes(size_t pos) const {
  double bytes_per_weight = bits_per_weight(_config->weight_quant, _config->block_size[0] * _config->block_size[1]) / 8.0;

  double bytes = 0;
  bytes += _rms_att_weight->size;
  bytes += _rms_ffn_weight->size;
  if (_config->n_routed_experts > 0 && _w1->ndim() == 3) {
    bytes += _moegate->size;
    if (_moegate_bias) bytes += _moegate_bias->size;
    // bytes_per_weight accounts for scales and other quantization schemes
    bytes += _config->n_active_routed * 3 * _config->dim * _config->moe_intermediate_size * bytes_per_weight; // w1, w2, w3
  } else {
    bytes += _w1->size + _w2->size + _w3->size; // w1, w2, w3
    if (_s1) {
      bytes += _s1->size;
      bytes += _s2->size;
      bytes += _s3->size;
    }
  }
  if (_config->n_shared_experts > 0) {
    if (_shared_w1) bytes += _shared_w1->size;
    if (_shared_s1) bytes += _shared_s1->size;
    if (_shared_w2) bytes += _shared_w2->size;
    if (_shared_s2) bytes += _shared_s2->size;
    if (_shared_w3) bytes += _shared_w3->size;
    if (_shared_s3) bytes += _shared_s3->size;
  }
  return bytes;
}

BlockMHA::BlockMHA(
  int layer_i,
  const std::shared_ptr<Config> config,
  const Tensor* rms_att_weight,
  const Tensor* rms_q_a_weight,
  const Tensor* rms_kv_a_weight,
  const Tensor* rms_ffn_weight,
  const Tensor* wq,
  const Tensor* sq,
  const Tensor* wq_a,
  const Tensor* sq_a,
  const Tensor* wkv_a,
  const Tensor* skv_a,
  const Tensor* wq_b,
  const Tensor* sq_b,
  const Tensor* wkv_b,
  const Tensor* skv_b,
  const Tensor* wo,
  const Tensor* so,
  const Tensor* w1,
  const Tensor* s1,
  const Tensor* w2,
  const Tensor* s2,
  const Tensor* w3,
  const Tensor* s3,
  const Tensor* shared_w1,
  const Tensor* shared_s1,
  const Tensor* shared_w2,
  const Tensor* shared_s2,
  const Tensor* shared_w3,
  const Tensor* shared_s3,
  const Tensor* moegate,
  const Tensor* moegate_bias
) : Block(layer_i, config, rms_att_weight, rms_ffn_weight, w1, s1, w2, s2, w3, s3, shared_w1, shared_s1, shared_w2, shared_s2, shared_w3, shared_s3, moegate, moegate_bias) {

  bool need_block_scales = _config->weight_quant == Quant::F8E5M2;
  int b0 = config->block_size[0];
  int b1 = config->block_size[1];

  if (config->q_lora_rank > 0) {
    _rms_q_a_weight = check_tensor(
      rms_q_a_weight, Quant::F32, {config->q_lora_rank, 0, 0, 0}, __LINE__
    );
    _wq_a = check_tensor(
      wq_a, config->weight_quant, {config->q_lora_rank, config->dim, 0, 0}, __LINE__
    );
    _wq_b = check_tensor(
      wq_b, config->weight_quant, {config->n_heads * config->head_dim, config->q_lora_rank, 0, 0}, __LINE__
    );
    if (need_block_scales) {
      _sq_a = check_tensor(
        sq_a, Quant::F32,
        {cdiv(config->q_lora_rank, b0), cdiv(config->dim, b1), 0, 0},
        __LINE__
      );
      _sq_b = check_tensor(
        sq_b, Quant::F32,
        {cdiv(config->n_heads * config->head_dim, b0), cdiv(config->q_lora_rank, b1), 0, 0},
        __LINE__
      );
    }
  } else {
    _wq = check_tensor(
      wq, config->weight_quant, {config->n_heads * config->head_dim, config->dim, 0, 0}, __LINE__
    );
    if (need_block_scales) {
      _sq = check_tensor(
        sq, Quant::F32,
        {cdiv(config->n_heads * config->head_dim, b0), cdiv(config->dim, b1), 0, 0},
        __LINE__
      );
    }
  }

  _rms_kv_a_weight = check_tensor(
    rms_kv_a_weight, Quant::F32, {config->kv_lora_rank, 0, 0, 0}, __LINE__ // Assuming kv_lora_rank is correct size here for MHA norm too
  );
  _wkv_a = check_tensor(
    wkv_a, config->weight_quant, {config->kv_lora_rank + config->qk_rope_head_dim, config->dim, 0, 0}, __LINE__
  );
  _wkv_b = check_tensor(
    wkv_b, config->weight_quant, {config->n_heads * (config->head_dim-config->qk_rope_head_dim+config->v_head_dim), config->kv_lora_rank, 0, 0}, __LINE__
  );
  _wo = check_tensor(
    wo, config->weight_quant, {config->dim, config->n_heads * config->v_head_dim, 0, 0}, __LINE__
  );

  if (need_block_scales) {
    _skv_a = check_tensor(
      skv_a, Quant::F32,
      {cdiv(config->kv_lora_rank + config->qk_rope_head_dim, b0), cdiv(config->dim, b1), 0, 0},
      __LINE__
    );
    _skv_b = check_tensor(
      skv_b, Quant::F32,
      {cdiv(config->n_heads * (config->head_dim-config->qk_rope_head_dim+config->v_head_dim), b0), cdiv(config->kv_lora_rank, b1), 0, 0},
      __LINE__
    );
    _so = check_tensor(
      so, Quant::F32,
      {cdiv(config->dim, b0), cdiv(config->n_heads * config->v_head_dim, b1), 0, 0},
      __LINE__
    );
  }

  _key_cache = new f16_t[config->max_seq_len * config->n_heads * config->head_dim]();
  _value_cache = new f16_t[config->max_seq_len * config->n_heads * config->v_head_dim]();
}

BlockMHA::~BlockMHA() {
  if (_device == Device::CPU) {
    delete[] _key_cache;
    delete[] _value_cache;
  }
}

double BlockMHA::active_bytes(size_t pos) const {
  double bytes = Block::active_bytes(pos);
  
  // Add active bytes for attention and KV cache
  if (_wq) bytes += _wq->size;
  if (_sq) bytes += _sq->size;
  if (_wq_a) bytes += _wq_a->size;
  if (_sq_a) bytes += _sq_a->size;
  if (_wkv_a) bytes += _wkv_a->size;
  if (_skv_a) bytes += _skv_a->size;
  if (_wo) bytes += _wo->size;
  if (_so) bytes += _so->size;
  if (_wq_b) bytes += _wq_b->size;
  if (_sq_b) bytes += _sq_b->size;
  if (_wkv_b) bytes += _wkv_b->size;
  if (_skv_b) bytes += _skv_b->size;
  
  size_t kv_len = std::min(static_cast<size_t>(_config->max_seq_len), pos + 1);
  size_t kv_entry_size = sizeof(f16_t);
  bytes += 2 * kv_len * _config->n_heads * _config->head_dim * kv_entry_size; // key_cache, value_cache
  return bytes;
}

void BlockMHA::attention_impl(
  InferenceState& s, int pos, int kv_sink, int kv_pos, int kv_len
) const {
  switch (_config->weight_quant) {
    case Quant::F32:
      _attention_impl<float>(s, pos, kv_sink, kv_pos, kv_len);
      break;
    case Quant::F16:
      _attention_impl<f16_t>(s, pos, kv_sink, kv_pos, kv_len);
      break;
    case Quant::F8E5M2:
      _attention_impl<f8e5m2_t>(s, pos, kv_sink, kv_pos, kv_len);
      break;
    case Quant::Q2_K:
      _attention_impl<block_q2_K>(s, pos, kv_sink, kv_pos, kv_len);
      break;
    case Quant::Q3_K:
      _attention_impl<block_q3_K>(s, pos, kv_sink, kv_pos, kv_len);
      break;
    default:
      assert(false && "unsupported weight quantization for mha");
  }
}


BlockMLA::BlockMLA(
  int layer_i,
  const std::shared_ptr<Config> config,
  const Tensor* rms_att_weight,
  const Tensor* rms_q_a_weight,
  const Tensor* rms_kv_a_weight,
  const Tensor* rms_ffn_weight,
  const Tensor* wq_a,
  const Tensor* sq_a,
  const Tensor* wkv_a,
  const Tensor* skv_a,
  const Tensor* wo,
  const Tensor* so,
  const Tensor* wc,
  const Tensor* sc,
  const Tensor* wq_rope_b,
  const Tensor* sq_rope_b,
  const Tensor* wv_b,
  const Tensor* sv_b,
  const Tensor* w1,
  const Tensor* s1,
  const Tensor* w2,
  const Tensor* s2,
  const Tensor* w3,
  const Tensor* s3,
  const Tensor* shared_w1,
  const Tensor* shared_s1,
  const Tensor* shared_w2,
  const Tensor* shared_s2,
  const Tensor* shared_w3,
  const Tensor* shared_s3,
  const Tensor* moegate,
  const Tensor* moegate_bias
) : Block(layer_i, config, rms_att_weight, rms_ffn_weight, w1, s1, w2, s2, w3, s3, shared_w1, shared_s1, shared_w2, shared_s2, shared_w3, shared_s3, moegate, moegate_bias) {

  bool need_block_scales = _config->weight_quant == Quant::F8E5M2;
  int b0 = config->block_size[0];
  int b1 = config->block_size[1];

  _rms_q_a_weight = check_tensor(
    rms_q_a_weight, Quant::F32, {config->q_lora_rank, 0, 0, 0}, __LINE__
  );
  _rms_kv_a_weight = check_tensor(
    rms_kv_a_weight, Quant::F32, {config->kv_lora_rank, 0, 0, 0}, __LINE__ // Only norm latent part
  );

  _wq_a = check_tensor(
    wq_a, config->weight_quant, {config->q_lora_rank, config->dim, 0, 0}, __LINE__
  );
  _wkv_a = check_tensor(
    wkv_a, config->weight_quant, {config->kv_lora_rank + config->qk_rope_head_dim, config->dim, 0, 0}, __LINE__
  );
  _wc = check_tensor(
    wc, config->weight_quant, {config->n_heads * config->kv_lora_rank, config->q_lora_rank, 0, 0}, __LINE__
  );
  _wq_rope_b = check_tensor(
    wq_rope_b, config->weight_quant, {config->n_heads * config->qk_rope_head_dim, config->q_lora_rank, 0, 0}, __LINE__
  );
  _wv_b = check_tensor(
      wv_b, config->weight_quant, {config->n_heads * config->v_head_dim, config->kv_lora_rank, 0, 0}, __LINE__
  );
  // Reshape _wv_b from 2D to 3D
  _wv_b = QTensor(_wv_b->quant, {config->n_heads, config->v_head_dim, config->kv_lora_rank}, _wv_b->data, _wv_b->size);
  _wo = check_tensor(
    wo, config->weight_quant, {config->dim, config->n_heads * config->v_head_dim, 0, 0}, __LINE__
  );

  if (need_block_scales) {
    _sq_a = check_tensor(
      sq_a, Quant::F32,
      {cdiv(config->q_lora_rank, b0), cdiv(config->dim, b1), 0, 0},
      __LINE__
    );
    _skv_a = check_tensor(
      skv_a, Quant::F32,
      {cdiv(config->kv_lora_rank + config->qk_rope_head_dim, b0), cdiv(config->dim, b1), 0, 0},
      __LINE__
    );
    _sc = check_tensor(
      sc, Quant::F32,
      {cdiv(config->n_heads * config->kv_lora_rank, b0), cdiv(config->q_lora_rank, b1), 0, 0},
      __LINE__
    );
    _sq_rope_b = check_tensor(
      sq_rope_b, Quant::F32,
      {cdiv(config->n_heads * config->qk_rope_head_dim, b0), cdiv(config->q_lora_rank, b1), 0, 0},
      __LINE__
    );
    _sv_b = check_tensor(
      sv_b, Quant::F32,
      {cdiv(config->n_heads * config->v_head_dim, b0), cdiv(config->kv_lora_rank, b1), 0, 0},
      __LINE__
    );
    _so = check_tensor(
      so, Quant::F32,
      {cdiv(config->dim, b0), cdiv(config->n_heads * config->v_head_dim, b1), 0, 0},
      __LINE__
    );
  }

  _kv_nope_cache = new f16_t[config->max_seq_len * config->kv_lora_rank]();
  _kv_rope_cache = new f16_t[config->max_seq_len * config->qk_rope_head_dim]();
}

BlockMLA::~BlockMLA() {
  if (_device == Device::CPU) {
    delete[] _kv_nope_cache;
    delete[] _kv_rope_cache;
  }
}

double BlockMLA::active_bytes(size_t pos) const {
  double bytes = Block::active_bytes(pos);

  bytes += _rms_q_a_weight->size;
  bytes += _rms_kv_a_weight->size;
  if (_wq_a) bytes += _wq_a->size;
  if (_sq_a) bytes += _sq_a->size;
  if (_wkv_a) bytes += _wkv_a->size;
  if (_skv_a) bytes += _skv_a->size;
  if (_wo) bytes += _wo->size;
  if (_so) bytes += _so->size;
  if (_wc) bytes += _wc->size;
  if (_sc) bytes += _sc->size;
  if (_wq_rope_b) bytes += _wq_rope_b->size;
  if (_sq_rope_b) bytes += _sq_rope_b->size;
  if (_wv_b) bytes += _wv_b->size;
  if (_sv_b) bytes += _sv_b->size;
  size_t kv_len = std::min(static_cast<size_t>(_config->max_seq_len), pos + 1);
  size_t kv_entry_size = sizeof(f16_t);
  bytes += kv_len * _config->kv_lora_rank * kv_entry_size; // kv_nope_cache
  bytes += kv_len * _config->qk_rope_head_dim * kv_entry_size; // kv_rope_cache
  return bytes;
}

void BlockMLA::attention_impl(
  InferenceState& s, int pos, int kv_sink, int kv_pos, int kv_len
) const {
  switch (_config->weight_quant) {
    case Quant::F32:
      _attention_impl<float>(s, pos, kv_sink, kv_pos, kv_len);
      break;
    case Quant::F16:
      _attention_impl<f16_t>(s, pos, kv_sink, kv_pos, kv_len);
      break;
    case Quant::F8E5M2:
      _attention_impl<f8e5m2_t>(s, pos, kv_sink, kv_pos, kv_len);
      break;
    case Quant::Q2_K:
      _attention_impl<block_q2_K>(s, pos, kv_sink, kv_pos, kv_len);
      break;
    case Quant::Q3_K:
      _attention_impl<block_q3_K>(s, pos, kv_sink, kv_pos, kv_len);
      break;
    default:
      assert(false && "unsupported weight quantization for mla");
  }
}

InferenceState::InferenceState(const std::shared_ptr<Config> config): 
  _config(config) {
  assert(config);
  _x = new float[config->dim]();
  _xb = new float[config->dim]();
  _xb2 = new float[std::max({
    config->dim, 
    config->n_heads * config->v_head_dim, 
    config->n_heads * config->kv_lora_rank
  })]();
  _hb = new float[std::max({
    config->dim, 
    config->hidden_dim
  })]();
  _hb2 = new float[config->hidden_dim]();
  if (config->q_lora_rank > 0) {
    _q_a = new float[config->q_lora_rank]();
  }
  _q = new float[config->n_heads * config->head_dim]();
  _kv_a = new float[config->kv_lora_rank + config->qk_rope_head_dim]();
  _kv_b = new float[config->n_heads * (config->head_dim-config->qk_rope_head_dim+config->v_head_dim)]();
  _ropebuf = new float[config->n_heads * config->qk_rope_head_dim]();
  _k = new float[config->n_heads * config->head_dim]();
  _v = new float[config->n_heads * config->v_head_dim]();
  _att = new float[config->n_heads * config->max_seq_len]();
  _logits = new float[config->vocab_size]();
  _logit_indices = new int[config->vocab_size]();
  for (int i = 0; i < config->vocab_size; i++){
    _logit_indices[i] = i;
  }
  if (config->use_mla) {
    _q_c = new float[config->n_heads * config->kv_lora_rank]();
    _q_rope = new float[config->n_heads * config->qk_rope_head_dim]();
  }
  if (config->n_routed_experts > 0) {
    _moe_weights = new float[config->n_routed_experts]();
    _active_experts = new int[config->n_active_routed]();
    _active_experts_weights = new float[config->n_active_routed]();
  }
  // TODO: consider dynamically resizing based on inputs
  size_t aqb_nitems = std::max({
    config->dim,
    config->moe_intermediate_size,
    config->n_heads * config->v_head_dim,
    config->n_heads * config->kv_lora_rank,
    config->hidden_dim
  });
  size_t aqb_nblocks = aqb_nitems / QK_K;
  _aqb = new uint8_t[aqb_nblocks * sizeof(block_q8_K)]();
}

InferenceState::~InferenceState() {
  if (_device == Device::CPU) {
    delete[] _x;
    delete[] _xb;
    delete[] _xb2;
    delete[] _hb;
    delete[] _hb2;
    if (_q_a != nullptr) {
      delete[] _q_a;
    }
    delete[] _q;
    delete[] _kv_a;
    delete[] _kv_b;
    delete[] _ropebuf;
    delete[] _k;
    delete[] _v;
    delete[] _att;
    delete[] _logits;
    delete[] _logit_indices;
    if (_moe_weights != nullptr) {
      delete[] _moe_weights;
      delete[] _active_experts;
      delete[] _active_experts_weights;
    }
    delete[] _aqb;
  }
}

Model::Model(YALMData& yalm, int context) {
  config = std::make_shared<Config>();
  config->from_yalm(yalm, context);
  std::cout << "loading model with quant: " << quant_to_string(config->weight_quant) << std::endl;

  bool need_weight_scales = config->weight_quant == Quant::F8E5M2;
  bool use_mla = config->use_mla;
  int b0 = config->block_size[0];
  int b1 = config->block_size[1];

  token_embedding_table = check_tensor(
    get_tensor(yalm, "model.embed.weight"),
    config->weight_quant,
    {config->vocab_size, config->dim, 0, 0},
    __LINE__
  );
  if (need_weight_scales) {
    token_embedding_scale = check_tensor(
      get_tensor(yalm, "model.embed.scale"),
      Quant::F32,
      {cdiv(config->vocab_size, b0), cdiv(config->dim, b1), 0, 0},
      __LINE__
    );
  }

  for (int i = 0; i < config->n_layers; ++i) {
    const Tensor* p_rms_att_weight = get_tensor(yalm, fmt::format("model.layers.{}.attn.norm.weight", i));
    const Tensor* p_rms_ffn_weight = get_tensor(yalm, fmt::format("model.layers.{}.mlp.norm.weight", i));
    const Tensor* p_w1 = get_tensor(yalm, fmt::format("model.layers.{}.mlp.w1.weight", i));
    const Tensor* p_s1 = need_weight_scales ? get_tensor(yalm, fmt::format("model.layers.{}.mlp.w1.scale", i)) : nullptr;
    const Tensor* p_w2 = get_tensor(yalm, fmt::format("model.layers.{}.mlp.w2.weight", i));
    const Tensor* p_s2 = need_weight_scales ? get_tensor(yalm, fmt::format("model.layers.{}.mlp.w2.scale", i)) : nullptr;
    const Tensor* p_w3 = get_tensor(yalm, fmt::format("model.layers.{}.mlp.w3.weight", i));
    const Tensor* p_s3 = need_weight_scales ? get_tensor(yalm, fmt::format("model.layers.{}.mlp.w3.scale", i)) : nullptr;
    const Tensor* p_shared_w1 = (i >= config->first_k_dense_replace && config->n_shared_experts > 0) ? get_tensor(yalm, fmt::format("model.layers.{}.shared_mlp.w1.weight", i)) : nullptr;
    const Tensor* p_shared_s1 = (need_weight_scales && i >= config->first_k_dense_replace && config->n_shared_experts > 0) ? get_tensor(yalm, fmt::format("model.layers.{}.shared_mlp.w1.scale", i)) : nullptr;
    const Tensor* p_shared_w2 = (i >= config->first_k_dense_replace && config->n_shared_experts > 0) ? get_tensor(yalm, fmt::format("model.layers.{}.shared_mlp.w2.weight", i)) : nullptr;
    const Tensor* p_shared_s2 = (need_weight_scales && i >= config->first_k_dense_replace && config->n_shared_experts > 0) ? get_tensor(yalm, fmt::format("model.layers.{}.shared_mlp.w2.scale", i)) : nullptr;
    const Tensor* p_shared_w3 = (i >= config->first_k_dense_replace && config->n_shared_experts > 0) ? get_tensor(yalm, fmt::format("model.layers.{}.shared_mlp.w3.weight", i)) : nullptr;
    const Tensor* p_shared_s3 = (need_weight_scales && i >= config->first_k_dense_replace && config->n_shared_experts > 0) ? get_tensor(yalm, fmt::format("model.layers.{}.shared_mlp.w3.scale", i)) : nullptr;
    const Tensor* p_moegate = (i >= config->first_k_dense_replace && config->n_routed_experts > 0) ? get_tensor(yalm, fmt::format("model.layers.{}.moegate.weight", i)) : nullptr;
    const Tensor* p_moegate_bias = (i >= config->first_k_dense_replace && config->n_routed_experts > 0 && config->has_moegate_bias) ? get_tensor(yalm, fmt::format("model.layers.{}.moegate.bias", i)) : nullptr;

    const Tensor* p_rms_q_a_weight = config->q_lora_rank > 0 ? get_tensor(yalm, fmt::format("model.layers.{}.attn.q_a_norm.weight", i)) : nullptr;
    const Tensor* p_rms_kv_a_weight = get_tensor(yalm, fmt::format("model.layers.{}.attn.kv_a_norm.weight", i));
    const Tensor* p_wq_a = config->q_lora_rank > 0 ? get_tensor(yalm, fmt::format("model.layers.{}.attn.wq_a.weight", i)) : nullptr;
    const Tensor* p_sq_a = (need_weight_scales && config->q_lora_rank > 0) ? get_tensor(yalm, fmt::format("model.layers.{}.attn.wq_a.scale", i)) : nullptr;
    const Tensor* p_wkv_a = get_tensor(yalm, fmt::format("model.layers.{}.attn.wkv_a.weight", i));
    const Tensor* p_skv_a = need_weight_scales ? get_tensor(yalm, fmt::format("model.layers.{}.attn.wkv_a.scale", i)) : nullptr;
    const Tensor* p_wo = get_tensor(yalm, fmt::format("model.layers.{}.attn.wo.weight", i));
    const Tensor* p_so = need_weight_scales ? get_tensor(yalm, fmt::format("model.layers.{}.attn.wo.scale", i)) : nullptr;


    if (use_mla) {
      const Tensor* p_wc = get_tensor(yalm, fmt::format("model.layers.{}.attn.wc.weight", i));
      const Tensor* p_sc = need_weight_scales ? get_tensor(yalm, fmt::format("model.layers.{}.attn.wc.scale", i)) : nullptr;
      const Tensor* p_wq_rope_b = get_tensor(yalm, fmt::format("model.layers.{}.attn.wq_rope_b.weight", i));
      const Tensor* p_sq_rope_b = need_weight_scales ? get_tensor(yalm, fmt::format("model.layers.{}.attn.wq_rope_b.scale", i)) : nullptr;
      const Tensor* p_wv_b = get_tensor(yalm, fmt::format("model.layers.{}.attn.wv_b.weight", i));
      const Tensor* p_sv_b = need_weight_scales ? get_tensor(yalm, fmt::format("model.layers.{}.attn.wv_b.scale", i)) : nullptr;

      blocks.emplace_back(std::make_shared<BlockMLA>(
        i, config,
        p_rms_att_weight, p_rms_q_a_weight, p_rms_kv_a_weight, p_rms_ffn_weight,
        p_wq_a, p_sq_a, p_wkv_a, p_skv_a, p_wo, p_so,
        p_wc, p_sc, p_wq_rope_b, p_sq_rope_b, p_wv_b, p_sv_b,
        p_w1, p_s1, p_w2, p_s2, p_w3, p_s3,
        p_shared_w1, p_shared_s1, p_shared_w2, p_shared_s2, p_shared_w3, p_shared_s3,
        p_moegate, p_moegate_bias
      ));
    } else {
      const Tensor* p_wq = config->q_lora_rank == 0 ? get_tensor(yalm, fmt::format("model.layers.{}.attn.wq.weight", i)) : nullptr;
      const Tensor* p_sq = (need_weight_scales && config->q_lora_rank == 0) ? get_tensor(yalm, fmt::format("model.layers.{}.attn.wq.scale", i)) : nullptr;
      const Tensor* p_wq_b = config->q_lora_rank > 0 ? get_tensor(yalm, fmt::format("model.layers.{}.attn.wq_b.weight", i)) : nullptr;
      const Tensor* p_sq_b = (need_weight_scales && config->q_lora_rank > 0) ? get_tensor(yalm, fmt::format("model.layers.{}.attn.wq_b.scale", i)) : nullptr;
      const Tensor* p_wkv_b = get_tensor(yalm, fmt::format("model.layers.{}.attn.wkv_b.weight", i));
      const Tensor* p_skv_b = need_weight_scales ? get_tensor(yalm, fmt::format("model.layers.{}.attn.wkv_b.scale", i)) : nullptr;

      blocks.emplace_back(std::make_shared<BlockMHA>(
        i, config,
        p_rms_att_weight, p_rms_q_a_weight, p_rms_kv_a_weight, p_rms_ffn_weight,
        p_wq, p_sq, p_wq_a, p_sq_a, p_wkv_a, p_skv_a,
        p_wq_b, p_sq_b, p_wkv_b, p_skv_b, p_wo, p_so,
        p_w1, p_s1, p_w2, p_s2, p_w3, p_s3,
        p_shared_w1, p_shared_s1, p_shared_w2, p_shared_s2, p_shared_w3, p_shared_s3,
        p_moegate, p_moegate_bias
      ));
    }
  }

  rms_final_weight = check_tensor(
    get_tensor(yalm, "model.norm.weight"),
    Quant::F32,
    {config->dim, 0, 0, 0},
    __LINE__
  );
  bool tie_word_embeddings = yalm.tensors.count("model.output.weight") == 0;
  if (tie_word_embeddings) {
    wcls = token_embedding_table;
    scls = token_embedding_scale;
  } else {
    wcls = check_tensor(
      get_tensor(yalm, "model.output.weight"),
      config->weight_quant,
      {config->vocab_size, config->dim, 0, 0},
      __LINE__
    );
    if (need_weight_scales) {
      scls = check_tensor(
        get_tensor(yalm, "model.output.scale"),
        Quant::F32,
        {cdiv(config->vocab_size, b0), cdiv(config->dim, b1), 0, 0},
        __LINE__
      );
    }
  }
}

void Model::forward(InferenceState& s, int token, int pos, InferenceMode mode) {
  if (s.device() != _device) {
    std::cerr << "FATAL: inference state device mismatch" << std::endl;
    assert(false);
    return;
  }
  if (_device == Device::CPU) {
    _forward_cpu(s, token, pos, mode);
  }
}

double Model::active_bytes(size_t pos) const {
  double bytes_per_weight = bits_per_weight(config->weight_quant, config->block_size[0] * config->block_size[1]) / 8.0;

  double bytes = 0;
  bytes += config->dim * bytes_per_weight; // 1 row of token_embedding_table
  // blocks
  for (auto& block : blocks) {
    bytes += block->active_bytes(pos);
  }
  bytes += rms_final_weight->size;
  bytes += wcls->size;
  if (scls) {
    bytes += scls->size;
  }

  return bytes;
}

#if DEBUG_MODEL
DebugTensor::DebugTensor(const std::vector<float>& data) {
  data_f32 = data;
  data_type = DataType::F32;
}
DebugTensor::DebugTensor(const std::vector<f16_t>& data) {
  data_f16 = data;
  data_type = DataType::F16;
}

float DebugTensor::max_err(const DebugTensor& other) const {
  if (data_type != other.data_type) {
    return -1;
  }
  if (data_type == DataType::F32) {
    float max_err = 0;
    for (size_t i = 0; i < data_f32.size(); i++) {
      max_err = std::max(max_err, std::abs(data_f32[i] - other.data_f32[i]));
    }
    return max_err;
  } else {
#if defined(__F16C__)
    float max_err = 0;
    for (size_t i = 0; i < data_f16.size(); i++) {
      max_err = std::max(max_err, std::abs(_cvtsh_ss(data_f16[i]) - _cvtsh_ss(other.data_f16[i])));
    }
    return max_err;
#else
  assert(false && "float16 not supported on this platform");
#endif
  }
}
#endif

================================================
FILE: src/model.h
================================================
#pragma once

#include "codec.h"

#include <memory>
#include <vector>
#include <map>
#include <optional>

#include "quant.h"

#define DEBUG_MODEL 0

constexpr int KV_SINKS = 2;

enum class ActivationType {
  GELU,
  SILU,
};

enum class LayerNormType {
  RMSNorm,
};

enum class TopKMethod {
  GREEDY,
  GROUP_LIMITED_GREEDY,
  NOAUX_TC,
};

enum class ScoringFunc {
  SOFTMAX,
  SIGMOID,
};

enum class Device {
  CPU,
};

enum class InferenceMode {
  HYDRATE_KV_CACHE, // only hydrate the KV cache and don't compute output logits
  OUTPUT_LOGITS // set InferenceState logits to logits for the next token
};

int cdiv(int a, int b);

struct Config {
  int dim;                  // transformer input & output dimension
  int hidden_dim;           // dimension of hidden layer in feedforward network (dense blocks only)
  int n_layers;             // number of layers
  int n_heads;              // number of attention heads
  int vocab_size;           // vocabulary size
  int max_seq_len;          // max sequence length
  float rope_theta;         // RoPE theta
  float norm_eps;           // epsilon for layer normalization
  ActivationType act;       // activation function
  LayerNormType norm_type;  // norm type
  int first_k_dense_replace; // how many blocks to keep the dense FFN (when sparse MoE is default)
  // mixture of experts
  int n_shared_experts;
  int n_routed_experts;
  int n_active_routed;
  int moe_intermediate_size;
  float routed_scaling_factor;
  int n_group;
  bool norm_topk_prob;
  ScoringFunc scoring_func;
  int topk_group;
  TopKMethod topk_method;
  bool has_moegate_bias;
  // multi-latent attention
  bool use_mla; // if false, use naive implementation of multi-latent attention
  int kv_lora_rank;
  int q_lora_rank;
  int qk_nope_head_dim;
  int qk_rope_head_dim;
  int v_head_dim;
  int head_dim;             // dimension of each attention head, equal to qk_nope_head_dim + qk_rope_head_dim
  // Data type of the weights according to config, used
  // to safety check tensor dtype at initialization time.
  Quant weight_quant;
  // Block size for weight quantization if present
  // If weights are quantized but block size is (0, 0), then we are using
  // per-tensor quantization.
  std::array<int, 2> block_size = {0, 0};
  // RoPE scaling
  int rs_beta_fast;
  int rs_beta_slow;
  float rs_factor;
  float rs_mscale;
  float rs_mscale_all_dim;
  int rs_original_max_position_embeddings;

  // If nonzero `context` is supplied, max sequence length is limited to `context`.
  void from_yalm(YALMData& yalm, int context = 0);
};

// Buffer for all state used during a forward pass.
// Members are reused across subsequent blocks and passes.
// This lets us avoid allocations during inference.
struct InferenceState {
  InferenceState(const std::shared_ptr<Config> config);
  ~InferenceState();

  // current activations
  float* x() const { return _x; }
  float* xb() const { return _xb; }
  float* xb(int head) const { return _xb + _config->head_dim * head; }
  // TODO: do we need xb2?
  float* xb2() const { return _xb2; }
  float* xb2(int head, int head_size) const { return _xb2 + head_size * head; }
  float* hb() const { return _hb; }
  float* hb2() const { return _hb2; }
  float* q_a() const { return _q_a; }
  float* q() const { return _q; }
  float* q(int head) const { return _q + _config->head_dim * head; }
  float* kv_a() const { return _kv_a; }
  float* kv_b() const { return _kv_b; }
  float* kv_b(int head) const { return _kv_b + (_config->head_dim - _config->qk_rope_head_dim + _config->v_head_dim) * head; }
  float* ropebuf() const { return _ropebuf; }
  float* k() const { return _k; }
  float* k(int head) const { return _k + _config->head_dim * head; }
  float* v() const { return _v; }
  float* v(int head) const { return _v + _config->v_head_dim * head; }
  float* att() const { return _att; }
  float* att(int head) const { return _att + _config->max_seq_len * head; }
  // MLA only
  float* q_c() const { return _q_c; }
  float* q_c(int head) const { return _q_c + _config->kv_lora_rank * head; }
  float* q_rope() const { return _q_rope; }
  float* q_rope(int head) const { return _q_rope + _config->qk_rope_head_dim * head; }
  // mixture of experts
  float* moe_weights() const { return _moe_weights; }
  float* active_experts_weights() const { return _active_experts_weights; }
  int* active_experts() const { return _active_experts; }
  // LM head
  float* logits() const { return _logits; }
  int* logit_indices() const { return _logit_indices; }
  // activation quantization buffer
  void* aqb() const { return _aqb; }

  Device device() const { return _device; }
  InferenceMode mode() const { return _mode; }
  void set_mode(InferenceMode mode) { _mode = mode; }

private:
  std::shared_ptr<Config> _config;
  Device _device = Device::CPU;
  InferenceMode _mode = InferenceMode::OUTPUT_LOGITS;

  // current activations
  float* _x = nullptr;         // (dim,) - latest activation
  float* _xb = nullptr;        // (dim,) - activation inside a residual branch
  float* _xb2 = nullptr;       // (max{dim, n_heads * v_head_dim, n_heads * kv_lora_rank},) - activation inside a residual branch (second slot)
  float* _hb = nullptr;        // (max{dim, hidden_dim},) - buffer for hidden dimension in feedforward network
  float* _hb2 = nullptr;       // (hidden_dim,) - buffer for hidden dimension in feedforward network (second slot)
  float* _q_a = nullptr;       // (q_lora_rank,) - compressed (latent) query vector for latest timestamp
  float* _q = nullptr;         // (n_heads * head_dim,) - query vectors for latest timestamp
  float* _kv_a = nullptr;      // (kv_lora_rank + qk_rope_head_dim,) - compressed (latent) key-value vector for latest timestamp
  float* _kv_b = nullptr;      // (n_heads * (head_dim-qk_rope_head_dim+v_head_dim),) - uncompressed key-value vector for latest timestamp
  float* _ropebuf = nullptr;   // (n_heads * qk_rope_head_dim,) - buffer for rope
  float* _k = nullptr;         // (n_heads * head_dim,) - key vectors for latest timestamp
  float* _v = nullptr;         // (n_heads * v_head_dim,) - value vectors for latest timestamp
  float* _att = nullptr;       // (n_heads, seq_len) - buffer for attention scores
  // MLA only
  float* _q_c = nullptr;       // (n_heads * kv_lora_rank,) - transformed and compressed query vector for latest timestamp
  float* _q_rope = nullptr;    // (n_heads * qk_rope_head_dim,) - RoPE-transformed query vector for latest timestamp
  // mixture of experts
  float* _moe_weights = nullptr; // (n_routed_experts,) - buffer for expert weights, decided by router
  float* _active_experts_weights = nullptr; // (n_active_experts,) - buffer for weights of top K experts (active experts)
  int* _active_experts = nullptr; // (n_active_experts,) - buffer for indices of top K experts (active experts)
  
  // LM head
  float* _logits = nullptr;    // (vocab_size,) - final output logits
  int* _logit_indices = nullptr; // (vocab_size,) - logit indices (for use by top-p sampler)

  // activation quantization buffer
  uint8_t* _aqb = nullptr; // buffer for quantized activations
};

/* Transformer Block Base */
struct Block {
  Block(
    int layer_i,
    const std::shared_ptr<Config> config,
    const Tensor* rms_att_weight,
    const Tensor* rms_ffn_weight,
    const Tensor* w1,
    const Tensor* s1,
    const Tensor* w2,
    const Tensor* s2,
    const Tensor* w3,
    const Tensor* s3,
    const Tensor* shared_w1,
    const Tensor* shared_s1,
    const Tensor* shared_w2,
    const Tensor* shared_s2,
    const Tensor* shared_w3,
    const Tensor* shared_s3,
    const Tensor* moegate,
    const Tensor* moegate_bias
  );
  virtual ~Block();

  float* rms_att_weight() const { return _rms_att_weight ? static_cast<float*>(_rms_att_weight->data) : nullptr; }
  float* rms_ffn_weight() const { return _rms_ffn_weight ? static_cast<float*>(_rms_ffn_weight->data) : nullptr; }
  std::optional<QTensor> w1() const { return _w1; }
  std::optional<QTensor> w2() const { return _w2; }
  std::optional<QTensor> w3() const { return _w3; }
  std::optional<QTensor> moegate() const { return _moegate; }
  std::optional<QTensor> moegate_bias() const { return _moegate_bias; }
  std::optional<QTensor> shared_w1() const { return _shared_w1; }
  std::optional<QTensor> shared_w2() const { return _shared_w2; }
  std::optional<QTensor> shared_w3() const { return _shared_w3; }

  // Compute forward pass for this block and update the inference state accordingly.
  // PRECONDITIONS:
  // - `s.x()` contains the input to the block. Output will also go here.
  // - Block KV cache is hydrated.
  void block(
    InferenceState& s,  // inference state
    int pos,            // index of the current token in the sequence
    int kv_sink,        // number of sink tokens currently in the KV cache
    int kv_pos,         // index of the current token in the kv cache, must be in [0..kv_len) since kv cache is a ring buffer
    int kv_len          // number of tokens in the kv cache that we will attend over
  ) const;

  virtual double active_bytes(size_t pos) const;

protected:
  virtual void attention_impl(
    InferenceState& s,  // inference state
    int pos,            // index of the current token in the sequence
    int kv_sink,        // number of sink tokens currently in the KV cache
    int kv_pos,         // index of the current token in the kv cache, must be in [0..kv_len) since kv cache is a ring buffer
    int kv_len          // number of tokens in the kv cache that we will attend over
  ) const = 0;

  template <typename T>
  void _block_cpu(
    InferenceState& s,  // inference state
    int pos,            // index of the current token in the sequence
    int kv_sink,        // number of sink tokens currently in the KV cache
    int kv_pos,         // index of the current token in the kv cache, must be in [0..kv_len) since kv cache is a ring buffer
    int kv_len          // number of tokens in the kv cache that we will attend over
  ) const;

  int _layer_i = 0;

  std::shared_ptr<Config> _config;
  Device _device = Device::CPU;

  // weights for norms
  std::optional<QTensor> _rms_att_weight = std::nullopt; // (dim) rmsnorm weights for attention input
  std::optional<QTensor> _rms_ffn_weight = std::nullopt; // (dim) rmsnorm weights for ffn input

  // weights for ffn
  std::optional<QTensor> _w1 = std::nullopt; // (n_routed_experts?, moe_intermediate_size, dim) or (hidden_dim, dim)
  std::optional<QTensor> _s1 = std::nullopt;
  std::optional<QTensor> _w2 = std::nullopt; // (n_routed_experts?, dim, moe_intermediate_size) or (dim, hidden_dim)
  std::optional<QTensor> _s2 = std::nullopt;
  std::optional<QTensor> _w3 = std::nullopt; // (n_routed_experts?, moe_intermediate_size, dim) or (hidden_dim, dim)
  std::optional<QTensor> _s3 = std::nullopt;
  std::optional<QTensor> _shared_w1 = std::nullopt; // (n_shared_experts?, moe_intermediate_size, dim)
  std::optional<QTensor> _shared_s1 = std::nullopt;
  std::optional<QTensor> _shared_w2 = std::nullopt; // (n_shared_experts?, dim, moe_intermediate_size)
  std::optional<QTensor> _shared_s2 = std::nullopt;
  std::optional<QTensor> _shared_w3 = std::nullopt; // (n_shared_experts?, moe_intermediate_size, dim)
  std::optional<QTensor> _shared_s3 = std::nullopt;
  // weights for mixture of experts router if present
  std::optional<QTensor> _moegate = std::nullopt; // (n_routed_experts?, dim)
  std::optional<QTensor> _moegate_bias = std::nullopt;
};

/* Transformer Block - Multi-Head Attention */
struct BlockMHA : public Block {
  BlockMHA(
    int layer_i,
    const std::shared_ptr<Config> config,
    const Tensor* rms_att_weight,
    const Tensor* rms_q_a_weight,
    const Tensor* rms_kv_a_weight,
    const Tensor* rms_ffn_weight,
    const Tensor* wq,
    const Tensor* sq,
    const Tensor* wq_a,
    const Tensor* sq_a,
    const Tensor* wkv_a,
    const Tensor* skv_a,
    const Tensor* wq_b,
    const Tensor* sq_b,
    const Tensor* wkv_b,
    const Tensor* skv_b,
    const Tensor* wo,
    const Tensor* so,
    const Tensor* w1,
    const Tensor* s1,
    const Tensor* w2,
    const Tensor* s2,
    const Tensor* w3,
    const Tensor* s3,
    const Tensor* shared_w1,
    const Tensor* shared_s1,
    const Tensor* shared_w2,
    const Tensor* shared_s2,
    const Tensor* shared_w3,
    const Tensor* shared_s3,
    const Tensor* moegate,
    const Tensor* moegate_bias
  );
  ~BlockMHA() override;

  float* rms_q_a_weight() const { return _rms_q_a_weight ? static_cast<float*>(_rms_q_a_weight->data) : nullptr; }
  float* rms_kv_a_weight() const { return _rms_kv_a_weight ? static_cast<float*>(_rms_kv_a_weight->data) : nullptr; }
  std::optional<QTensor> wq() const { return _wq; }
  std::optional<QTensor> wq_a() const { return _wq_a; }
  std::optional<QTensor> wq_b() const { return _wq_b; }
  std::optional<QTensor> wkv_a() const { return _wkv_a; }
  std::optional<QTensor> wkv_b() const { return _wkv_b; }
  std::optional<QTensor> wo() const { return _wo; }
  f16_t* key_cache() const { return _key_cache; }
  f16_t* key_cache(int pos) const { return _key_cache + pos * _config->head_dim * _config->n_heads; }
  f16_t* value_cache() const { return _value_cache; }
  f16_t* value_cache(int pos) const { return _value_cache + pos * _config->v_head_dim * _config->n_heads; }

  double active_bytes(size_t pos) const override;

protected:
  void attention_impl(
    InferenceState& s, int pos, int kv_sink, int kv_pos, int kv_len
  ) const override;

private:
  template <typename T>
  void _attention_impl(
    InferenceState& s,  // inference state
    int pos,            // index of the current token in the sequence
    int kv_sink,        // number of sink tokens currently in the KV cache
    int kv_pos,         // index of the current token in the kv cache, must be in [0..kv_len) since kv cache is a ring buffer
    int kv_len          // number of tokens in the kv cache that we will attend over
  ) const;

  std::optional<QTensor> _rms_q_a_weight = std::nullopt; // (q_lora_rank) rmsnorm weights
  std::optional<QTensor> _rms_kv_a_weight = std::nullopt; // (kv_lora_rank + qk_rope_head_dim)

  // weights for self-attention matmuls
  std::optional<QTensor> _wq = std::nullopt; // (n_heads * head_dim, dim)
  std::optional<QTensor> _sq = std::nullopt;
  std::optional<QTensor> _wq_a = std::nullopt; // (q_lora_rank, dim)
  std::optional<QTensor> _sq_a = std::nullopt;
  std::optional<QTensor> _wkv_a = std::nullopt; // (kv_lora_rank + qk_rope_head_dim, dim)
  std::optional<QTensor> _skv_a = std::nullopt;
  std::optional<QTensor> _wo = std::nullopt; // (dim, n_heads * v_head_dim)
  std::optional<QTensor> _so = std::nullopt;
  std::optional<QTensor> _wq_b = std::nullopt; // (n_heads * head_dim, q_lora_rank)
  std::optional<QTensor> _sq_b = std::nullopt;
  std::optional<QTensor> _wkv_b = std::nullopt; // (n_heads * (head_dim-qk_rope_head_dim+v_head_dim), kv_lora_rank)
  std::optional<QTensor> _skv_b = std::nullopt;

  // MHA kv cache
  f16_t* _key_cache = nullptr;   // (seq_len, n_heads * head_dim)
  f16_t* _value_cache = nullptr; // (seq_len, n_heads * v_head_dim)
};

/* Transformer Block - Multi-Latent Attention */
struct BlockMLA : public Block {
  BlockMLA(
    int layer_i,
    const std::shared_ptr<Config> config,
    const Tensor* rms_att_weight,
    const Tensor* rms_q_a_weight,
    const Tensor* rms_kv_a_weight,
    const Tensor* rms_ffn_weight,
    const Tensor* wq_a,
    const Tensor* sq_a,
    const Tensor* wkv_a,
    const Tensor* skv_a,
    const Tensor* wo,
    const Tensor* so,
    const Tensor* wc,
    const Tensor* sc,
    const Tensor* wq_rope_b,
    const Tensor* sq_rope_b,
    const Tensor* wv_b,
    const Tensor* sv_b,
    const Tensor* w1,
    const Tensor* s1,
    const Tensor* w2,
    const Tensor* s2,
    const Tensor* w3,
    const Tensor* s3,
    const Tensor* shared_w1,
    const Tensor* shared_s1,
    const Tensor* shared_w2,
    const Tensor* shared_s2,
    const Tensor* shared_w3,
    const Tensor* shared_s3,
    const Tensor* moegate,
    const Tensor* moegate_bias
  );
  ~BlockMLA() override;

  float* rms_q_a_weight() const { return _rms_q_a_weight ? static_cast<float*>(_rms_q_a_weight->data) : nullptr; }
  float* rms_kv_a_weight() const { return _rms_kv_a_weight ? static_cast<float*>(_rms_kv_a_weight->data) : nullptr; }
  std::optional<QTensor> wq_a() const { return _wq_a; }
  std::optional<QTensor> wkv_a() const { return _wkv_a; }
  std::optional<QTensor> wo() const { return _wo; }
  std::optional<QTensor> wc() const { return _wc; }
  std::optional<QTensor> wq_rope_b() const { return _wq_rope_b; }
  std::optional<QTensor> wv_b() const { return _wv_b; }
  f16_t* kv_nope_cache() const { return _kv_nope_cache; }
  f16_t* kv_nope_cache(int pos) const { return _kv_nope_cache + pos * _config->kv_lora_rank; }
  f16_t* kv_rope_cache() const { return _kv_rope_cache; }
  f16_t* kv_rope_cache(int pos) const { return _kv_rope_cache + pos * _config->qk_rope_head_dim; }

  double active_bytes(size_t pos) const override;

protected:
  void attention_impl(
    InferenceState& s, int pos, int kv_sink, int kv_pos, int kv_len
  ) const override;
private:
  template <typename T>
  void _attention_impl(
    InferenceState& s,  // inference state
    int pos,            // index of the current token in the sequence
    int kv_sink,        // number of sink tokens currently in the KV cache
    int kv_pos,         // index of the current token in the kv cache, must be in [0..kv_len) since kv cache is a ring buffer
    int kv_len          // number of tokens in the kv cache that we will attend over
  ) const;

  // weights for norms
  std::optional<QTensor> _rms_q_a_weight = std::nullopt; // (q_lora_rank) rmsnorm weights
  std::optional<QTensor> _rms_kv_a_weight = std::nullopt; // (kv_lora_rank + qk_rope_head_dim)

  // weights for self-attention matmuls
  std::optional<QTensor> _wq_a = std::nullopt; // (q_lora_rank, dim)
  std::optional<QTensor> _sq_a = std::nullopt;
  std::optional<QTensor> _wkv_a = std::nullopt; // (kv_lora_rank + qk_rope_head_dim, dim)
  std::optional<QTensor> _skv_a = std::nullopt;
  std::optional<QTensor> _wo = std::nullopt; // (dim, n_heads * v_head_dim)
  std::optional<QTensor> _so = std::nullopt;
  std::optional<QTensor> _wc = std::nullopt; // (n_heads * kv_lora_rank, q_lora_rank)
  std::optional<QTensor> _sc = std::nullopt;
  std::optional<QTensor> _wq_rope_b = std::nullopt; // (n_heads * qk_rope_head_dim, q_lora_rank)
  std::optional<QTensor> _sq_rope_b = std::nullopt;
  std::optional<QTensor> _wv_b = std::nullopt; // (n_heads, v_head_dim, kv_lora_rank)
  std::optional<QTensor> _sv_b = std::nullopt;

  // MLA kv cache
  f16_t* _kv_nope_cache = nullptr; // (seq_len, kv_lora_rank)
  f16_t* _kv_rope_cache = nullptr; // (seq_len, qk_rope_head_dim)
};

struct Model {
  std::shared_ptr<Config> config;

  std::vector<std::shared_ptr<Block>> blocks;
  
  // token embedding table
  std::optional<QTensor> token_embedding_table = std::nullopt; // (vocab_size, dim)
  std::optional<QTensor> token_embedding_scale = std::nullopt; // (ceil(vocab_size / block_size[0]), ceil(dim / block_size[1]))
  // final norm
  std::optional<QTensor> rms_final_weight = std::nullopt; // (dim,)
  // classifier weights for the logits, on the last layer
  std::optional<QTensor> wcls = std::nullopt; // (vocab_size, dim)
  std::optional<QTensor> scls = std::nullopt;

  Model(YALMData& yalm, int context = 0);
  
  void forward(InferenceState& s, int token, int pos, InferenceMode mode = InferenceMode::OUTPUT_LOGITS);

  double active_bytes(size_t pos) const;
private:
  void _forward_cpu(InferenceState& s, int token, int pos, InferenceMode mode);
  void _copy_embedding(InferenceState& s, int token);

  Device _device = Device::CPU;
};

#if DEBUG_MODEL
struct DebugTensor {
  enum struct DataType {
    F32,
    F16,
  };

  DebugTensor() = default;
  DebugTensor(const std::vector<float>& data);
  DebugTensor(const std::vector<f16_t>& data);
  DebugTensor& operator=(const DebugTensor& other) = default;
  float max_err(const DebugTensor& other) const;

  std::vector<float> data_f32;
  std::vector<f16_t> data_f16;
  DataType data_type;
};
std::map<std::string, DebugTensor>& debug_map_cpu();
void dump_debug_map(const std::string& filename);
void dump_debug_map_as_safetensors(const std::string& filename);
#endif

////////////////////////////////////////
// Exposed for tests
////////////////////////////////////////
void attn(
  float* xout,    // (dim,) - output vector
  float* atth,    // (kv_len,) - scratch space to hold attention scores of the sequence
  const float* qh,      // (head_dim,) - query vector for this head
  const f16_t* kh,      // (kv_len, n_heads, head_dim) - buffer containing key vectors of the sequence for all KV heads
  const f16_t* vh,      // (kv_len, n_heads, head_dim) - buffer containing value vectors of the sequence for all KV heads
  int head_dim,   // size of the "key-space"
  int v_head_dim, // size of the "value-space"
  int n_heads, // number of attention heads
  int kv_len      // number of tokens of the sequence we will attend over
);

void mha_cpu(
  float* xout,  // (n_heads, head_dim)
  float* att,   // (n_heads, max_seq_len)
  f16_t* kb,    // (max_seq_len, n_heads, head_dim)
  f16_t* vb,    // (max_seq_len, n_heads, head_dim)
  float* q,     // (n_heads, head_dim)
  int head_dim, int v_head_dim, int kv_len, int max_seq_len, int n_heads
);

void matmul_unscaled(float* xout, float* x, const QTensor& w);

void ffn_cpu(
  float* xout, float* x, 
  float* w1, float* w2, float* w3, 
  int hidden_dim, int dim,
  ActivationType act
);
////////////////////////////////////////

================================================
FILE: src/profile.cpp
================================================
#include "profile.h"

#include <vector>

static bool _profile_enabled = true;
static std::vector<std::string> _profile_scopes;
static std::map<std::string, double> _profile_times;

void set_profile_enabled(bool enabled) {
  _profile_enabled = enabled;
}

bool get_profile_enabled() {
  return _profile_enabled;
}

const std::map<std::string, double>& profile_times() {
  return _profile_times;
}

#if PROFILE_ENABLED
ProfileScope::ProfileScope(std::string name) {
  _profile_scopes.push_back(name);
  _start = omp_get_wtime();
}

ProfileScope::ProfileScope(const char* name) : 
  ProfileScope(std::string(name)) {}

ProfileScope::~ProfileScope() {
  double end = omp_get_wtime();
  double duration = end - _start;
  if (_profile_enabled) {
    std::string key = "";
    for (const auto& scope : _profile_scopes) {
      key += scope + ".";
    }
    _profile_times[key] += duration;
  }
  _profile_scopes.pop_back();
}
#else
ProfileScope::ProfileScope(std::string name) {}
ProfileScope::ProfileScope(const char* name) {}
ProfileScope::~ProfileScope() {}
#endif

ProfileDisabledScope::ProfileDisabledScope() {
  _was_enabled = get_profile_enabled();
  set_profile_enabled(false);
}

ProfileDisabledScope::~ProfileDisabledScope() {
  set_profile_enabled(_was_enabled);
}

================================================
FILE: src/profile.h
================================================
#include <omp.h>
#include <map>
#include <string>

#define PROFILE_ENABLED 0

// Toggle aggregation of profile scopes at runtime.
// This does not disable profile instrumentation; change PROFILE_ENABLED and recompile for that.
void set_profile_enabled(bool enabled);
bool get_profile_enabled();
const std::map<std::string, double>& profile_times();

#if PROFILE_ENABLED
// This macro can be used to profile a block of code.
// Example:
// ```
// {
//   PROFILE_BLOCK(my_block);
//   // code to profile...
// }
// ```
// The execution time will be saved with key `my_block` in the profile_times map.
// `my_block` need not be a variable name; it can be any string.
#define PROFILE_BLOCK(name) \
  ProfileScope profile_scope(#name)
#else
#define PROFILE_BLOCK(name)
#endif

// This macro can be used to profile a single statement.
// Example:
// ```
// PROFILE(my_statement);
// ```
// The execution time will be saved with key `my_statement` in the profile_times map.
// `my_statement` should be a valid C++ statement or expression.
#define PROFILE(X) do { \
  PROFILE_BLOCK(X); \
  X; \
} while(0)

struct ProfileScope {
  ProfileScope(std::string name);
  ProfileScope(const char* name);
  ~ProfileScope();
private:
  double _start;
};

struct ProfileDisabledScope {
  ProfileDisabledScope();
  ~ProfileDisabledScope();
private:
  bool _was_enabled;
};

================================================
FILE: src/quant.cpp
================================================
/*
K-quants adapted from llama.cpp

MIT License

Copyright (c) 2023-2024 The ggml authors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

*/

#include "quant.h"

#include <cassert>

#define GROUP_MAX_EPS 1e-15f

static inline int nearest_int(float fval) {
  assert(fabsf(fval) <= 4194303.f);
  float val = fval + 12582912.f;
  int i; memcpy(&i, &val, sizeof(int));
  return (i & 0x007fffff) - 0x00400000;
}

// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)

#if __AVX__ || __AVX2__ || __AVX512F__
// horizontally add 8 floats
static inline float hsum_float_8(const __m256 x) {
  __m128 res = _mm256_extractf128_ps(x, 1);
  res = _mm_add_ps(res, _mm256_castps256_ps128(x));
  res = _mm_add_ps(res, _mm_movehl_ps(res, res));
  res = _mm_add_ss(res, _mm_movehdup_ps(res));
  return _mm_cvtss_f32(res);
}

// shuffles to pick the required scales in dot products
static inline __m256i get_scale_shuffle_q3k(int i) {
  static const uint8_t k_shuffle[128] = {
      0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
      4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
      8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
    12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
  };
  return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
}
#endif

static float make_qkx2_quants(int n, int nmax, const float * __restrict__ x, const float * __restrict__ weights,
    uint8_t * __restrict__ L, float * __restrict__ the_min, uint8_t * __restrict__ Laux,
    float rmin, float rdelta, int nstep, bool use_mad) {
  float min = x[0];
  float max = x[0];
  float sum_w = weights[0];
  float sum_x = sum_w * x[0];
#ifdef HAVE_BUGGY_APPLE_LINKER
  // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
  for (volatile int i = 1; i < n; ++i) {
#else
  for (int i = 1; i < n; ++i) {
#endif
    if (x[i] < min) min = x[i];
    if (x[i] > max) max = x[i];
    float w = weights[i];
    sum_w += w;
    sum_x += w * x[i];
  }
  if (min > 0) min = 0;
  if (max == min) {
    for (int i = 0; i < n; ++i) L[i] = 0;
    *the_min = -min;
    return 0.f;
  }
  float iscale = nmax/(max - min);
  float scale = 1/iscale;
  float best_mad = 0;
  for (int i = 0; i < n; ++i) {
    int l = nearest_int(iscale*(x[i] - min));
    L[i] = std::max(0, std::min(nmax, l));
    float diff = scale * L[i] + min - x[i];
    diff = use_mad ? fabsf(diff) : diff * diff;
    float w = weights[i];
    best_mad += w * diff;
  }
  if (nstep < 1) {
    *the_min = -min;
    return scale;
  }
  for (int is = 0; is <= nstep; ++is) {
    iscale = (rmin + rdelta*is + nmax)/(max - min);
    float sum_l = 0, sum_l2 = 0, sum_xl = 0;
    for (int i = 0; i < n; ++i) {
      int l = nearest_int(iscale*(x[i] - min));
      l = std::max(0, std::min(nmax, l));
      Laux[i] = l;
      float w = weights[i];
      sum_l += w*l;
      sum_l2 += w*l*l;
      sum_xl += w*l*x[i];
    }
    float D = sum_w * sum_l2 - sum_l * sum_l;
    if (D > 0) {
      float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
      float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;
      if (this_min > 0) {
        this_min = 0;
        this_scale = sum_xl / sum_l2;
      }
      float mad = 0;
      for (int i = 0; i < n; ++i) {
        float diff = this_scale * Laux[i] + this_min - x[i];
        diff = use_mad ? fabsf(diff) : diff * diff;
        float w = weights[i];
        mad += w * diff;
      }
      if (mad < best_mad) {
        for (int i = 0; i < n; ++i) {
          L[i] = Laux[i];
        }
        best_mad = mad;
        scale = this_scale;
        min = this_min;
      }
    }
  }
  *the_min = -min;
  return scale;
}

void quantize_row_q2_K_ref(const float * __restrict__ x, block_q2_K * __restrict__ y, int64_t k) {
  assert(k % QK_K == 0);
  const int nb = k / QK_K;

  uint8_t L[QK_K];
  uint8_t Laux[16];
  float   weights[16];
  float mins[QK_K/16];
  float scales[QK_K/16];

  const float q4scale = 15.f;

  for (int i = 0; i < nb; i++) {
    float max_scale = 0; // as we are deducting the min, scales are always positive
    float max_min = 0;
    for (int j = 0; j < QK_K/16; ++j) {
      for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
      scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
      float scale = scales[j];
      if (scale > max_scale) {
        max_scale = scale;
      }
      float min = mins[j];
      if (min > max_min) {
        max_min = min;
      }
    }

    if (max_scale > 0) {
      float iscale = q4scale/max_scale;
      for (int j = 0; j < QK_K/16; ++j) {
        int l = nearest_int(iscale*scales[j]);
        y[i].scales[j] = l;
      }
      y[i].d = float_to_half(max_scale/q4scale);
    } else {
      for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
      y[i].d = float_to_half(0.f);
    }
    if (max_min > 0) {
      float iscale = q4scale/max_min;
      for (int j = 0; j < QK_K/16; ++j) {
        int l = nearest_int(iscale*mins[j]);
        y[i].scales[j] |= (l << 4);
      }
      y[i].dmin = float_to_half(max_min/q4scale);
    } else {
      y[i].dmin = float_to_half(0.f);
    }
    for (int j = 0; j < QK_K/16; ++j) {
      const float d = half_to_float(y[i].d) * (y[i].scales[j] & 0xF);
      if (!d) continue;
      const float dm = half_to_float(y[i].dmin) * (y[i].scales[j] >> 4);
      for (in
Download .txt
gitextract_edm78p7f/

├── .gitignore
├── LICENSE.md
├── Makefile
├── README.md
├── convert.py
├── pyproject.toml
├── quantizer.cpp
├── quantizer.py
├── setup.py
├── src/
│   ├── codec.cpp
│   ├── codec.h
│   ├── debug.cpp
│   ├── debug.h
│   ├── infer.cpp
│   ├── main.cpp
│   ├── model.cpp
│   ├── model.h
│   ├── profile.cpp
│   ├── profile.h
│   ├── quant.cpp
│   ├── quant.h
│   ├── sampler.cpp
│   ├── sampler.h
│   ├── test.cpp
│   ├── time_utils.cpp
│   ├── time_utils.h
│   ├── tokenizer.cpp
│   ├── tokenizer.h
│   ├── wikitest.cat.1chunk.v2-encoded.txt
│   └── wikitest.cat.1chunk.v3-encoded.txt
└── vendor/
    ├── fmt/
    │   ├── base.h
    │   ├── format-inl.h
    │   └── format.h
    ├── format.cc
    └── json.hpp
Download .txt
SYMBOL INDEX (781 symbols across 24 files)

FILE: convert.py
  class BlockQuant (line 26) | class BlockQuant:
  class KQuant (line 32) | class KQuant:
  class Metadata (line 46) | class Metadata:
    method __init__ (line 47) | def __init__(self, config, tokenizer_config, quant, n_layers, use_mla,...
    method to_dict (line 123) | def to_dict(self):
  function gpt2_bytes_to_unicode (line 175) | def gpt2_bytes_to_unicode():
  function load_tokens (line 187) | def load_tokens(tokenizer_path, vocab_size):
  function per_tensor_quantize (line 216) | def per_tensor_quantize(tensor: torch.Tensor, dtype: torch.dtype) -> Tup...
  function per_tensor_dequantize (line 246) | def per_tensor_dequantize(qweight: torch.Tensor, scale: torch.Tensor) ->...
  function blockwise_dequantize (line 250) | def blockwise_dequantize(qweight: torch.Tensor, scale: torch.Tensor, blo...
  function blockwise_quantize (line 262) | def blockwise_quantize(weight: torch.Tensor, block_size: torch.Tensor, d...
  function per_expert_blockwise_quantize (line 277) | def per_expert_blockwise_quantize(expert_weights: torch.Tensor, block_si...
  function per_expert_k_quantize (line 288) | def per_expert_k_quantize(expert_weights: torch.Tensor, method: Literal[...
  function load_weights (line 296) | def load_weights(model_files: List[str], metadata: Metadata, tie_word_em...

FILE: quantizer.cpp
  function quantize_q2_k (line 4) | torch::Tensor quantize_q2_k(torch::Tensor& input) {
  function quantize_q3_k (line 36) | torch::Tensor quantize_q3_k(torch::Tensor& input) {
  function PYBIND11_MODULE (line 68) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: quantizer.py
  function k_quantize (line 5) | def k_quantize(tensor: torch.Tensor, method: Literal["q2_k", "q3_k"]) ->...

FILE: setup.py
  class BinaryDistribution (line 6) | class BinaryDistribution(Distribution):
    method has_ext_modules (line 7) | def has_ext_modules(self):

FILE: src/codec.cpp
  function quant_to_string (line 13) | std::string quant_to_string(Quant quant) {
  function string_to_quant (line 24) | std::optional<Quant> string_to_quant(const std::string& quant_str) {
  function bits_per_weight (line 40) | double bits_per_weight(Quant quant, size_t blockwise_quant_size) {
  function CodecDType (line 55) | CodecDType quant_to_codec_dtype(Quant quant) {
  function is_k_quant (line 66) | bool is_k_quant(Quant quant) {
  function codec_dtype_to_string (line 70) | std::string codec_dtype_to_string(CodecDType dtype) {
  function string_to_codec_dtype (line 85) | std::optional<CodecDType> string_to_codec_dtype(const std::string& dtype...
  function codec_dtype_size (line 109) | size_t codec_dtype_size(CodecDType dtype) {
  function QTensor (line 166) | QTensor QTensor::from_codec_tensor(const Tensor& tensor, Quant weight_qu...
  type stat (line 269) | struct stat
  type dirent (line 342) | struct dirent

FILE: src/codec.h
  type f16_t (line 19) | typedef uint16_t f16_t;
  type f8e5m2_t (line 20) | typedef uint8_t f8e5m2_t;
  function half_to_float (line 23) | inline float half_to_float(f16_t x) {
  function f16_t (line 26) | inline f16_t float_to_half(float x) {
  function half_to_float (line 30) | inline float half_to_float(f16_t x) {
  function f16_t (line 34) | inline f16_t float_to_half(float x) {
  function float8e5m2_to_float (line 40) | inline float float8e5m2_to_float(f8e5m2_t x) {
  function f8e5m2_t (line 49) | [[maybe_unused]] inline f8e5m2_t float_to_float8e5m2(float x) {
  type class (line 62) | enum class
  type class (line 79) | enum class
  type Tensor (line 95) | struct Tensor {
  type QTensor (line 107) | struct QTensor {
  type YALMData (line 122) | struct YALMData {

FILE: src/debug.h
  type BinaryDumper (line 6) | struct BinaryDumper {

FILE: src/infer.cpp
  function copy_debug_tensor (line 19) | static std::vector<T> copy_debug_tensor(T* x, size_t size) {
  function save_debug_tensor (line 27) | static void save_debug_tensor(const std::string& name, T* x, size_t size) {
  function dump_debug_map (line 30) | void dump_debug_map(const std::string& filename) {
  function dump_debug_map_as_safetensors (line 73) | void dump_debug_map_as_safetensors(const std::string& filename) {
  function _matmul (line 121) | static void _matmul(
  function _matmul (line 161) | static void _matmul(
  function _matmul (line 238) | static void _matmul(
  function _matmul (line 315) | static void _matmul(
  function _matmul (line 348) | static void _matmul(
  function matmul (line 381) | static void matmul(
  function matmul_unscaled (line 419) | void matmul_unscaled(float* xout, float* x, const QTensor& w) {
  function matmul_expert (line 423) | static void matmul_expert(
  function softmax (line 472) | static void softmax(float* o, float* x, int size) {
  function sigmoid (line 489) | inline float sigmoid(float x) {
  function moe_gate (line 493) | static void moe_gate(
  function rmsnorm (line 601) | static void rmsnorm(float* o, float* x, float* weight, int size, float e...
  function layernorm (line 613) | [[maybe_unused]] static void layernorm(float* o, float* x, float* weight...
  function gelu (line 636) | inline float gelu(float x) {
  function silu (line 640) | inline float silu(float x) {
  function clip (line 644) | inline float clip(float x, float v) {
  function rope (line 648) | static void rope(float* buf, float* vec, int d, int head_dim, int pos, f...
  function rope_v3 (line 670) | static void rope_v3(float* vec, int d, int head_dim, int pos, float thet...
  function rope (line 687) | static void rope(float* buf, f16_t* vec, int d, int head_dim, int pos, f...
  function rope_v3 (line 709) | static void rope_v3(f16_t* vec, int d, int head_dim, int pos, float thet...
  function attn (line 728) | void attn(
  function attn_mla (line 766) | void attn_mla(
  function mha_cpu (line 1143) | void mha_cpu(
  function ffn_cpu (line 1166) | void ffn_cpu(

FILE: src/main.cpp
  function error_usage (line 18) | void error_usage() {
  function help_usage_interactive (line 45) | void help_usage_interactive() {
  type Session (line 71) | struct Session {
    method Session (line 72) | Session(const std::string& checkpoint_dir, bool lock_model_weights, in...
  type CompletionArgs (line 85) | struct CompletionArgs {
    method parse_args (line 91) | bool parse_args(const std::vector<const char*>& args) {
  type PasskeyArgs (line 158) | struct PasskeyArgs {
    method parse_args (line 162) | bool parse_args(const std::vector<const char*>& args) {
  type PerplexityArgs (line 199) | struct PerplexityArgs {
    method parse_args (line 203) | bool parse_args(const std::vector<const char*>& args) {
  function encode_prompt (line 257) | std::vector<int> encode_prompt(const std::string& prompt, Tokenizer& tok...
  function run_completion (line 277) | void run_completion(
  function run_perplexity (line 371) | void run_perplexity(
  function run_passkey (line 433) | void run_passkey(
  function run_interactive (line 514) | void run_interactive(Session& session) {
  function main (line 594) | int main(int argc, char* argv[]) {

FILE: src/model.cpp
  function cdiv (line 18) | int cdiv(int a, int b) {
  function check_tensor (line 129) | std::optional<QTensor> check_tensor(const Tensor* tensor, Quant weight_q...
  function Tensor (line 138) | const Tensor* get_tensor(const YALMData& yalm, const std::string& key) {

FILE: src/model.h
  type class (line 16) | enum class
  function LayerNormType (line 21) | enum class LayerNormType {
  type class (line 25) | enum class
  function ScoringFunc (line 31) | enum class ScoringFunc {
  function Block (line 276) | struct BlockMHA : public Block {
  function Block (line 366) | struct BlockMLA : public Block {
  type Model (line 455) | struct Model {
  type DebugTensor (line 482) | struct DebugTensor {

FILE: src/profile.cpp
  function set_profile_enabled (line 9) | void set_profile_enabled(bool enabled) {
  function get_profile_enabled (line 13) | bool get_profile_enabled() {

FILE: src/profile.h
  type ProfileScope (line 42) | struct ProfileScope {
  type ProfileDisabledScope (line 50) | struct ProfileDisabledScope {

FILE: src/quant.cpp
  function nearest_int (line 34) | static inline int nearest_int(float fval) {
  function hsum_float_8 (line 46) | static inline float hsum_float_8(const __m256 x) {
  function __m256i (line 55) | static inline __m256i get_scale_shuffle_q3k(int i) {
  function quantize_row_q2_K_ref (line 147) | void quantize_row_q2_K_ref(const float * __restrict__ x, block_q2_K * __...
  function dequantize_row_q2_K (line 217) | void dequantize_row_q2_K(const block_q2_K * __restrict__ x, float * __re...
  function make_q3_quants (line 249) | static float make_q3_quants(int n, int nmax, const float * __restrict__ ...
  function quantize_row_q3_K_ref (line 308) | void quantize_row_q3_K_ref(const float * __restrict__ x, block_q3_K * __...
  function dequantize_row_q3_K (line 384) | void dequantize_row_q3_K(const block_q3_K * __restrict__ x, float * __re...
  function ggml_vec_dot_q3_K_q8_K (line 434) | void ggml_vec_dot_q3_K_q8_K(int n, float * __restrict__ s, const void * ...
  function quantize_row_q8_K_ref (line 616) | void quantize_row_q8_K_ref(const float * __restrict__ x, block_q8_K * __...
  function dequantize_row_q8_K (line 655) | void dequantize_row_q8_K(const block_q8_K * __restrict__ x, float * __re...
  function ggml_vec_dot_q2_K_q8_K (line 666) | void ggml_vec_dot_q2_K_q8_K(

FILE: src/quant.h
  type block_q2_K (line 41) | typedef struct {
  type block_q3_K (line 70) | typedef struct {
  type block_q8_K (line 104) | typedef struct {

FILE: src/sampler.h
  type Sampler (line 7) | struct Sampler {

FILE: src/test.cpp
  function floatEquals (line 13) | bool floatEquals(float a, float b, float epsilon = 1e-5) {
  function arrayEquals (line 17) | bool arrayEquals(const std::vector<float>& a, const std::vector<float>& ...
  function assertArrayEquals (line 29) | void assertArrayEquals(const std::vector<float>& actual, const std::vect...
  function assertArrayEquals (line 46) | void assertArrayEquals(float* actual, const std::vector<float>& expected...
  function float_array_to_half (line 54) | std::vector<f16_t> float_array_to_half(const std::vector<float>& data) {
  function float_array_to_float8e5m2 (line 62) | std::vector<f8e5m2_t> float_array_to_float8e5m2(const std::vector<float>...
  function test_attn (line 70) | void test_attn() {
  function test_matmul (line 128) | void test_matmul() {
  function fill_random (line 188) | void fill_random(float* data, size_t N, unsigned long seed, float scale_...
  function fill_random (line 196) | void fill_random(f16_t* data, size_t N, unsigned long seed, float scale_...
  function mem_bench (line 218) | void mem_bench() {
  type ThreadData (line 252) | struct alignas(64) ThreadData {
  function mem_bench2_thread (line 257) | void mem_bench2_thread(uint32_t* data, size_t start_idx, size_t elements...
  function mem_bench2 (line 264) | void mem_bench2() {
  function main (line 312) | int main(int argc, char* argv[]) {

FILE: src/time_utils.cpp
  function get_timestamp_ms (line 5) | uint64_t get_timestamp_ms() {

FILE: src/tokenizer.h
  type TokenTrie (line 10) | struct TokenTrie
  type TokenTrie (line 12) | struct TokenTrie {
  type Tokenizer (line 51) | struct Tokenizer {

FILE: vendor/fmt/base.h
  function const_check (line 346) | struct monostate {
  function T (line 417) | auto convert_for_visit(T value) -> T {
  function int128_opt (line 424) | enum class int128_opt {}
  function uint128_opt (line 425) | enum class uint128_opt {}
  function Int (line 432) | auto to_unsigned(Int value) -> make_unsigned_t<Int> {
  function typename (line 445) | typename T::value_type*> {}
  function namespace (line 478) | namespace adl {
  function OutputIt (line 498) | struct accessor : OutputIt {
  function FMT_EXPORT (line 516) | FMT_EXPORT
  function operator (line 560) | constexpr auto operator[](size_t pos) const noexcept -> const Char& {
  function FMT_CONSTEXPR (line 564) | FMT_CONSTEXPR void remove_prefix(size_t n) noexcept {
  function FMT_CONSTEXPR (line 569) | FMT_CONSTEXPR auto starts_with(basic_string_view<Char> sv) const noexcept
  function FMT_CONSTEXPR (line 573) | FMT_CONSTEXPR auto starts_with(Char c) const noexcept -> bool {
  function FMT_CONSTEXPR (line 576) | FMT_CONSTEXPR auto starts_with(const Char* s) const -> bool {
  function FMT_CONSTEXPR (line 581) | FMT_CONSTEXPR auto compare(basic_string_view other) const -> int {
  function FMT_EXPORT (line 614) | FMT_EXPORT
  function true_type (line 616) | struct is_char<char> : std::true_type {}
  function namespace (line 618) | namespace detail {
  function operator (line 648) | constexpr operator basic_string_view<Char>() const {
  function type (line 657) | enum class type {
  function throw_format_error (line 737) | inline void throw_format_error(
  function FMT_EXPORT (line 752) | FMT_EXPORT
  function namespace (line 813) | namespace detail {
  function FMT_CONSTEXPR (line 900) | FMT_CONSTEXPR auto data() noexcept -> T* { return ptr_; }
  function clear (line 904) | void clear() { size_ = 0; }
  function FMT_CONSTEXPR (line 908) | FMT_CONSTEXPR void try_resize(size_t count) {
  function FMT_CONSTEXPR (line 917) | FMT_CONSTEXPR void try_reserve(size_t new_capacity) {
  function FMT_CONSTEXPR (line 921) | FMT_CONSTEXPR void push_back(const T& value) {
  function count (line 950) | struct buffer_traits {
  function flush (line 983) | void flush() {
  function flush (line 1024) | void flush() {
  function friend (line 1158) | friend auto get_container(basic_appender app) -> detail::buffer<T>& {
  function FMT_CONSTEXPR (line 1171) | FMT_CONSTEXPR basic_appender(detail::buffer<T>& buf) : buffer_(&buf) {}
  function operator (line 1177) | auto operator*() -> basic_appender& { return *this; }
  function namespace (line 1184) | namespace detail {
  function decltype (line 1223) | auto has_const_formatter_impl(T*)
  type view (line 1269) | struct view {}
  type unformattable (line 1302) | struct unformattable {}
  function unformattable (line 1303) | struct unformattable_char : unformattable {}
  function unformattable (line 1304) | struct unformattable_pointer : unformattable {}
  function FMT_ALWAYS_INLINE (line 1346) | constexpr FMT_ALWAYS_INLINE value() : no_value() {}
  function FMT_ALWAYS_INLINE (line 1367) | FMT_ALWAYS_INLINE value(const void* val) : pointer(val) {}
  function FMT_ALWAYS_INLINE (line 1368) | FMT_ALWAYS_INLINE value(const named_arg_info<char_type>* args, size_t size)
  function format_custom_arg (line 1393) | void format_custom_arg(void* arg,
  function FMT_MAP_API (line 1434) | FMT_MAP_API auto map(signed char val) -> int { return val; }
  function FMT_MAP_API (line 1435) | FMT_MAP_API auto map(unsigned char val) -> unsigned { return val; }
  function FMT_MAP_API (line 1436) | FMT_MAP_API auto map(short val) -> int { return val; }
  function FMT_MAP_API (line 1437) | FMT_MAP_API auto map(unsigned short val) -> unsigned { return val; }
  function FMT_MAP_API (line 1438) | FMT_MAP_API auto map(int val) -> int { return val; }
  function FMT_MAP_API (line 1439) | FMT_MAP_API auto map(unsigned val) -> unsigned { return val; }
  function FMT_MAP_API (line 1440) | FMT_MAP_API auto map(long val) -> long_type { return val; }
  function FMT_MAP_API (line 1441) | FMT_MAP_API auto map(unsigned long val) -> ulong_type { return val; }
  function FMT_MAP_API (line 1442) | FMT_MAP_API auto map(long long val) -> long long { return val; }
  function FMT_MAP_API (line 1443) | FMT_MAP_API auto map(unsigned long long val) -> unsigned long long {
  function FMT_MAP_API (line 1446) | FMT_MAP_API auto map(int128_opt val) -> int128_opt { return val; }
  function FMT_MAP_API (line 1447) | FMT_MAP_API auto map(uint128_opt val) -> uint128_opt { return val; }
  function FMT_MAP_API (line 1448) | FMT_MAP_API auto map(bool val) -> bool { return val; }
  function char_type (line 1452) | auto map(T val) -> char_type {
  function FMT_MAP_API (line 1467) | FMT_MAP_API auto map(float val) -> float { return val; }
  function FMT_MAP_API (line 1468) | FMT_MAP_API auto map(double val) -> double { return val; }
  function FMT_MAP_API (line 1469) | FMT_MAP_API auto map(long double val) -> long double { return val; }
  function Char (line 1476) | auto map(const T& val) -> basic_string_view<Char> {
  function unformattable_char (line 1482) | auto map(const T&) -> unformattable_char {
  function FMT_MAP_API (line 1486) | FMT_MAP_API auto map(void* val) -> const void* { return val; }
  function FMT_MAP_API (line 1487) | FMT_MAP_API auto map(const void* val) -> const void* { return val; }
  function FMT_MAP_API (line 1488) | FMT_MAP_API auto map(volatile void* val) -> const void* {
  function FMT_MAP_API (line 1491) | FMT_MAP_API auto map(const volatile void* val) -> const void* {
  function FMT_MAP_API (line 1494) | FMT_MAP_API auto map(std::nullptr_t val) -> const void* { return val; }
  function unformattable_pointer (line 1505) | auto map(const T&) -> unformattable_pointer {
  function decltype (line 1518) | auto map(const T& val) -> decltype(FMT_DECLTYPE_THIS map(U())) {
  function decltype (line 1544) | auto map(T& val) -> decltype(FMT_DECLTYPE_THIS do_map(val)) {
  function decltype (line 1549) | auto map(const T& named_arg)
  type is_output_iterator (line 1572) | struct is_output_iterator
  function class (line 1580) | class locale_ref {
  function make_descriptor (line 1604) | constexpr unsigned long long make_descriptor() {
  type type_is_unformattable_for (line 1613) | struct type_is_unformattable_for {
  type type_is_unformattable_for (line 1616) | struct type_is_unformattable_for
  function Context (line 1659) | inline auto make_arg(T& val) -> basic_format_arg<Context> {
  type format_arg_store (line 1716) | struct format_arg_store
  function FMT_BEGIN_EXPORT (line 1722) | FMT_BEGIN_EXPORT
  function FMT_CONSTEXPR (line 1870) | FMT_CONSTEXPR auto type(int index) const -> detail::type {
  function FMT_CONSTEXPR (line 1966) | FMT_CONSTEXPR auto arg(int id) const -> format_arg { return args_.get(id...
  function format_arg (line 1967) | auto arg(string_view name) -> format_arg { return args_.get(name); }
  function FMT_CONSTEXPR (line 1968) | FMT_CONSTEXPR auto arg_id(string_view name) -> int {
  function advance_to (line 1977) | void advance_to(iterator) {}
  function namespace (line 2065) | namespace align {
  function presentation_type (line 2128) | enum class presentation_type : unsigned char {
  function FMT_CONSTEXPR (line 2267) | FMT_CONSTEXPR inline auto parse_align(char c) -> align_t {
  function FMT_CONSTEXPR (line 2325) | FMT_CONSTEXPR void on_auto() {
  function FMT_CONSTEXPR (line 2330) | FMT_CONSTEXPR void on_index(int id) {
  function state (line 2377) | enum class state { start, align, sign, hash, zero, width, precision, loc...
  function FMT_CONSTEXPR (line 2571) | FMT_CONSTEXPR void on_auto() { arg_id = handler.on_arg_id(); }
  function FMT_CONSTEXPR (line 2572) | FMT_CONSTEXPR void on_index(int id) { arg_id = handler.on_arg_id(id); }
  function adapter (line 2585) | auto adapter = id_adapter{handler, 0};
  function FMT_CONSTEXPR (line 2660) | FMT_CONSTEXPR auto parse_format_specs(ParseContext& ctx)
  function FMT_CONSTEXPR (line 2682) | FMT_CONSTEXPR inline auto check_char_specs(const format_specs& specs) ->...
  function on_text (line 2737) | void on_text(const Char*, const Char*) {}
  function FMT_CONSTEXPR (line 2740) | FMT_CONSTEXPR auto on_arg_id(int id) -> int {
  function FMT_CONSTEXPR (line 2743) | FMT_CONSTEXPR auto on_arg_id(basic_string_view<Char> id) -> int {
  function FMT_CONSTEXPR (line 2755) | FMT_CONSTEXPR void on_replacement_field(int id, const Char* begin) {
  function on_error (line 2766) | void on_error(const char* message) {
  type compile_string (line 2772) | struct compile_string {}
  function check_format_string (line 2779) | void check_format_string(const S&) {
  function report_truncation (line 2797) | inline void report_truncation(bool truncated) {
  function char (line 2806) | struct vformat_args<char> {
  function vprint_mojibake (line 2816) | inline void vprint_mojibake(FILE*, string_view, format_args, bool) {}
  function FMT_ALWAYS_INLINE (line 2873) | FMT_ALWAYS_INLINE basic_format_string(const S& s) : str_(s) {
  function string_view (line 2899) | inline auto runtime(string_view s) -> string_view { return s; }
  function runtime_format_string (line 2911) | inline auto runtime(string_view s) -> runtime_format_string<> { return {...

FILE: vendor/fmt/format-inl.h
  function FMT_BEGIN_NAMESPACE (line 29) | FMT_BEGIN_NAMESPACE
  function FMT_FUNC (line 131) | FMT_FUNC void report_error(const char* message) {
  function FMT_FUNC (line 145) | FMT_FUNC auto format_facet<std::locale>::do_put(
  function FMT_FUNC (line 152) | FMT_FUNC auto vsystem_error(int error_code, string_view fmt, format_args...
  function namespace (line 158) | namespace detail {
  function noexcept (line 199) | inline auto floor_log10_pow2_minus_log10_4_over_3(int e) noexcept -> int {
  function FMT_INLINE_VARIABLE (line 204) | FMT_INLINE_VARIABLE constexpr struct {
  function noexcept (line 247) | inline auto divide_by_10_to_kappa_plus_1(uint32_t n) noexcept -> uint32_t {
  function noexcept (line 252) | inline auto divide_by_10_to_kappa_plus_1(uint64_t n) noexcept -> uint64_t {
  function float (line 260) | struct cache_accessor<float> {
  function noexcept (line 314) | static auto compute_delta(const cache_entry_type& cache, int beta) noexcept
  function carrier_uint (line 331) | static auto compute_left_endpoint_for_shorter_interval_case(
  function carrier_uint (line 338) | static auto compute_right_endpoint_for_shorter_interval_case(
  function carrier_uint (line 345) | static auto compute_round_up_for_shorter_interval_case(
  function double (line 354) | struct cache_accessor<double> {
  function bigint (line 1374) | struct formatter<detail::bigint> {
  function FMT_FUNC (line 1400) | FMT_FUNC detail::utf8_to_utf16::utf8_to_utf16(string_view s) {
  function FMT_FUNC (line 1415) | FMT_FUNC void format_system_error(detail::buffer<char>& out, int error_c...
  function FMT_FUNC (line 1426) | FMT_FUNC void report_system_error(int error_code,
  function FMT_FUNC (line 1431) | FMT_FUNC auto vformat(string_view fmt, format_args args) -> std::string {
  function namespace (line 1439) | namespace detail {
  function operator (line 1474) | operator F*() const { return file_; }
  function unget (line 1485) | void unget(char c) {
  function flush (line 1490) | void flush() { fflush(this->file_); }
  function init_buffer (line 1508) | void init_buffer() {
  function span (line 1522) | auto get_write_buffer() const -> span<char> {
  function advance_write_buffer (line 1527) | void advance_write_buffer(size_t size) { this->file_->_IO_write_ptr += s...
  function flush (line 1535) | void flush() { fflush_unlocked(this->file_); }
  function init_buffer (line 1553) | void init_buffer() {
  function advance_write_buffer (line 1572) | void advance_write_buffer(size_t size) {
  function init_buffer (line 1595) | void init_buffer() {}
  function unget (line 1610) | void unget(char c) {
  function grow (line 1648) | static void grow(buffer<char>& base, size_t) {
  function FMT_FUNC (line 1675) | FMT_FUNC auto write_console(int, string_view) -> bool { return false; }
  function FMT_FUNC (line 1681) | FMT_FUNC bool write_console(int fd, string_view text) {
  function FMT_FUNC (line 1690) | FMT_FUNC void vprint_mojibake(std::FILE* f, string_view fmt, format_args...
  function FMT_FUNC (line 1699) | FMT_FUNC void print(std::FILE* f, string_view text) {
  function FMT_FUNC (line 1711) | FMT_FUNC void vprint_buffered(std::FILE* f, string_view fmt, format_args...
  function FMT_FUNC (line 1717) | FMT_FUNC void vprint(std::FILE* f, string_view fmt, format_args args) {
  function FMT_FUNC (line 1724) | FMT_FUNC void vprintln(std::FILE* f, string_view fmt, format_args args) {
  function FMT_FUNC (line 1731) | FMT_FUNC void vprint(string_view fmt, format_args args) {
  function namespace (line 1735) | namespace detail {
  function FMT_FUNC (line 1774) | FMT_FUNC auto is_printable(uint32_t cp) -> bool {

FILE: vendor/fmt/format.h
  function namespace (line 109) | namespace std {
  function FMT_BEGIN_NAMESPACE (line 119) | FMT_BEGIN_NAMESPACE
  function FMT_BEGIN_NAMESPACE (line 202) | FMT_BEGIN_NAMESPACE
  function namespace (line 278) | namespace detail {
  function FMT_CONSTEXPR (line 380) | FMT_CONSTEXPR auto operator>>(int shift) const -> uint128_fallback {
  function FMT_CONSTEXPR (line 385) | FMT_CONSTEXPR auto operator<<(int shift) const -> uint128_fallback {
  function FMT_CONSTEXPR (line 390) | FMT_CONSTEXPR auto operator>>=(int shift) -> uint128_fallback& {
  function uint128_fallback (line 405) | uint64_t n) noexcept -> uint128_fallback& {
  function To (line 455) | auto bit_cast(const From& from) -> To {
  function FMT_INLINE (line 493) | FMT_INLINE void assume(bool condition) {
  function Char (line 509) | auto get_data(std::basic_string<Char>& s) -> Char* {
  function FMT_NOINLINE (line 594) | FMT_NOINLINE auto copy_noinline(InputIt begin, InputIt end,
  function FMT_CONSTEXPR (line 616) | FMT_CONSTEXPR inline auto utf8_decode(const char* s, uint32_t* c, int* e)
  function for_each_codepoint (line 658) | void for_each_codepoint(string_view s, F f) {
  function FMT_CONSTEXPR (line 694) | FMT_CONSTEXPR inline auto compute_width(string_view s) -> size_t {
  function true_type (line 750) | struct is_integral<int128_opt> : std::true_type {}
  function true_type (line 751) | struct is_integral<uint128_t> : std::true_type {}
  function FMT_BEGIN_EXPORT (line 814) | FMT_BEGIN_EXPORT
  function FMT_CONSTEXPR20 (line 924) | FMT_CONSTEXPR20 void resize(size_t count) { this->try_resize(count); }
  function reserve (line 927) | void reserve(size_t new_capacity) { this->try_reserve(new_capacity); }
  function FMT_END_EXPORT (line 942) | FMT_END_EXPORT
  function namespace (line 961) | namespace detail_exported {
  function generic_context (line 1010) | constexpr auto arg(int id) const -> basic_format_arg<generic_context> {
  function generic_context (line 1013) | auto arg(basic_string_view<Char> name) -> basic_format_arg<generic_conte...
  function FMT_CONSTEXPR (line 1016) | FMT_CONSTEXPR auto arg_id(basic_string_view<Char> name) -> int {
  function advance_to (line 1025) | void advance_to(iterator it) {
  function class (line 1032) | class loc_value {
  function char (line 1116) | constexpr auto digits2(size_t value) -> const char* {
  function Char (line 1126) | auto sign(Sign s) -> Char {
  function FMT_CONSTEXPR (line 1148) | FMT_CONSTEXPR inline auto count_digits(uint128_opt n) -> int {
  function FMT_CONSTEXPR20 (line 1176) | FMT_CONSTEXPR20 inline auto count_digits(uint64_t n) -> int {
  function FMT_INLINE (line 1203) | FMT_INLINE auto do_count_digits(uint32_t n) -> int {
  function FMT_CONSTEXPR20 (line 1226) | FMT_CONSTEXPR20 inline auto count_digits(uint32_t n) -> int {
  function wchar_t (line 1254) | inline auto thousands_sep(locale_ref loc) -> thousands_sep_result<wchar_...
  function Char (line 1260) | auto decimal_point(locale_ref loc) -> Char {
  function wchar_t (line 1263) | inline auto decimal_point(locale_ref loc) -> wchar_t {
  function copy2 (line 1277) | void copy2(Char* dst, const char* src) {
  function class (line 1355) | class utf8_to_utf16 {
  function namespace (line 1465) | namespace dragonbox {
  function double (line 1517) | struct float_info<double> {
  function FMT_CONSTEXPR (line 1671) | FMT_CONSTEXPR inline auto multiply(uint64_t lhs, uint64_t rhs) -> uint64...
  function FMT_CONSTEXPR (line 1688) | FMT_CONSTEXPR inline auto operator*(fp x, fp y) -> fp {
  function write (line 1756) | auto write = [=](reserve_iterator<OutputIt> it) {
  function FMT_CONSTEXPR (line 1939) | FMT_CONSTEXPR write_int_data(int num_digits, unsigned prefix,
  type next_state (line 1988) | struct next_state {
  function locale_ref (line 2218) | locale_ref loc)
  function class (line 2240) | class counting_iterator {
  function FMT_CONSTEXPR (line 2275) | FMT_CONSTEXPR auto operator*() const -> value_type { return {}; }
  function float_format (line 2381) | enum class float_format : unsigned char {
  type big_decimal_fp (line 2443) | struct big_decimal_fp {
  function write (line 2587) | auto write = [=](iterator it) {
  function signbit (line 2707) | bool signbit(T value) {
  function class (line 2727) | class bigint {
  function FMT_CONSTEXPR20 (line 2836) | FMT_CONSTEXPR20 auto num_bigits() const -> int {
  function FMT_CONSTEXPR20 (line 2907) | FMT_CONSTEXPR20 void assign_pow10(int exp) {
  function FMT_CONSTEXPR20 (line 2926) | FMT_CONSTEXPR20 void square() {
  function FMT_CONSTEXPR20 (line 2956) | FMT_CONSTEXPR20 void align(const bigint& other) {
  function FMT_CONSTEXPR20 (line 2969) | FMT_CONSTEXPR20 auto divmod_assign(const bigint& divisor) -> int {
  type dragon (line 2984) | enum dragon {
  function f (line 3582) | auto f = big_decimal_fp{buffer.data(), static_cast<int>(buffer.size()), ...
  function OutputIt (line 3588) | auto write(OutputIt out, T value, format_specs specs,
  function FMT_INLINE (line 3746) | FMT_INLINE auto operator()(T value) -> iterator {
  type precision_checker (line 3770) | struct precision_checker {
  function value (line 3822) | statically_named_arg(const T& v) : value(v) {}
  function str_ (line 3943) | format_int(long value)
  function str_ (line 3945) | format_int(long long value)
  function str_ (line 3947) | format_int(unsigned value)
  function str_ (line 3949) | format_int(unsigned long value)
  function str_ (line 3951) | format_int(unsigned long long value)
  function FMT_CONSTEXPR20 (line 3955) | FMT_CONSTEXPR20 auto size() const -> size_t {
  function namespace (line 4036) | namespace enums {
  function class (line 4043) | class bytes {
  function bytes (line 4052) | struct formatter<bytes> {
  function FMT_CONSTEXPR (line 4146) | FMT_CONSTEXPR auto parse(basic_format_parse_context<Char>& ctx)
  function string (line 4192) | auto to_string(T value) -> std::string {
  function FMT_END_EXPORT (line 4215) | FMT_END_EXPORT
  function namespace (line 4329) | inline namespace literals {

FILE: vendor/format.cc
  function FMT_BEGIN_NAMESPACE (line 10) | FMT_BEGIN_NAMESPACE

FILE: vendor/json.hpp
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 247) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_END (line 258) | NLOHMANN_JSON_NAMESPACE_END
  type would_call_std_ (line 2814) | struct would_call_std_
  type value_t (line 2872) | enum class value_t : std::uint8_t
  function NLOHMANN_JSON_NAMESPACE_END (line 2937) | NLOHMANN_JSON_NAMESPACE_END
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 3030) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 3076) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 3267) | NLOHMANN_JSON_NAMESPACE_BEGIN
  class json_pointer (line 3416) | class json_pointer
  type ordered_map (line 3427) | struct ordered_map
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 3438) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 4230) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_END (line 4358) | NLOHMANN_JSON_NAMESPACE_END
  function NLOHMANN_JSON_NAMESPACE_END (line 4590) | NLOHMANN_JSON_NAMESPACE_END
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 4636) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 4644) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 4659) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 5174) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_END (line 5356) | NLOHMANN_JSON_NAMESPACE_END
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 5404) | NLOHMANN_JSON_NAMESPACE_BEGIN
  type adl_serializer (line 5832) | struct adl_serializer
    method from_json (line 5837) | static auto from_json(BasicJsonType && j, TargetType& val) noexcept(
    method from_json (line 5847) | static auto from_json(BasicJsonType && j) noexcept(
    method to_json (line 5857) | static auto to_json(BasicJsonType& j, TargetType && val) noexcept(
  function set_subtype (line 5938) | void set_subtype(subtype_type subtype_) noexcept
  function subtype_type (line 5946) | constexpr subtype_type subtype() const noexcept
  function has_subtype (line 5953) | constexpr bool has_subtype() const noexcept
  function clear_subtype (line 5960) | void clear_subtype() noexcept
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 5999) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 6171) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function json_sax_dom_parser (line 6813) | explicit json_sax_dom_parser(BasicJsonType& r, const bool allow_exceptio...
  function json_sax_dom_parser (line 6818) | json_sax_dom_parser(const json_sax_dom_parser&) = delete;
  function json_sax_dom_parser (line 6819) | json_sax_dom_parser(json_sax_dom_parser&&) = default;
  function null (line 6824) | bool null()
  function boolean (line 6830) | bool boolean(bool val)
  function number_integer (line 6836) | bool number_integer(number_integer_t val)
  function number_unsigned (line 6842) | bool number_unsigned(number_unsigned_t val)
  function number_float (line 6848) | bool number_float(number_float_t val, const string_t& /*unused*/)
  function string (line 6854) | bool string(string_t& val)
  function binary (line 6860) | bool binary(binary_t& val)
  function start_object (line 6866) | bool start_object(std::size_t len)
  function key (line 6878) | bool key(string_t& val)
  function end_object (line 6888) | bool end_object()
  function start_array (line 6898) | bool start_array(std::size_t len)
  function end_array (line 6910) | bool end_array()
  function parse_error (line 6921) | bool parse_error(std::size_t /*unused*/, const std::string& /*unused*/,
  function is_errored (line 6933) | constexpr bool is_errored() const
  class json_sax_dom_callback_parser (line 6982) | class json_sax_dom_callback_parser
    method json_sax_dom_callback_parser (line 6993) | json_sax_dom_callback_parser(BasicJsonType& r,
    method json_sax_dom_callback_parser (line 7002) | json_sax_dom_callback_parser(const json_sax_dom_callback_parser&) = de...
    method json_sax_dom_callback_parser (line 7003) | json_sax_dom_callback_parser(json_sax_dom_callback_parser&&) = default;
    method json_sax_dom_callback_parser (line 7004) | json_sax_dom_callback_parser& operator=(const json_sax_dom_callback_pa...
    method json_sax_dom_callback_parser (line 7005) | json_sax_dom_callback_parser& operator=(json_sax_dom_callback_parser&&...
    method null (line 7008) | bool null()
    method boolean (line 7014) | bool boolean(bool val)
    method number_integer (line 7020) | bool number_integer(number_integer_t val)
    method number_unsigned (line 7026) | bool number_unsigned(number_unsigned_t val)
    method number_float (line 7032) | bool number_float(number_float_t val, const string_t& /*unused*/)
    method string (line 7038) | bool string(string_t& val)
    method binary (line 7044) | bool binary(binary_t& val)
    method start_object (line 7050) | bool start_object(std::size_t len)
    method key (line 7068) | bool key(string_t& val)
    method end_object (line 7085) | bool end_object()
    method start_array (line 7121) | bool start_array(std::size_t len)
    method end_array (line 7138) | bool end_array()
    method parse_error (line 7171) | bool parse_error(std::size_t /*unused*/, const std::string& /*unused*/,
    method is_errored (line 7183) | constexpr bool is_errored() const
    method handle_value (line 7205) | std::pair<bool, BasicJsonType*> handle_value(Value&& v, const bool ski...
  class json_sax_acceptor (line 7289) | class json_sax_acceptor
    method null (line 7298) | bool null()
    method boolean (line 7303) | bool boolean(bool /*unused*/)
    method number_integer (line 7308) | bool number_integer(number_integer_t /*unused*/)
    method number_unsigned (line 7313) | bool number_unsigned(number_unsigned_t /*unused*/)
    method number_float (line 7318) | bool number_float(number_float_t /*unused*/, const string_t& /*unused*/)
    method string (line 7323) | bool string(string_t& /*unused*/)
    method binary (line 7328) | bool binary(binary_t& /*unused*/)
    method start_object (line 7333) | bool start_object(std::size_t /*unused*/ = static_cast<std::size_t>(-1))
    method key (line 7338) | bool key(string_t& /*unused*/)
    method end_object (line 7343) | bool end_object()
    method start_array (line 7348) | bool start_array(std::size_t /*unused*/ = static_cast<std::size_t>(-1))
    method end_array (line 7353) | bool end_array()
    method parse_error (line 7358) | bool parse_error(std::size_t /*unused*/, const std::string& /*unused*/...
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 7397) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function reset (line 8692) | void reset() noexcept
  function char_int_type (line 8709) | char_int_type get()
  function unget (line 8746) | void unget()
  function add (line 8773) | void add(char_int_type c)
  function number_unsigned_t (line 8790) | constexpr number_unsigned_t get_number_unsigned() const noexcept
  function number_float_t (line 8796) | constexpr number_float_t get_number_float() const noexcept
  function string_t (line 8802) | string_t& get_string()
  function position_t (line 8812) | constexpr position_t get_position() const noexcept
  function get_token_string (line 8820) | std::string get_token_string() const
  function JSON_HEDLEY_RETURNS_NON_NULL (line 8844) | JSON_HEDLEY_RETURNS_NON_NULL
  function skip_bom (line 8858) | bool skip_bom()
  function skip_whitespace (line 8872) | void skip_whitespace()
  function token_type (line 8881) | token_type scan()
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 9030) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_END (line 9170) | NLOHMANN_JSON_NAMESPACE_END
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 12195) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 12719) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_END (line 12835) | NLOHMANN_JSON_NAMESPACE_END
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 12890) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function pointer (line 13192) | pointer operator->() const
  function iter_impl (line 13234) | iter_impl operator++(int)& // NOLINT(cert-dcl21-cpp)
  function iter_impl (line 13245) | iter_impl& operator++()
  function iter_impl (line 13285) | iter_impl operator--(int)& // NOLINT(cert-dcl21-cpp)
  function iter_impl (line 13296) | iter_impl& operator--()
  function iter_impl (line 13444) | iter_impl& operator+=(difference_type i)
  function iter_impl (line 13481) | iter_impl& operator-=(difference_type i)
  function iter_impl (line 13490) | iter_impl operator+(difference_type i) const
  function friend (line 13501) | friend iter_impl operator+(difference_type i, const iter_impl& it)
  function iter_impl (line 13512) | iter_impl operator-(difference_type i) const
  function difference_type (line 13523) | difference_type operator-(const iter_impl& other) const
  function reference (line 13552) | reference operator[](difference_type n) const
  function reference (line 13606) | reference value() const
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 13641) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 13774) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 13835) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_BASIC_JSON_TPL_DECLARATION (line 13855) | NLOHMANN_BASIC_JSON_TPL_DECLARATION
  function json_pointer (line 13867) | explicit json_pointer(const string_t& s = "")
  function string_t (line 13873) | string_t to_string() const
  function friend (line 13894) | friend std::ostream& operator<<(std::ostream& o, const json_pointer& ptr)
  function json_pointer (line 13903) | json_pointer& operator/=(const json_pointer& ptr)
  function json_pointer (line 13913) | json_pointer& operator/=(string_t token)
  function json_pointer (line 13921) | json_pointer& operator/=(std::size_t array_idx)
  function friend (line 13928) | friend json_pointer operator/(const json_pointer& lhs,
  function friend (line 13936) | friend json_pointer operator/(const json_pointer& lhs, string_t token) /...
  function friend (line 13943) | friend json_pointer operator/(const json_pointer& lhs, std::size_t array...
  function json_pointer (line 13950) | json_pointer parent_pointer() const
  function pop_back (line 13964) | void pop_back()
  function string_t (line 13976) | const string_t& back() const
  function push_back (line 13988) | void push_back(const string_t& token)
  function push_back (line 13995) | void push_back(string_t&& token)
  function empty (line 14002) | bool empty() const noexcept
  function BasicJsonType (line 14079) | BasicJsonType& get_and_create(BasicJsonType& j) const
  function BasicJsonType (line 14159) | BasicJsonType& get_unchecked(BasicJsonType* ptr) const
  function BasicJsonType (line 14227) | BasicJsonType& get_checked(BasicJsonType* ptr) const
  function BasicJsonType (line 14285) | const BasicJsonType& get_unchecked(const BasicJsonType* ptr) const
  function BasicJsonType (line 14334) | const BasicJsonType& get_checked(const BasicJsonType* ptr) const
  function contains (line 14383) | bool contains(const BasicJsonType* ptr) const
  function split (line 14471) | static std::vector<string_t> split(const string_t& reference_string)
  function BasicJsonType (line 14611) | static BasicJsonType
  function convert (line 14640) | json_pointer<string_t> convert() const&
  function convert (line 14647) | json_pointer<string_t> convert()&&
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 14814) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 14939) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_END (line 15061) | NLOHMANN_JSON_NAMESPACE_END
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 16928) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function NLOHMANN_JSON_NAMESPACE_END (line 18025) | NLOHMANN_JSON_NAMESPACE_END
  function hex_bytes (line 18685) | static std::string hex_bytes(std::uint8_t byte)
  function is_negative_number (line 18696) | bool is_negative_number(NumberType x)
  function is_negative_number (line 18702) | bool is_negative_number(NumberType /*unused*/)
  function dump_integer (line 18722) | void dump_integer(NumberType x)
  function dump_float (line 18807) | void dump_float(number_float_t x)
  function dump_float (line 18828) | void dump_float(number_float_t x, std::true_type /*is_ieee_single_or_dou...
  function dump_float (line 18836) | void dump_float(number_float_t x, std::false_type /*is_ieee_single_or_do...
  function decode (line 18908) | static std::uint8_t decode(std::uint8_t& state, std::uint32_t& codep, co...
  function number_unsigned_t (line 18948) | number_unsigned_t remove_sign(number_unsigned_t x)
  function number_unsigned_t (line 18963) | inline number_unsigned_t remove_sign(number_integer_t x) noexcept
  function ordered_map (line 19050) | ordered_map() noexcept(noexcept(Container())) : Container{} {}
  function ordered_map (line 19051) | explicit ordered_map(const Allocator& alloc) noexcept(noexcept(Container...
  function ordered_map (line 19053) | ordered_map(It first, It last, const Allocator& alloc = Allocator())
  function ordered_map (line 19055) | ordered_map(std::initializer_list<value_type> init, const Allocator& all...
  function emplace (line 19058) | std::pair<iterator, bool> emplace(const key_type& key, T&& t)
  function emplace (line 19073) | std::pair<iterator, bool> emplace(KeyType && key, T && t)
  function T (line 19086) | T& operator[](const key_type& key)
  function T (line 19093) | T & operator[](KeyType && key)
  function T (line 19098) | const T& operator[](const key_type& key) const
  function T (line 19105) | const T & operator[](KeyType && key) const
  function T (line 19110) | T& at(const key_type& key)
  function T (line 19125) | T & at(KeyType && key) // NOLINT(cppcoreguidelines-missing-std-forward)
  function T (line 19138) | const T& at(const key_type& key) const
  function T (line 19153) | const T & at(KeyType && key) const // NOLINT(cppcoreguidelines-missing-s...
  function size_type (line 19166) | size_type erase(const key_type& key)
  function size_type (line 19187) | size_type erase(KeyType && key) // NOLINT(cppcoreguidelines-missing-std-...
  function iterator (line 19206) | iterator erase(iterator pos)
  function iterator (line 19211) | iterator erase(iterator first, iterator last)
  function size_type (line 19264) | size_type count(const key_type& key) const
  function size_type (line 19278) | size_type count(KeyType && key) const // NOLINT(cppcoreguidelines-missin...
  function iterator (line 19290) | iterator find(const key_type& key)
  function iterator (line 19304) | iterator find(KeyType && key) // NOLINT(cppcoreguidelines-missing-std-fo...
  function const_iterator (line 19316) | const_iterator find(const key_type& key) const
  function insert (line 19328) | std::pair<iterator, bool> insert( value_type&& value )
  function insert (line 19333) | std::pair<iterator, bool> insert( const value_type& value )
  function insert (line 19351) | void insert(InputIt first, InputIt last)
  function NLOHMANN_JSON_NAMESPACE_BEGIN (line 19378) | NLOHMANN_JSON_NAMESPACE_BEGIN
  function set_parents (line 20005) | void set_parents()
  function iterator (line 20042) | iterator set_parents(iterator it, typename iterator::difference_type cou...
  function reference (line 20055) | reference set_parent(reference j, std::size_t old_capacity = static_cast...
  function basic_json (line 20117) | basic_json(const value_t v)
  function basic_json (line 20125) | basic_json(std::nullptr_t = nullptr) noexcept // NOLINT(bugprone-excepti...
  function basic_json (line 20137) | basic_json(CompatibleType && val) noexcept(noexcept( // NOLINT(bugprone-...
  function basic_json (line 20151) | basic_json(const BasicJsonType& val)
  function basic_json (line 20204) | basic_json(initializer_list_t init,
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 20262) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 20273) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 20284) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 20295) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 20306) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 20314) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function basic_json (line 20322) | basic_json(size_type cnt, const basic_json& val):
  function basic_json (line 20334) | basic_json(InputIT first, InputIT last)
  function basic_json (line 20443) | basic_json(const JsonRef& ref) : basic_json(ref.moved_or_copied()) {}
  function basic_json (line 20447) | basic_json(const basic_json& other)
  function basic_json (line 20516) | basic_json(basic_json&& other) noexcept
  function basic_json (line 20533) | basic_json& operator=(basic_json other) noexcept (
  function value_t (line 20596) | constexpr value_t type() const noexcept
  function is_primitive (line 20603) | constexpr bool is_primitive() const noexcept
  function is_structured (line 20610) | constexpr bool is_structured() const noexcept
  function is_null (line 20617) | constexpr bool is_null() const noexcept
  function is_boolean (line 20624) | constexpr bool is_boolean() const noexcept
  function is_number (line 20631) | constexpr bool is_number() const noexcept
  function is_number_integer (line 20638) | constexpr bool is_number_integer() const noexcept
  function is_number_unsigned (line 20645) | constexpr bool is_number_unsigned() const noexcept
  function is_number_float (line 20652) | constexpr bool is_number_float() const noexcept
  function is_object (line 20659) | constexpr bool is_object() const noexcept
  function is_array (line 20666) | constexpr bool is_array() const noexcept
  function is_string (line 20673) | constexpr bool is_string() const noexcept
  function is_binary (line 20680) | constexpr bool is_binary() const noexcept
  function is_discarded (line 20687) | constexpr bool is_discarded() const noexcept
  function object_t (line 20718) | object_t* get_impl_ptr(object_t* /*unused*/) noexcept
  function object_t (line 20724) | constexpr const object_t* get_impl_ptr(const object_t* /*unused*/) const...
  function array_t (line 20730) | array_t* get_impl_ptr(array_t* /*unused*/) noexcept
  function array_t (line 20736) | constexpr const array_t* get_impl_ptr(const array_t* /*unused*/) const n...
  function string_t (line 20742) | string_t* get_impl_ptr(string_t* /*unused*/) noexcept
  function string_t (line 20748) | constexpr const string_t* get_impl_ptr(const string_t* /*unused*/) const...
  function boolean_t (line 20754) | boolean_t* get_impl_ptr(boolean_t* /*unused*/) noexcept
  function boolean_t (line 20760) | constexpr const boolean_t* get_impl_ptr(const boolean_t* /*unused*/) con...
  function number_integer_t (line 20766) | number_integer_t* get_impl_ptr(number_integer_t* /*unused*/) noexcept
  function number_integer_t (line 20772) | constexpr const number_integer_t* get_impl_ptr(const number_integer_t* /...
  function number_unsigned_t (line 20778) | number_unsigned_t* get_impl_ptr(number_unsigned_t* /*unused*/) noexcept
  function number_unsigned_t (line 20784) | constexpr const number_unsigned_t* get_impl_ptr(const number_unsigned_t*...
  function number_float_t (line 20790) | number_float_t* get_impl_ptr(number_float_t* /*unused*/) noexcept
  function number_float_t (line 20796) | constexpr const number_float_t* get_impl_ptr(const number_float_t* /*unu...
  function binary_t (line 20802) | binary_t* get_impl_ptr(binary_t* /*unused*/) noexcept
  function binary_t (line 20808) | constexpr const binary_t* get_impl_ptr(const binary_t* /*unused*/) const...
  function ReferenceType (line 20825) | static ReferenceType get_ref_impl(ThisType& obj)
  function get_ptr (line 20858) | constexpr auto get_ptr() const noexcept -> decltype(std::declval<const b...
  function ValueType (line 20950) | ValueType get_impl(detail::priority_tag<1> /*unused*/) const noexcept(no...
  function BasicJsonType (line 20975) | BasicJsonType get_impl(detail::priority_tag<2> /*unused*/) const
  function basic_json (line 20998) | basic_json get_impl(detail::priority_tag<3> /*unused*/) const
  function get_impl (line 21011) | constexpr auto get_impl(detail::priority_tag<4> /*unused*/) const noexcept
  function get (line 21087) | auto get() noexcept -> decltype(std::declval<basic_json_t&>().template g...
  function ValueType (line 21100) | ValueType & get_to(ValueType& v) const noexcept(noexcept(
  function ValueType (line 21113) | ValueType & get_to(ValueType& v) const
  function Array (line 21124) | Array get_to(T (&v)[N]) const // NOLINT(cppcoreguidelines-avoid-c-arrays...
  function ReferenceType (line 21136) | ReferenceType get_ref()
  function ReferenceType (line 21147) | ReferenceType get_ref() const
  function binary_t (line 21206) | binary_t& get_binary()
  function binary_t (line 21218) | const binary_t& get_binary() const
  function reference (line 21240) | reference at(size_type idx)
  function const_reference (line 21263) | const_reference at(size_type idx) const
  function reference (line 21286) | reference at(const typename object_t::key_type& key)
  function reference (line 21306) | reference at(KeyType && key)
  function const_reference (line 21324) | const_reference at(const typename object_t::key_type& key) const
  function const_reference (line 21344) | const_reference at(KeyType && key) const
  function reference (line 21362) | reference operator[](size_type idx)
  function const_reference (line 21408) | const_reference operator[](size_type idx) const
  function reference (line 21421) | reference operator[](typename object_t::key_type key)
  function const_reference (line 21443) | const_reference operator[](const typename object_t::key_type& key) const
  function reference (line 21459) | reference operator[](T* key)
  function const_reference (line 21465) | const_reference operator[](T* key) const
  function reference (line 21474) | reference operator[](KeyType && key)
  function const_reference (line 21498) | const_reference operator[](KeyType && key) const
  class ValueType (line 21524) | class ValueType
  function ReturnType (line 21553) | ReturnType value(const typename object_t::key_type& key, ValueType && de...
  function ValueType (line 21579) | ValueType value(KeyType && key, const ValueType& default_value) const
  function ReturnType (line 21606) | ReturnType value(KeyType && key, ValueType && default_value) const
  function ValueType (line 21629) | ValueType value(const json_pointer& ptr, const ValueType& default_value)...
  function ReturnType (line 21654) | ReturnType value(const json_pointer& ptr, ValueType && default_value) const
  function ValueType (line 21678) | ValueType value(const ::nlohmann::json_pointer<BasicJsonType>& ptr, cons...
  function ReturnType (line 21689) | ReturnType value(const ::nlohmann::json_pointer<BasicJsonType>& ptr, Val...
  function reference (line 21696) | reference front()
  function const_reference (line 21703) | const_reference front() const
  function reference (line 21710) | reference back()
  function const_reference (line 21719) | const_reference back() const
  function IteratorType (line 21731) | IteratorType erase(IteratorType pos)
  function IteratorType (line 21801) | IteratorType erase(IteratorType first, IteratorType last)
  function erase_internal (line 21869) | private:
  function size_type (line 21885) | size_type erase_internal(KeyType && key)
  function size_type (line 21917) | size_type erase(KeyType && key)
  function erase (line 21924) | void erase(const size_type idx)
  function iterator (line 21953) | iterator find(const typename object_t::key_type& key)
  function const_iterator (line 21967) | const_iterator find(const typename object_t::key_type& key) const
  function iterator (line 21983) | iterator find(KeyType && key)
  function const_iterator (line 21999) | const_iterator find(KeyType && key) const
  function size_type (line 22013) | size_type count(const typename object_t::key_type& key) const
  function size_type (line 22023) | size_type count(KeyType && key) const
  function contains (line 22031) | bool contains(const typename object_t::key_type& key) const
  function contains (line 22040) | bool contains(KeyType && key) const
  function contains (line 22047) | bool contains(const json_pointer& ptr) const
  function contains (line 22054) | bool contains(const typename ::nlohmann::json_pointer<BasicJsonType>& pt...
  function iterator (line 22070) | iterator begin() noexcept
  function const_iterator (line 22079) | const_iterator begin() const noexcept
  function const_iterator (line 22086) | const_iterator cbegin() const noexcept
  function iterator (line 22095) | iterator end() noexcept
  function const_iterator (line 22104) | const_iterator end() const noexcept
  function const_iterator (line 22111) | const_iterator cend() const noexcept
  function reverse_iterator (line 22120) | reverse_iterator rbegin() noexcept
  function const_reverse_iterator (line 22127) | const_reverse_iterator rbegin() const noexcept
  function reverse_iterator (line 22134) | reverse_iterator rend() noexcept
  function const_reverse_iterator (line 22141) | const_reverse_iterator rend() const noexcept
  function const_reverse_iterator (line 22148) | const_reverse_iterator crbegin() const noexcept
  function const_reverse_iterator (line 22155) | const_reverse_iterator crend() const noexcept
  function iterator_wrapper (line 22167) | static iteration_proxy<iterator> iterator_wrapper(reference ref) noexcept
  function iterator_wrapper (line 22178) | static iteration_proxy<const_iterator> iterator_wrapper(const_reference ...
  function items (line 22185) | iteration_proxy<iterator> items() noexcept
  function items (line 22192) | iteration_proxy<const_iterator> items() const noexcept
  function empty (line 22208) | bool empty() const noexcept
  function size_type (line 22247) | size_type size() const noexcept
  function size_type (line 22286) | size_type max_size() const noexcept
  function clear (line 22329) | void clear() noexcept
  function push_back (line 22390) | void push_back(basic_json&& val)
  function reference (line 22415) | reference operator+=(basic_json&& val)
  function push_back (line 22423) | void push_back(const basic_json& val)
  function reference (line 22447) | reference operator+=(const basic_json& val)
  function push_back (line 22455) | void push_back(const typename object_t::value_type& val)
  function reference (line 22478) | reference operator+=(const typename object_t::value_type& val)
  function push_back (line 22486) | void push_back(initializer_list_t init)
  function reference (line 22502) | reference operator+=(initializer_list_t init)
  function reference (line 22511) | reference emplace_back(Args&& ... args)
  function emplace (line 22536) | std::pair<iterator, bool> emplace(Args&& ... args)
  function iterator (line 22568) | iterator insert_iterator(const_iterator pos, Args&& ... args)
  function iterator (line 22587) | iterator insert(const_iterator pos, const basic_json& val)
  function iterator (line 22607) | iterator insert(const_iterator pos, basic_json&& val)
  function iterator (line 22614) | iterator insert(const_iterator pos, size_type cnt, const basic_json& val)
  function iterator (line 22634) | iterator insert(const_iterator pos, const_iterator first, const_iterator...
  function iterator (line 22665) | iterator insert(const_iterator pos, initializer_list_t ilist)
  function insert (line 22685) | void insert(const_iterator first, const_iterator last)
  function update (line 22710) | void update(const_reference j, bool merge_objects = false)
  function update (line 22717) | void update(const_iterator first, const_iterator last, bool merge_object...
  function swap (line 22764) | void swap(reference other) noexcept (
  function friend (line 22781) | friend void swap(reference left, reference right) noexcept (
  function swap (line 22793) | void swap(array_t& other) // NOLINT(bugprone-exception-escape,cppcoregui...
  function swap (line 22809) | void swap(object_t& other) // NOLINT(bugprone-exception-escape,cppcoregu...
  function swap (line 22825) | void swap(string_t& other) // NOLINT(bugprone-exception-escape,cppcoregu...
  function swap (line 22841) | void swap(binary_t& other) // NOLINT(bugprone-exception-escape,cppcoregu...
  function swap (line 22857) | void swap(typename binary_t::container_type& other) // NOLINT(bugprone-e...
  function else (line 22946) | else if(compares_unordered(lhs, rhs))\
  function compares_unordered (line 22975) | bool compares_unordered(const_reference rhs, bool inverse = false) const...
  function friend (line 23088) | friend bool operator==(const_reference lhs, const_reference rhs) noexcept
  function friend (line 23120) | friend bool operator!=(const_reference lhs, const_reference rhs) noexcept
  function friend (line 23177) | friend bool operator<=(const_reference lhs, const_reference rhs) noexcept
  function friend (line 23206) | friend bool operator>(const_reference lhs, const_reference rhs) noexcept
  function friend (line 23236) | friend bool operator>=(const_reference lhs, const_reference rhs) noexcept
  function friend (line 23277) | friend std::ostream& operator<<(std::ostream& o, const basic_json& j)
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 23316) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 23330) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function basic_json (line 23344) | static basic_json parse(detail::span_input_adapter&& i,
  function accept (line 23357) | static bool accept(InputType&& i,
  function accept (line 23366) | static bool accept(IteratorType first, IteratorType last,
  function accept (line 23374) | static bool accept(detail::span_input_adapter&& i,
  function sax_parse (line 23384) | static bool sax_parse(InputType&& i, SAX* sax,
  function sax_parse (line 23399) | static bool sax_parse(IteratorType first, IteratorType last, SAX* sax,
  function sax_parse (line 23418) | static bool sax_parse(detail::span_input_adapter&& i, SAX* sax,
  function JSON_HEDLEY_RETURNS_NON_NULL (line 23459) | JSON_HEDLEY_RETURNS_NON_NULL
  type data (line 23491) | struct data
    method data (line 23499) | data(const value_t v)
    method data (line 23504) | data(size_type cnt, const basic_json& val)
    method data (line 23510) | data() noexcept = default;
    method data (line 23511) | data(data&&) noexcept = default;
    method data (line 23512) | data(const data&) noexcept = delete;
    method data (line 23513) | data& operator=(data&&) noexcept = delete;
    method data (line 23514) | data& operator=(const data&) noexcept = delete;
  function to_cbor (line 23548) | static void to_cbor(const basic_json& j, detail::output_adapter<std::uin...
  function to_cbor (line 23555) | static void to_cbor(const basic_json& j, detail::output_adapter<char> o)
  function to_msgpack (line 23562) | static std::vector<std::uint8_t> to_msgpack(const basic_json& j)
  function to_msgpack (line 23571) | static void to_msgpack(const basic_json& j, detail::output_adapter<std::...
  function to_msgpack (line 23578) | static void to_msgpack(const basic_json& j, detail::output_adapter<char> o)
  function to_ubjson (line 23585) | static std::vector<std::uint8_t> to_ubjson(const basic_json& j,
  function to_ubjson (line 23596) | static void to_ubjson(const basic_json& j, detail::output_adapter<std::u...
  function to_ubjson (line 23604) | static void to_ubjson(const basic_json& j, detail::output_adapter<char> o,
  function to_bjdata (line 23612) | static std::vector<std::uint8_t> to_bjdata(const basic_json& j,
  function to_bjdata (line 23623) | static void to_bjdata(const basic_json& j, detail::output_adapter<std::u...
  function to_bjdata (line 23631) | static void to_bjdata(const basic_json& j, detail::output_adapter<char> o,
  function to_bson (line 23639) | static std::vector<std::uint8_t> to_bson(const basic_json& j)
  function to_bson (line 23648) | static void to_bson(const basic_json& j, detail::output_adapter<std::uin...
  function to_bson (line 23655) | static void to_bson(const basic_json& j, detail::output_adapter<char> o)
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 23663) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 23679) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function basic_json (line 23695) | static basic_json from_cbor(const T* ptr, std::size_t len,
  function basic_json (line 23705) | static basic_json from_cbor(detail::span_input_adapter&& i,
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 23721) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 23736) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function basic_json (line 23751) | static basic_json from_msgpack(const T* ptr, std::size_t len,
  function basic_json (line 23760) | static basic_json from_msgpack(detail::span_input_adapter&& i,
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 23775) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 23790) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function basic_json (line 23805) | static basic_json from_ubjson(const T* ptr, std::size_t len,
  function basic_json (line 23814) | static basic_json from_ubjson(detail::span_input_adapter&& i,
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 23829) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 23844) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 23859) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 23874) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function basic_json (line 23889) | static basic_json from_bson(const T* ptr, std::size_t len,
  function basic_json (line 23898) | static basic_json from_bson(detail::span_input_adapter&& i,
  function reference (line 23920) | reference operator[](const json_pointer& ptr)
  function reference (line 23927) | reference operator[](const ::nlohmann::json_pointer<BasicJsonType>& ptr)
  function const_reference (line 23934) | const_reference operator[](const json_pointer& ptr) const
  function const_reference (line 23941) | const_reference operator[](const ::nlohmann::json_pointer<BasicJsonType>...
  function reference (line 23948) | reference at(const json_pointer& ptr)
  function reference (line 23955) | reference at(const ::nlohmann::json_pointer<BasicJsonType>& ptr)
  function const_reference (line 23962) | const_reference at(const json_pointer& ptr) const
  function const_reference (line 23969) | const_reference at(const ::nlohmann::json_pointer<BasicJsonType>& ptr) c...
  function basic_json (line 23976) | basic_json flatten() const
  function basic_json (line 23985) | basic_json unflatten() const
  function patch_inplace (line 24001) | void patch_inplace(const basic_json& json_patch)
  function basic_json (line 24272) | basic_json patch(const basic_json& json_patch) const
  function JSON_HEDLEY_WARN_UNUSED_RESULT (line 24281) | JSON_HEDLEY_WARN_UNUSED_RESULT
  function merge_patch (line 24424) | void merge_patch(const basic_json& apply_patch)
  function NLOHMANN_BASIC_JSON_TPL_DECLARATION (line 24455) | NLOHMANN_BASIC_JSON_TPL_DECLARATION
  function NLOHMANN_JSON_NAMESPACE_END (line 24492) | NLOHMANN_JSON_NAMESPACE_END
Condensed preview — 35 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,706K chars).
[
  {
    "path": ".gitignore",
    "chars": 188,
    "preview": "env/\n\n# build intermediates\n.vscode/\nbuild/\n__pycache__/\n*.egg-info/\n*.cpython-312-x86_64-linux-gnu.so\n\n# profiling tool"
  },
  {
    "path": "LICENSE.md",
    "chars": 4814,
    "preview": "MIT License\n\nCopyright (c) 2025 Andrew Chan\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
  },
  {
    "path": "Makefile",
    "chars": 2153,
    "preview": "MAKEFLAGS+=-r -j\n\nUNAME=$(shell uname)\n\nBUILD=build\nASM_DIR=$(BUILD)/asm\n\n# compile .c, .cpp, .cu files\nSOURCES=$(filter"
  },
  {
    "path": "README.md",
    "chars": 7335,
    "preview": "This is an CPU-only inference implementation for the DeepSeek family of large language models written in C++, based on ["
  },
  {
    "path": "convert.py",
    "chars": 27150,
    "preview": "# Converts a model consisting of a huggingface config.json, tokenizer.json, and .safetensors weights into a .yalm file,\n"
  },
  {
    "path": "pyproject.toml",
    "chars": 287,
    "preview": "[build-system]\nrequires = [\"setuptools>=42\", \"wheel\", \"torch>=2.0.0\", \"ninja\", \"numpy\"]\nbuild-backend = \"setuptools.buil"
  },
  {
    "path": "quantizer.cpp",
    "chars": 2704,
    "preview": "#include <torch/extension.h>\n#include \"quant.h\"\n\ntorch::Tensor quantize_q2_k(torch::Tensor& input) {\n  // Row-major quan"
  },
  {
    "path": "quantizer.py",
    "chars": 627,
    "preview": "import torch\nimport quantizer_cpp\nfrom typing import Literal\n\ndef k_quantize(tensor: torch.Tensor, method: Literal[\"q2_k"
  },
  {
    "path": "setup.py",
    "chars": 879,
    "preview": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CppExtension\nfrom setuptools.dist imp"
  },
  {
    "path": "src/codec.cpp",
    "chars": 10700,
    "preview": "#include \"codec.h\"\n\n#include \"quant.h\"\n\n#include \"fmt/format.h\"\n\n#include <fcntl.h>\n#include <iostream>\n#include <sys/mm"
  },
  {
    "path": "src/codec.h",
    "chars": 3880,
    "preview": "#pragma once\n\n#include \"json.hpp\"\n\n#include <array>\n#include <cstdint>\n#include <string>\n#include <unordered_map>\n#inclu"
  },
  {
    "path": "src/debug.cpp",
    "chars": 1163,
    "preview": "#include \"debug.h\"\n#include \"model.h\"\n\ntemplate <typename T>\nbool BinaryDumper::save(const std::string& filename, const "
  },
  {
    "path": "src/debug.h",
    "chars": 357,
    "preview": "#include <fstream>\n#include <vector>\n#include <cstdint>\n#include <iostream>\n\nstruct BinaryDumper {\n  // Save T array to "
  },
  {
    "path": "src/infer.cpp",
    "chars": 45803,
    "preview": "#include \"model.h\"\n\n#include <assert.h>\n#include <cfloat>\n#include <math.h>\n\n#include \"quant.h\"\n#include \"profile.h\"\n\n#i"
  },
  {
    "path": "src/main.cpp",
    "chars": 23122,
    "preview": "#include <cmath>\n#include <cstdint>\n#include <cstdlib>\n#include <fstream>\n#include <iostream>\n#include <sstream>\n#includ"
  },
  {
    "path": "src/model.cpp",
    "chars": 35546,
    "preview": "#include \"model.h\"\n\n#include \"json.hpp\"\n#include <algorithm>\n#include <array>\n#include <cfloat>\n#include \"fmt/format.h\"\n"
  },
  {
    "path": "src/model.h",
    "chars": 21699,
    "preview": "#pragma once\n\n#include \"codec.h\"\n\n#include <memory>\n#include <vector>\n#include <map>\n#include <optional>\n\n#include \"quan"
  },
  {
    "path": "src/profile.cpp",
    "chars": 1268,
    "preview": "#include \"profile.h\"\n\n#include <vector>\n\nstatic bool _profile_enabled = true;\nstatic std::vector<std::string> _profile_s"
  },
  {
    "path": "src/profile.h",
    "chars": 1353,
    "preview": "#include <omp.h>\n#include <map>\n#include <string>\n\n#define PROFILE_ENABLED 0\n\n// Toggle aggregation of profile scopes at"
  },
  {
    "path": "src/quant.cpp",
    "chars": 25783,
    "preview": "/*\nK-quants adapted from llama.cpp\n\nMIT License\n\nCopyright (c) 2023-2024 The ggml authors\n\nPermission is hereby granted,"
  },
  {
    "path": "src/quant.h",
    "chars": 5212,
    "preview": "/*\nK-quants adapted from llama.cpp\n\nMIT License\n\nCopyright (c) 2023-2024 The ggml authors\n\nPermission is hereby granted,"
  },
  {
    "path": "src/sampler.cpp",
    "chars": 2094,
    "preview": "#include \"sampler.h\"\n\n#include <algorithm>\n#include <cfloat>\n#include <cstdlib>\n\nSampler::Sampler(const std::shared_ptr<"
  },
  {
    "path": "src/sampler.h",
    "chars": 634,
    "preview": "#pragma once\n\n#include \"model.h\"\n\n#include <memory>\n\nstruct Sampler {\n  int vocab_size;\n\n  Sampler(const std::shared_ptr"
  },
  {
    "path": "src/test.cpp",
    "chars": 10814,
    "preview": "#include <iostream>\n#include <memory>\n#include <omp.h>\n#include <random>\n#include <thread>\n#include <vector>\n\n#include \""
  },
  {
    "path": "src/time_utils.cpp",
    "chars": 207,
    "preview": "#include \"time_utils.h\"\n\n#include <chrono>\n\nuint64_t get_timestamp_ms() {\n  return std::chrono::duration_cast<std::chron"
  },
  {
    "path": "src/time_utils.h",
    "chars": 62,
    "preview": "#pragma once\n\n#include <cstdint>\n\nuint64_t get_timestamp_ms();"
  },
  {
    "path": "src/tokenizer.cpp",
    "chars": 3505,
    "preview": "#include \"tokenizer.h\"\n\nTokenizer::Tokenizer(const YALMData& data) {\n  this->bos_id = std::stoi(data.metadata.at(\"bos_to"
  },
  {
    "path": "src/tokenizer.h",
    "chars": 2042,
    "preview": "#pragma once\n\n#include \"codec.h\"\n\n#include <memory>\n#include <string>\n#include <vector>\n#include <unordered_map>\n\nstruct"
  },
  {
    "path": "src/wikitest.cat.1chunk.v2-encoded.txt",
    "chars": 72822,
    "preview": "        100000,    207,    185,    403,   7940,  88819,    367,    403,    207,\n           185,    207,    185,   7940, "
  },
  {
    "path": "src/wikitest.cat.1chunk.v3-encoded.txt",
    "chars": 72823,
    "preview": "             0,    539,    438,  10498, 102771,    402,    438,  54921,  10498,\n        102771,    402,    344,    411, "
  },
  {
    "path": "vendor/fmt/base.h",
    "chars": 104877,
    "preview": "// Formatting library for C++ - the base API for char/UTF-8\n//\n// Copyright (c) 2012 - present, Victor Zverovich\n// All "
  },
  {
    "path": "vendor/fmt/format-inl.h",
    "chars": 80577,
    "preview": "// Formatting library for C++ - implementation\n//\n// Copyright (c) 2012 - 2016, Victor Zverovich\n// All rights reserved."
  },
  {
    "path": "vendor/fmt/format.h",
    "chars": 163131,
    "preview": "/*\n  Formatting library for C++\n\n  Copyright (c) 2012 - present, Victor Zverovich\n\n  Permission is hereby granted, free "
  },
  {
    "path": "vendor/format.cc",
    "chars": 1333,
    "preview": "// Formatting library for C++\n//\n// Copyright (c) 2012 - 2016, Victor Zverovich\n// All rights reserved.\n//\n// For the li"
  },
  {
    "path": "vendor/json.hpp",
    "chars": 919974,
    "preview": "//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11"
  }
]

About this extraction

This page contains the full source code of the andrewkchan/deepseek.cpp GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 35 files (1.6 MB), approximately 484.3k tokens, and a symbol index with 781 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!