SYMBOL INDEX (51 symbols across 5 files) FILE: inference/convert.py function main (line 33) | def main(hf_ckpt_path, save_path, n_experts, mp): FILE: inference/fp8_cast_bf16.py function main (line 12) | def main(fp8_path, bf16_path): FILE: inference/generate.py function sample (line 14) | def sample(logits, temperature: float = 1.0): function generate (line 31) | def generate( function main (line 81) | def main( FILE: inference/kernel.py function act_quant_kernel (line 10) | def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scal... function act_quant (line 38) | def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: Optiona... function weight_dequant_kernel (line 61) | def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.cons... function weight_dequant (line 89) | def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 1... function fp8_gemm_kernel (line 120) | def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, function fp8_gemm (line 175) | def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: t... FILE: inference/model.py class ModelArgs (line 20) | class ModelArgs: class ParallelEmbedding (line 89) | class ParallelEmbedding(nn.Module): method __init__ (line 97) | def __init__(self, vocab_size: int, dim: int): method forward (line 107) | def forward(self, x: torch.Tensor) -> torch.Tensor: function linear (line 131) | def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.T... class Linear (line 166) | class Linear(nn.Module): method __init__ (line 179) | def __init__(self, in_features: int, out_features: int, bias: bool = F... method forward (line 195) | def forward(self, x: torch.Tensor) -> torch.Tensor: class ColumnParallelLinear (line 208) | class ColumnParallelLinear(Linear): method __init__ (line 218) | def __init__(self, in_features: int, out_features: int, bias: bool = F... method forward (line 223) | def forward(self, x: torch.Tensor) -> torch.Tensor: class RowParallelLinear (line 237) | class RowParallelLinear(Linear): method __init__ (line 247) | def __init__(self, in_features: int, out_features: int, bias: bool = F... method forward (line 252) | def forward(self, x: torch.Tensor) -> torch.Tensor: class RMSNorm (line 270) | class RMSNorm(nn.Module): method __init__ (line 278) | def __init__(self, dim: int, eps: float = 1e-6): method forward (line 284) | def forward(self, x: torch.Tensor): function precompute_freqs_cis (line 297) | def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: function apply_rotary_emb (line 378) | def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.... class MLA (line 396) | class MLA(nn.Module): method __init__ (line 412) | def __init__(self, args: ModelArgs): method forward (line 446) | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Te... class MLP (line 500) | class MLP(nn.Module): method __init__ (line 509) | def __init__(self, dim: int, inter_dim: int): method forward (line 522) | def forward(self, x: torch.Tensor) -> torch.Tensor: class Gate (line 535) | class Gate(nn.Module): method __init__ (line 549) | def __init__(self, args: ModelArgs): method forward (line 566) | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: class Expert (line 601) | class Expert(nn.Module): method __init__ (line 610) | def __init__(self, dim: int, inter_dim: int): method forward (line 623) | def forward(self, x: torch.Tensor) -> torch.Tensor: class MoE (line 636) | class MoE(nn.Module): method __init__ (line 649) | def __init__(self, args: ModelArgs): method forward (line 669) | def forward(self, x: torch.Tensor) -> torch.Tensor: class Block (line 696) | class Block(nn.Module): method __init__ (line 706) | def __init__(self, layer_id: int, args: ModelArgs): method forward (line 720) | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Te... class Transformer (line 738) | class Transformer(nn.Module): method __init__ (line 750) | def __init__(self, args: ModelArgs): method forward (line 773) | def forward(self, tokens: torch.Tensor, start_pos: int = 0):