SYMBOL INDEX (125 symbols across 13 files) FILE: src/mistral_inference/args.py class VisionEncoderArgs (line 13) | class VisionEncoderArgs: class TransformerArgs (line 30) | class TransformerArgs(Serializable): method __post_init__ (line 54) | def __post_init__(self) -> None: class MambaArgs (line 63) | class MambaArgs(Serializable): method __post_init__ (line 75) | def __post_init__(self) -> None: FILE: src/mistral_inference/cache.py function get_cache_sizes (line 13) | def get_cache_sizes(n_layers: int, max_seq_len: int, sliding_window: Opt... class CacheInputMetadata (line 28) | class CacheInputMetadata: function interleave_list (line 54) | def interleave_list(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> L... function unrotate (line 59) | def unrotate(cache: torch.Tensor, seqlen: int) -> torch.Tensor: class CacheView (line 70) | class CacheView: method __init__ (line 71) | def __init__( method update (line 83) | def update(self, xk: torch.Tensor, xv: torch.Tensor) -> None: method interleave_kv (line 94) | def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[t... method max_seq_len (line 120) | def max_seq_len(self) -> int: method key (line 124) | def key(self) -> torch.Tensor: method value (line 128) | def value(self) -> torch.Tensor: method prefill (line 132) | def prefill(self) -> bool: method mask (line 136) | def mask(self) -> AttentionBias: class BufferCache (line 140) | class BufferCache: method __init__ (line 146) | def __init__( method get_view (line 172) | def get_view(self, layer_id: int, metadata: CacheInputMetadata) -> Cac... method reset (line 176) | def reset(self) -> None: method init_kvseqlens (line 179) | def init_kvseqlens(self, batch_size: int) -> None: method device (line 183) | def device(self) -> torch.device: method to (line 186) | def to(self, device: torch.device, dtype: torch.dtype) -> "BufferCache": method update_seqlens (line 193) | def update_seqlens(self, seqlens: List[int]) -> None: method get_input_metadata (line 197) | def get_input_metadata(self, seqlens: List[int]) -> List[CacheInputMet... method _get_input_metadata_layer (line 225) | def _get_input_metadata_layer(self, cache_size: int, seqlens: List[int... FILE: src/mistral_inference/generate.py function generate_mamba (line 12) | def generate_mamba( function generate (line 44) | def generate( function sample (line 151) | def sample(logits: torch.Tensor, temperature: float, top_p: float) -> to... function sample_top_p (line 161) | def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: FILE: src/mistral_inference/lora.py class LoraArgs (line 13) | class LoraArgs(Serializable): method __post_init__ (line 17) | def __post_init__(self) -> None: class LoRALinear (line 22) | class LoRALinear(nn.Module): method __init__ (line 35) | def __init__( method forward (line 71) | def forward(self, x: torch.Tensor) -> torch.Tensor: method _load_from_state_dict (line 76) | def _load_from_state_dict(self, state_dict: Dict[str, Any], prefix: st... class LoRALoaderMixin (line 92) | class LoRALoaderMixin: method load_lora (line 93) | def load_lora(self, lora_path: Union[Path, str], scaling: float = 2.0)... method _load_lora_state_dict (line 103) | def _load_lora_state_dict(self, lora_state_dict: Dict[str, torch.Tenso... FILE: src/mistral_inference/main.py function is_torchrun (line 36) | def is_torchrun() -> bool: function load_tokenizer (line 41) | def load_tokenizer(model_path: Path) -> MistralTokenizer: function get_model_cls (line 60) | def get_model_cls(model_path: str) -> Union[Type[Mamba], Type[Transforme... function pad_and_convert_to_tensor (line 67) | def pad_and_convert_to_tensor(list_of_lists: List[List[int]], pad_id: in... function _get_multimodal_input (line 77) | def _get_multimodal_input() -> Tuple[UserMessage, bool]: function interactive (line 102) | def interactive( function demo (line 203) | def demo( function mistral_chat (line 268) | def mistral_chat() -> None: function mistral_demo (line 272) | def mistral_demo() -> None: FILE: src/mistral_inference/mamba.py class Mamba (line 23) | class Mamba(ModelBase, nn.Module): method __init__ (line 24) | def __init__(self, args: MambaArgs): method dtype (line 46) | def dtype(self) -> torch.dtype: method device (line 50) | def device(self) -> torch.device: method forward (line 53) | def forward( method from_folder (line 64) | def from_folder( FILE: src/mistral_inference/model.py class ModelBase (line 11) | class ModelBase(nn.Module, ABC): method __init__ (line 12) | def __init__(self) -> None: method dtype (line 17) | def dtype(self) -> torch.dtype: method device (line 22) | def device(self) -> torch.device: method forward (line 26) | def forward( method from_folder (line 36) | def from_folder( FILE: src/mistral_inference/moe.py class MoeArgs (line 11) | class MoeArgs(Serializable): class MoeLayer (line 16) | class MoeLayer(nn.Module): method __init__ (line 17) | def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args... method forward (line 24) | def forward(self, inputs: torch.Tensor) -> torch.Tensor: FILE: src/mistral_inference/rope.py function precompute_freqs_cis (line 6) | def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor: function apply_rotary_emb (line 13) | def apply_rotary_emb( function precompute_freqs_cis_2d (line 26) | def precompute_freqs_cis_2d( FILE: src/mistral_inference/transformer.py class SimpleInputMetadata (line 22) | class SimpleInputMetadata: method from_seqlens (line 27) | def from_seqlens(seqlens: List[int], device: torch.device) -> "SimpleI... class Transformer (line 33) | class Transformer(ModelBase, LoRALoaderMixin): method __init__ (line 34) | def __init__( method dtype (line 101) | def dtype(self) -> torch.dtype: method device (line 105) | def device(self) -> torch.device: method freqs_cis (line 109) | def freqs_cis(self) -> torch.Tensor: method embed_vision_language_features (line 122) | def embed_vision_language_features(self, input_ids: torch.Tensor, imag... method forward_partial (line 163) | def forward_partial( method forward (line 221) | def forward( method load_state_dict (line 244) | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool ... method from_folder (line 298) | def from_folder( FILE: src/mistral_inference/transformer_layers.py function repeat_kv (line 16) | def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, di... function maybe_lora (line 22) | def maybe_lora( class Attention (line 31) | class Attention(nn.Module): method __init__ (line 32) | def __init__( method forward (line 56) | def forward( class FeedForward (line 96) | class FeedForward(nn.Module): method __init__ (line 97) | def __init__(self, dim: int, hidden_dim: int, lora: Optional[LoraArgs]... method forward (line 105) | def forward(self, x: torch.Tensor) -> torch.Tensor: class RMSNorm (line 109) | class RMSNorm(torch.nn.Module): method __init__ (line 110) | def __init__(self, dim: int, eps: float = 1e-6): method _norm (line 115) | def _norm(self, x: torch.Tensor) -> torch.Tensor: method forward (line 118) | def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock (line 123) | class TransformerBlock(nn.Module): method __init__ (line 124) | def __init__( method forward (line 158) | def forward( FILE: src/mistral_inference/vision_encoder.py function position_meshgrid (line 12) | def position_meshgrid( class VisionTransformer (line 31) | class VisionTransformer(nn.Module): method __init__ (line 32) | def __init__(self, args: VisionEncoderArgs): method max_patches_per_side (line 50) | def max_patches_per_side(self) -> int: method device (line 54) | def device(self) -> torch.device: method freqs_cis (line 58) | def freqs_cis(self) -> torch.Tensor: method forward (line 72) | def forward( class VisionLanguageAdapter (line 105) | class VisionLanguageAdapter(nn.Module): method __init__ (line 106) | def __init__(self, in_dim: int, out_dim: int, bias: bool = True): method forward (line 116) | def forward(self, x: torch.Tensor) -> torch.Tensor: class VisionTransformerBlocks (line 120) | class VisionTransformerBlocks(nn.Module): method __init__ (line 121) | def __init__(self, args: VisionEncoderArgs): method forward (line 136) | def forward( class PatchMerger (line 147) | class PatchMerger(nn.Module): method __init__ (line 152) | def __init__( method forward (line 166) | def forward(self, x: torch.Tensor, image_sizes: list[tuple[int, int]])... method permute (line 180) | def permute( function get_sub_grids (line 206) | def get_sub_grids( FILE: tests/test_generate.py class DebugTokenizer (line 12) | class DebugTokenizer: method bos_id (line 14) | def bos_id(self) -> int: method eos_id (line 18) | def eos_id(self) -> int: method pad_id (line 22) | def pad_id(self) -> int: method encode (line 25) | def encode(self, s: str, bos: bool = True) -> List[int]: method decode (line 32) | def decode(self, t: List[int]) -> str: function test_generation_transformer (line 36) | def test_generation_transformer() -> None: function test_generation_pixtral (line 72) | def test_generation_pixtral() -> None: function test_generation_pixtral_patch_merger (line 121) | def test_generation_pixtral_patch_merger() -> None: function test_generation_mamba (line 174) | def test_generation_mamba() -> None: function test_chunks_transformer (line 199) | def test_chunks_transformer() -> None: