SYMBOL INDEX (3886 symbols across 505 files) FILE: .runpod/src/handler.py function handler (line 20) | async def handler(job): FILE: .runpod/src/train.py function train (line 8) | async def train(config_path: str, gpu_id: str = "0", preprocess: bool = ... FILE: .runpod/src/utils.py function get_output_dir (line 10) | def get_output_dir(run_id): function make_valid_config (line 15) | def make_valid_config(input_args): function set_config_env_vars (line 40) | def set_config_env_vars(args: dict): FILE: benchmarks/bench_entropy.py function entropy_from_logits_original (line 20) | def entropy_from_logits_original(logits: torch.Tensor, chunk_size: int =... function _clean_gpu (line 33) | def _clean_gpu(): function profile_time (line 41) | def profile_time(fn, logits, n_iters=BENCH_ITERS): function profile_memory (line 60) | def profile_memory(fn, logits, n_iters=MEM_ITERS): function fmt (line 77) | def fmt(values, unit=""): function benchmark_contiguous (line 83) | def benchmark_contiguous(): function benchmark_noncontiguous (line 138) | def benchmark_noncontiguous(): FILE: benchmarks/bench_scattermoe_lora.py function _resolve_config (line 46) | def _resolve_config(spec): function _clean (line 79) | def _clean(): function _bench (line 85) | def _bench(fn, warmup=WARMUP, iters=ITERS): function _setup (line 100) | def _setup(num_experts, K, N, T, top_k, R): function _call_fwd (line 117) | def _call_fwd(x, W, sei, ssi, top_k, lA, lB): function _call_base (line 130) | def _call_base(x, W, sei, ssi, top_k): function _call_dx (line 140) | def _call_dx(dy, W, sei, ssi, lA, lB): function _call_bwd (line 155) | def _call_bwd(dy, gx, lA, lB, eo, num_experts): function main (line 170) | def main(): FILE: benchmarks/bench_selective_logsoftmax.py function _clean_gpu (line 22) | def _clean_gpu(): function profile_time (line 30) | def profile_time(fn, args, n_iters=BENCH_ITERS): function profile_memory (line 47) | def profile_memory(fn, args, n_iters=MEM_ITERS): function fmt (line 64) | def fmt(values, unit=""): function benchmark_forward (line 70) | def benchmark_forward(): function benchmark_backward (line 123) | def benchmark_backward(): FILE: cicd/cleanup.py function cleanup (line 13) | def cleanup(): function main (line 18) | def main(): FILE: cicd/e2e_tests.py function cicd_pytest (line 14) | def cicd_pytest(): function main (line 19) | def main(): FILE: cicd/multigpu.py function run_cmd (line 63) | def run_cmd(cmd: str, run_folder: str): function cicd_pytest (line 79) | def cicd_pytest(): function main (line 84) | def main(): FILE: cicd/single_gpu.py function run_cmd (line 64) | def run_cmd(cmd: str, run_folder: str): FILE: docs/scripts/generate_config_docs.py class QuartoGenerator (line 20) | class QuartoGenerator: method __init__ (line 23) | def __init__(self): method _get_direct_fields (line 28) | def _get_direct_fields(self, cls: Type[BaseModel]) -> FrozenSet[str]: method _is_pydantic_model (line 46) | def _is_pydantic_model(self, type_obj) -> bool: method _extract_nested_type (line 50) | def _extract_nested_type(self, field_type) -> Any: method _extract_all_pydantic_models_from_type (line 126) | def _extract_all_pydantic_models_from_type( method _get_nested_models (line 186) | def _get_nested_models( method _build_inheritance_map (line 216) | def _build_inheritance_map(self, child_class: Type[BaseModel]): method _wrap_comment (line 237) | def _wrap_comment(self, text: str, width: int = 88) -> list[str]: method _extract_type_from_source (line 247) | def _extract_type_from_source( method _get_type_from_class_source (line 263) | def _get_type_from_class_source(self, class_obj: type, field_name: str... method _extract_field_groups_from_all_classes (line 285) | def _extract_field_groups_from_all_classes( method _extract_field_groups_from_source (line 319) | def _extract_field_groups_from_source( method _generate_field_documentation (line 438) | def _generate_field_documentation( method generate_qmd (line 591) | def generate_qmd( function main (line 736) | def main(): FILE: docs/scripts/generate_examples_docs.py function slugify (line 20) | def slugify(name: str) -> str: function read_allowlist (line 27) | def read_allowlist(): function find_readme (line 36) | def find_readme(folder: Path) -> Path | None: function remove_first_h1 (line 44) | def remove_first_h1(md: str) -> tuple[str, str | None]: function rewrite_and_copy_assets (line 68) | def rewrite_and_copy_assets(md: str, src_dir: Path, dest_assets_root: Pa... function rewrite_readme_links (line 94) | def rewrite_readme_links( function write_qmd (line 209) | def write_qmd(out_path: Path, title: str, body_md: str): function update_quarto_yml (line 215) | def update_quarto_yml(generated: list[tuple[str, str, str]]): function main (line 286) | def main(): FILE: examples/swanlab/custom_trainer_profiling.py class CustomTrainerWithProfiling (line 30) | class CustomTrainerWithProfiling(AxolotlTrainer): method __init__ (line 39) | def __init__(self, *args, **kwargs): method training_step (line 56) | def training_step(self, model, inputs): method compute_loss (line 64) | def compute_loss(self, model, inputs, return_outputs=False): method prediction_step (line 72) | def prediction_step(self, model, inputs, prediction_loss_only, ignore_... method complex_training_step (line 85) | def complex_training_step(self, model, inputs): method _prepare_inputs (line 115) | def _prepare_inputs(self, inputs): method _prepare_input_for_model (line 130) | def _prepare_input_for_model(self, input_ids): method potentially_failing_method (line 147) | def potentially_failing_method(self): method _do_risky_computation (line 162) | def _do_risky_computation(self): class AdvancedProfilingTrainer (line 172) | class AdvancedProfilingTrainer(AxolotlTrainer): method __init__ (line 175) | def __init__(self, *args, **kwargs): method training_step (line 197) | def training_step(self, model, inputs): method _prepare_inputs (line 204) | def _prepare_inputs(self, inputs): method _debug_method (line 211) | def _debug_method(self, data): FILE: scripts/chat_datasets.py function parse_dataset (line 13) | def parse_dataset(dataset=None, split="train"): FILE: setup.py function parse_requirements (line 12) | def parse_requirements(extras_require_map): function get_package_version (line 147) | def get_package_version(): FILE: src/axolotl/cli/args.py class PreprocessCliArgs (line 8) | class PreprocessCliArgs: class TrainerCliArgs (line 29) | class TrainerCliArgs: class VllmServeCliArgs (line 40) | class VllmServeCliArgs: class QuantizeCliArgs (line 109) | class QuantizeCliArgs: class EvaluateCliArgs (line 122) | class EvaluateCliArgs: class InferenceCliArgs (line 131) | class InferenceCliArgs: FILE: src/axolotl/cli/art.py function print_axolotl_text_art (line 22) | def print_axolotl_text_art(): FILE: src/axolotl/cli/checks.py function check_accelerate_default_config (line 16) | def check_accelerate_default_config() -> None: function check_user_token (line 24) | def check_user_token() -> bool: FILE: src/axolotl/cli/cloud/__init__.py function load_cloud_cfg (line 16) | def load_cloud_cfg(cloud_config: Path | str) -> DictDefault: function do_cli_preprocess (line 24) | def do_cli_preprocess( function do_cli_train (line 35) | def do_cli_train( function do_cli_lm_eval (line 66) | def do_cli_lm_eval( FILE: src/axolotl/cli/cloud/base.py class Cloud (line 9) | class Cloud(ABC): method preprocess (line 15) | def preprocess(self, config_yaml: str, *args, **kwargs) -> None: method train (line 19) | def train( FILE: src/axolotl/cli/cloud/baseten/__init__.py class BasetenCloud (line 14) | class BasetenCloud(Cloud): method __init__ (line 17) | def __init__(self, config: dict): method preprocess (line 20) | def preprocess(self, config_yaml: str, *args, **kwargs) -> None: method train (line 26) | def train( FILE: src/axolotl/cli/cloud/modal_.py function run_cmd (line 18) | def run_cmd(cmd: str, run_folder: str, volumes=None): class ModalCloud (line 52) | class ModalCloud(Cloud): method __init__ (line 57) | def __init__(self, config, app=None): method get_env (line 69) | def get_env(self): method get_image (line 84) | def get_image(self): method get_secrets (line 128) | def get_secrets(self): method create_volume (line 140) | def create_volume(self, volume_config): method get_ephemeral_disk_size (line 145) | def get_ephemeral_disk_size(self): method get_preprocess_timeout (line 148) | def get_preprocess_timeout(self): method get_preprocess_memory (line 153) | def get_preprocess_memory(self): method get_preprocess_env (line 161) | def get_preprocess_env(self): method preprocess (line 172) | def preprocess(self, config_yaml: str, *args, **kwargs): method get_train_timeout (line 183) | def get_train_timeout(self): method get_train_gpu (line 188) | def get_train_gpu(self): method get_train_memory (line 208) | def get_train_memory(self): method get_train_env (line 214) | def get_train_env(self, local_dirs=None): method train (line 228) | def train( method lm_eval (line 247) | def lm_eval(self, config_yaml: str): function _preprocess (line 261) | def _preprocess(config_yaml: str, volumes=None): function _train (line 273) | def _train( function _lm_eval (line 307) | def _lm_eval(config_yaml: str, volumes=None): FILE: src/axolotl/cli/config.py function _coerce_value (line 36) | def _coerce_value(value: Any, existing: Optional[Any] = None) -> Any: function check_remote_config (line 97) | def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: function choose_config (line 163) | def choose_config(path: Path) -> str: function prepare_plugins (line 208) | def prepare_plugins(cfg: DictDefault): function plugin_set_cfg (line 223) | def plugin_set_cfg(cfg: DictDefault): function load_cfg (line 230) | def load_cfg( function compute_supports_fp8 (line 349) | def compute_supports_fp8() -> bool: FILE: src/axolotl/cli/delinearize_llama4.py function iter_convert_patched_to_hf (line 15) | def iter_convert_patched_to_hf(model_state_dict, num_experts) -> Generator: function do_cli (line 71) | def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None: FILE: src/axolotl/cli/evaluate.py function do_evaluate (line 21) | def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: function do_cli (line 44) | def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: FILE: src/axolotl/cli/inference.py function get_multi_line_input (line 32) | def get_multi_line_input() -> str: function do_inference (line 50) | def do_inference( function do_inference_gradio (line 168) | def do_inference_gradio( function do_cli (line 299) | def do_cli( FILE: src/axolotl/cli/main.py function cli (line 43) | def cli(): function preprocess (line 57) | def preprocess(config: str, cloud: Optional[str] = None, **kwargs): function train (line 98) | def train( function evaluate (line 153) | def evaluate(ctx: click.Context, config: str, launcher: str, **kwargs): function inference (line 198) | def inference(ctx: click.Context, config: str, launcher: str, gradio: bo... function merge_sharded_fsdp_weights (line 245) | def merge_sharded_fsdp_weights( function merge_lora (line 282) | def merge_lora(config: str, **kwargs): function fetch (line 299) | def fetch(directory: str, dest: Optional[str]): function vllm_serve (line 318) | def vllm_serve(config: str, **cli_args: VllmServeCliArgs): function quantize (line 328) | def quantize(config: str, **cli_args: QuantizeCliArgs): function delinearize_llama4 (line 337) | def delinearize_llama4(model: str, output: str): function main (line 346) | def main(): FILE: src/axolotl/cli/merge_lora.py function do_merge_lora (line 18) | def do_merge_lora(*, cfg: DictDefault) -> None: function do_cli (line 55) | def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: FILE: src/axolotl/cli/merge_sharded_fsdp_weights.py class BFloat16CastPlanner (line 31) | class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): method commit_tensor (line 34) | def commit_tensor(self, read_item, tensor): function _distributed_checkpoint_to_merged_weights (line 38) | def _distributed_checkpoint_to_merged_weights( function merge_fsdp_weights (line 108) | def merge_fsdp_weights( function do_cli (line 169) | def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): FILE: src/axolotl/cli/preprocess.py function do_preprocess (line 29) | def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: function do_cli (line 99) | def do_cli( FILE: src/axolotl/cli/quantize.py function do_quantize (line 23) | def do_quantize( FILE: src/axolotl/cli/train.py function do_train (line 23) | def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): function do_cli (line 55) | def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): function ray_train_func (line 91) | def ray_train_func(kwargs: dict): FILE: src/axolotl/cli/utils/args.py function _strip_optional_type (line 12) | def _strip_optional_type(field_type: type | str | None): function filter_none_kwargs (line 32) | def filter_none_kwargs(func: Callable) -> Callable: function add_options_from_dataclass (line 53) | def add_options_from_dataclass(config_class: Type[Any]) -> Callable: function _is_pydantic_model (line 91) | def _is_pydantic_model(field_type: type) -> bool: function _get_field_description (line 99) | def _get_field_description(field) -> str | None: function _add_nested_model_options (line 108) | def _add_nested_model_options( function add_options_from_config (line 148) | def add_options_from_config(config_class: Type[BaseModel]) -> Callable: FILE: src/axolotl/cli/utils/diffusion.py function diffusion_inference (line 12) | def diffusion_inference( function _parse_commands (line 91) | def _parse_commands(text: str): function run_diffusion (line 139) | def run_diffusion( function render_html (line 199) | def render_html( function launch_diffusion_gradio_ui (line 260) | def launch_diffusion_gradio_ui( FILE: src/axolotl/cli/utils/fetch.py function _download_file (line 16) | def _download_file( function fetch_from_github (line 70) | def fetch_from_github( FILE: src/axolotl/cli/utils/load.py function load_model_and_tokenizer (line 20) | def load_model_and_tokenizer( FILE: src/axolotl/cli/utils/sweeps.py function generate_sweep_configs (line 9) | def generate_sweep_configs( FILE: src/axolotl/cli/utils/train.py function _add_default_rdzv_args (line 15) | def _add_default_rdzv_args(launcher_args: list[str]) -> list[str]: function build_command (line 46) | def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[... function generate_config_files (line 69) | def generate_config_files(config: str, sweep: str | None) -> Iterator[tu... function launch_training (line 109) | def launch_training( function _launch_cloud_training (line 134) | def _launch_cloud_training( function _launch_accelerate_training (line 157) | def _launch_accelerate_training( function _launch_torchrun_training (line 195) | def _launch_torchrun_training( function _launch_python_training (line 221) | def _launch_python_training(cfg_file: str, kwargs: dict) -> None: FILE: src/axolotl/cli/vllm_serve.py class AxolotlScriptArguments (line 15) | class AxolotlScriptArguments(ScriptArguments): function do_vllm_serve (line 24) | def do_vllm_serve( FILE: src/axolotl/common/datasets.py class TrainDatasetMeta (line 23) | class TrainDatasetMeta: function sample_dataset (line 31) | def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: function load_datasets (line 39) | def load_datasets( function load_preference_datasets (line 102) | def load_preference_datasets( FILE: src/axolotl/convert.py class FileReader (line 7) | class FileReader: method read (line 12) | def read(self, file_path): class FileWriter (line 17) | class FileWriter: method __init__ (line 22) | def __init__(self, file_path): method write (line 25) | def write(self, content): class StdoutWriter (line 30) | class StdoutWriter: method write (line 35) | def write(self, content): class JsonParser (line 40) | class JsonParser: method parse (line 45) | def parse(self, content): class JsonlSerializer (line 49) | class JsonlSerializer: method serialize (line 54) | def serialize(self, data): class JsonToJsonlConverter (line 59) | class JsonToJsonlConverter: method __init__ (line 64) | def __init__(self, file_reader, file_writer, json_parser, jsonl_serial... method convert (line 70) | def convert(self, input_file_path): FILE: src/axolotl/core/builders/base.py class TrainerBuilderBase (line 55) | class TrainerBuilderBase(abc.ABC): method __init__ (line 58) | def __init__(self, cfg, model, tokenizer, processor=None): method model_ref (line 78) | def model_ref(self): method model_ref (line 82) | def model_ref(self, model): method train_dataset (line 86) | def train_dataset(self): method train_dataset (line 90) | def train_dataset(self, dataset): method eval_dataset (line 94) | def eval_dataset(self): method eval_dataset (line 98) | def eval_dataset(self, dataset): method peft_config (line 102) | def peft_config(self): method peft_config (line 106) | def peft_config(self, peft_config): method build (line 110) | def build(self, total_num_steps): method get_callbacks (line 113) | def get_callbacks(self) -> list[TrainerCallback]: method get_post_trainer_create_callbacks (line 182) | def get_post_trainer_create_callbacks(self, trainer): method hook_pre_create_training_args (line 200) | def hook_pre_create_training_args(self, training_arguments_kwargs): method hook_post_create_training_args (line 204) | def hook_post_create_training_args(self, training_arguments): method hook_pre_create_trainer (line 208) | def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls): method hook_post_create_trainer (line 212) | def hook_post_create_trainer(self, trainer): method _configure_warmup_and_logging (line 216) | def _configure_warmup_and_logging( method _configure_precision_settings (line 251) | def _configure_precision_settings(self, training_args_kwargs: dict): method _configure_scheduler (line 261) | def _configure_scheduler(self, training_args_kwargs: dict): method _configure_optimizer (line 273) | def _configure_optimizer(self, training_args_kwargs: dict, trainer_kwa... method _configure_hub_parameters (line 426) | def _configure_hub_parameters(self, training_args_kwargs: dict): method _configure_save_and_eval_strategy (line 439) | def _configure_save_and_eval_strategy(self, training_args_kwargs: dict): method _configure_reporting (line 466) | def _configure_reporting(self, training_args_kwargs: dict): method _configure_torch_compile (line 490) | def _configure_torch_compile(self, training_args_kwargs: dict): method _configure_accelerator_config (line 502) | def _configure_accelerator_config(self, training_args_kwargs: dict): method _configure_gradient_checkpointing (line 510) | def _configure_gradient_checkpointing(self, training_args_kwargs: dict): method _set_base_training_args (line 528) | def _set_base_training_args( FILE: src/axolotl/core/builders/causal.py class HFCausalTrainerBuilder (line 53) | class HFCausalTrainerBuilder(TrainerBuilderBase): method get_callbacks (line 59) | def get_callbacks(self): method get_post_trainer_create_callbacks (line 82) | def get_post_trainer_create_callbacks(self, trainer): method _get_trainer_cls (line 134) | def _get_trainer_cls(self): method build (line 163) | def build(self, total_num_steps): method build_collator (line 454) | def build_collator( FILE: src/axolotl/core/builders/rl.py class HFRLTrainerBuilder (line 24) | class HFRLTrainerBuilder(TrainerBuilderBase): method get_callbacks (line 27) | def get_callbacks(self): method get_post_trainer_create_callbacks (line 35) | def get_post_trainer_create_callbacks(self, trainer): method _get_trainer_cls (line 39) | def _get_trainer_cls(self, trainer_kwargs: dict): method _build_training_arguments (line 96) | def _build_training_arguments(self, total_num_steps): method build (line 206) | def build(self, total_num_steps): FILE: src/axolotl/core/chat/format/chatml.py function format_message (line 11) | def format_message( FILE: src/axolotl/core/chat/format/llama3x.py function format_message (line 11) | def format_message(message: Messages, message_index: Optional[int] = Non... FILE: src/axolotl/core/chat/format/shared.py function wrap_tools (line 8) | def wrap_tools(message: Messages): FILE: src/axolotl/core/chat/messages.py class MessageRoles (line 13) | class MessageRoles(str, Enum): class MessageContentTypes (line 28) | class MessageContentTypes(str, Enum): class SpecialToken (line 41) | class SpecialToken(str, Enum): class ToolCallFunction (line 50) | class ToolCallFunction(BaseModel): class Tool (line 59) | class Tool(BaseModel): class ToolCallContents (line 69) | class ToolCallContents(BaseModel): method __str__ (line 78) | def __str__(self) -> str: class ToolResponseContents (line 85) | class ToolResponseContents(BaseModel): method __str__ (line 94) | def __str__(self) -> str: class MessageContents (line 101) | class MessageContents(BaseModel): method __str__ (line 113) | def __str__(self) -> str: class Messages (line 120) | class Messages(BaseModel): method __str__ (line 131) | def __str__(self) -> str: method tokenized (line 134) | def tokenized( class Chats (line 180) | class Chats(BaseModel): method __str__ (line 187) | def __str__(self) -> str: method tokenized (line 190) | def tokenized( class ChatFormattedChats (line 208) | class ChatFormattedChats(Chats): method model_post_init (line 216) | def model_post_init(self, __context): class PreferenceChats (line 223) | class PreferenceChats(BaseModel): FILE: src/axolotl/core/datasets/chat.py class TokenizedChatDataset (line 13) | class TokenizedChatDataset(Dataset): method __init__ (line 18) | def __init__( FILE: src/axolotl/core/datasets/transforms/chat_builder.py function chat_message_transform_builder (line 9) | def chat_message_transform_builder( FILE: src/axolotl/core/trainers/base.py class AxolotlTrainer (line 64) | class AxolotlTrainer( method axolotl_cfg (line 81) | def axolotl_cfg(self): method axolotl_cfg (line 85) | def axolotl_cfg(self, cfg): method __init__ (line 88) | def __init__( method _create_multipack_sampler (line 109) | def _create_multipack_sampler( method _get_train_sampler (line 150) | def _get_train_sampler( method _get_eval_sampler (line 187) | def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sa... method _get_dataloader (line 219) | def _get_dataloader( method _get_bench_sampler (line 316) | def _get_bench_sampler( method get_bench_dataloader (line 323) | def get_bench_dataloader( method compute_loss (line 344) | def compute_loss( method evaluate (line 399) | def evaluate(self, *args, **kwargs): method orpo_concatenate_inputs (line 404) | def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0,... method orpo_compute_custom_loss (line 455) | def orpo_compute_custom_loss(self, logits, labels): method orpo_compute_logps (line 473) | def orpo_compute_logps( method orpo_compute_loss (line 497) | def orpo_compute_loss( method push_to_hub (line 565) | def push_to_hub(self, *args, **kwargs) -> str: method create_accelerator_and_postprocess (line 578) | def create_accelerator_and_postprocess(self): method additional_accelerator_args (line 587) | def additional_accelerator_args( method log (line 609) | def log(self, logs: dict[str, float], start_time: float | None = None)... method store_metrics (line 672) | def store_metrics( method _save_checkpoint (line 695) | def _save_checkpoint(self, model, trial, **kwargs): method _save (line 717) | def _save(self, output_dir: Optional[str] = None, state_dict=None): FILE: src/axolotl/core/trainers/dpo/__init__.py class DPOStrategy (line 7) | class DPOStrategy: method get_trainer_class (line 11) | def get_trainer_class(cls): method get_training_args_class (line 15) | def get_training_args_class(cls): method set_training_args_kwargs (line 21) | def set_training_args_kwargs(cls, cfg): FILE: src/axolotl/core/trainers/dpo/args.py class AxolotlDPOConfig (line 13) | class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): FILE: src/axolotl/core/trainers/dpo/trainer.py class AxolotlDPOTrainer (line 23) | class AxolotlDPOTrainer( method __init__ (line 35) | def __init__(self, *args, dataset_tags=None, **kwargs): method push_to_hub (line 43) | def push_to_hub(self, *args, **kwargs) -> str: method tokenize_row (line 57) | def tokenize_row( method training_step (line 87) | def training_step( method concatenated_forward (line 98) | def concatenated_forward( FILE: src/axolotl/core/trainers/grpo/__init__.py class GRPOStrategy (line 26) | class GRPOStrategy: method get_trainer_class (line 30) | def get_trainer_class( method get_training_args_class (line 51) | def get_training_args_class( method set_training_args_kwargs (line 59) | def set_training_args_kwargs(cls, cfg: DictDefault) -> dict[str, Any]: method set_trainer_args (line 205) | def set_trainer_args(cls, cfg: DictDefault) -> list[Any]: method set_trainer_kwargs (line 216) | def set_trainer_kwargs(cls, cfg: DictDefault) -> dict[str, Any]: method get_collator (line 228) | def get_collator(cls, *args, **kwargs): method get_blocklist_args_kwargs (line 233) | def get_blocklist_args_kwargs(cls) -> list[str]: method get_reward_func (line 242) | def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc: method get_rollout_func (line 286) | def get_rollout_func(cls, rollout_func_fqn: str): FILE: src/axolotl/core/trainers/grpo/args.py class AxolotlGRPOConfig (line 14) | class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): class AxolotlAsyncGRPOConfig (line 21) | class AxolotlAsyncGRPOConfig(AxolotlTrainingMixins, FastAsyncGRPOConfig): FILE: src/axolotl/core/trainers/grpo/async_trainer.py function disable_gradient_checkpointing (line 68) | def disable_gradient_checkpointing(model, kwargs): class AsyncGRPOConfig (line 98) | class AsyncGRPOConfig(GRPOConfig): class ProducerConfig (line 184) | class ProducerConfig: method __post_init__ (line 211) | def __post_init__(self): class DataProducer (line 232) | class DataProducer(ABC): method produce (line 242) | def produce( class BaseDataProducer (line 256) | class BaseDataProducer(DataProducer): method __init__ (line 259) | def __init__(self, config: ProducerConfig | None = None): method on_rollout_begin (line 262) | def on_rollout_begin(self, global_step: int) -> None: method on_rollout_end (line 265) | def on_rollout_end(self, dataset: Dataset, global_step: int) -> None: class AsyncDataProducer (line 269) | class AsyncDataProducer: method __init__ (line 282) | def __init__( method config (line 303) | def config(self) -> ProducerConfig: method produce (line 306) | def produce(self, model: Any, global_step: int, **kwargs) -> Dataset: method _broadcast_dataset (line 358) | def _broadcast_dataset(self, dataset) -> Dataset: method _locked_produce (line 391) | def _locked_produce(self, model: Any, global_step: int, **kwargs) -> D... method on_rollout_begin (line 396) | def on_rollout_begin(self, global_step: int) -> None: method on_rollout_end (line 400) | def on_rollout_end(self, dataset: Dataset, global_step: int) -> None: method shutdown (line 404) | def shutdown(self) -> None: class DataProducerCallback (line 412) | class DataProducerCallback: class RolloutDataset (line 424) | class RolloutDataset(Dataset): method __init__ (line 434) | def __init__(self, data: dict[str, Any]): method __len__ (line 461) | def __len__(self) -> int: method __getitem__ (line 464) | def __getitem__(self, idx: int) -> dict[str, Any]: function make_rollout_collator (line 473) | def make_rollout_collator(shared_keys: set[str]): class GRPODataProducer (line 492) | class GRPODataProducer(BaseDataProducer): method __init__ (line 499) | def __init__( method set_trainer (line 523) | def set_trainer(self, trainer) -> None: method _init_prompt_dataloader (line 528) | def _init_prompt_dataloader(self) -> None: method produce (line 570) | def produce( class AsyncGRPOTrainer (line 623) | class AsyncGRPOTrainer(GRPOTrainer): method __init__ (line 630) | def __init__(self, *args, **kwargs): method _create_data_producer (line 709) | def _create_data_producer(self, args, train_dataset): method _setup_async (line 743) | def _setup_async(self): method _shutdown_async (line 777) | def _shutdown_async(self): method _submit_generation (line 782) | def _submit_generation(self): method _sync_peft_weights_no_merge (line 792) | def _sync_peft_weights_no_merge(self): method _sync_lora_adapter (line 892) | def _sync_lora_adapter(self): method _maybe_sync_vllm_weights (line 998) | def _maybe_sync_vllm_weights(self): method _zero_pad_embedding_for_fp8 (line 1045) | def _zero_pad_embedding_for_fp8(self): method _generate_single_turn (line 1091) | def _generate_single_turn(self, prompts, **kwargs): method _generate_rank0_only (line 1129) | def _generate_rank0_only(self, prompts): method _generate_only (line 1217) | def _generate_only(self, inputs, rank0_only=False): method _compute_rewards_for_batch (line 1414) | def _compute_rewards_for_batch( method _launch_reward_workers (line 1422) | def _launch_reward_workers(self, inputs, prompts, completions, complet... method _collect_reward_workers (line 1429) | def _collect_reward_workers( method _post_advantage_hook (line 1444) | def _post_advantage_hook( method _compute_deferred_scores (line 1463) | def _compute_deferred_scores(self, rollout: dict) -> dict: method _compute_streaming_group_scores (line 1802) | def _compute_streaming_group_scores( method _score_streaming (line 2153) | def _score_streaming(self, rollout: dict) -> list[dict]: method _prepare_inputs (line 2215) | def _prepare_inputs(self, generation_batch): method _prepare_inputs_data_producer (line 2230) | def _prepare_inputs_data_producer(self, generation_batch): method _prepare_inputs_legacy_async (line 2271) | def _prepare_inputs_legacy_async(self, generation_batch): method _get_per_token_logps_and_entropies (line 2302) | def _get_per_token_logps_and_entropies( method get_off_policy_mask (line 2437) | def get_off_policy_mask( method _compute_loss (line 2453) | def _compute_loss(self, model, inputs): FILE: src/axolotl/core/trainers/grpo/fast_async_trainer.py class FastAsyncGRPOConfig (line 51) | class FastAsyncGRPOConfig(AsyncGRPOConfig): class RerollDataProducer (line 117) | class RerollDataProducer(GRPODataProducer): method _pre_produce_hook (line 125) | def _pre_produce_hook(self, inputs: list, global_step: int) -> list: function _persistent_reward_worker (line 169) | def _persistent_reward_worker(conn): class FastAsyncGRPOTrainer (line 216) | class FastAsyncGRPOTrainer(AsyncGRPOTrainer): method __init__ (line 226) | def __init__(self, *args, **kwargs): method _create_data_producer (line 265) | def _create_data_producer(self, args, train_dataset): method _get_reward_workers (line 301) | def _get_reward_workers(self): method _shutdown_reward_workers (line 328) | def _shutdown_reward_workers(self): method _compute_rewards_for_batch (line 346) | def _compute_rewards_for_batch( method _launch_reward_workers (line 355) | def _launch_reward_workers(self, inputs, prompts, completions, complet... method _collect_reward_workers (line 410) | def _collect_reward_workers( method _post_advantage_hook (line 470) | def _post_advantage_hook( method compute_liger_loss (line 637) | def compute_liger_loss(self, unwrapped_model, inputs): method _compute_loss (line 753) | def _compute_loss(self, model, inputs): FILE: src/axolotl/core/trainers/grpo/replay_buffer.py class ReplayBuffer (line 8) | class ReplayBuffer: method __init__ (line 14) | def __init__(self, max_size: int): method __len__ (line 19) | def __len__(self): method add (line 22) | def add(self, score: float, data: dict): method sample (line 32) | def sample(self, num_samples: int) -> list[dict] | None: FILE: src/axolotl/core/trainers/grpo/sampler.py class SequenceParallelRepeatRandomSampler (line 13) | class SequenceParallelRepeatRandomSampler(Sampler): method __init__ (line 54) | def __init__( method __iter__ (line 109) | def __iter__(self) -> Iterator[int]: method __len__ (line 155) | def __len__(self) -> int: method set_epoch (line 166) | def set_epoch(self, epoch: int) -> None: FILE: src/axolotl/core/trainers/grpo/trainer.py class AxolotlGRPOTrainer (line 57) | class AxolotlGRPOTrainer( class AxolotlAsyncGRPOTrainer (line 70) | class AxolotlAsyncGRPOTrainer( class AxolotlGRPOSequenceParallelTrainer (line 83) | class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): method __init__ (line 86) | def __init__( method train (line 166) | def train(self, *args, **kwargs): method _get_train_sampler (line 176) | def _get_train_sampler(self) -> Sampler: method _create_dataloader_params (line 198) | def _create_dataloader_params(self, is_eval=False, custom_batch_size=N... method _prepare_dataloader (line 221) | def _prepare_dataloader( method get_train_dataloader (line 264) | def get_train_dataloader(self) -> DataLoader: method _generate_and_score_completions (line 291) | def _generate_and_score_completions( FILE: src/axolotl/core/trainers/mamba.py class AxolotlMambaTrainer (line 8) | class AxolotlMambaTrainer(AxolotlTrainer): method compute_loss (line 13) | def compute_loss( FILE: src/axolotl/core/trainers/mixins/activation_checkpointing.py class ActivationOffloadingMixin (line 25) | class ActivationOffloadingMixin(Trainer): method __init__ (line 30) | def __init__(self, *args, **kwargs): method training_step (line 44) | def training_step(self, *args, **kwargs): function ac_wrap_hf_model (line 49) | def ac_wrap_hf_model(model: nn.Module, **kwargs): function get_lora_act_offloading_ctx_manager (line 54) | def get_lora_act_offloading_ctx_manager( FILE: src/axolotl/core/trainers/mixins/checkpoints.py class CheckpointSaveMixin (line 10) | class CheckpointSaveMixin(Trainer): method _save_optimizer_and_scheduler (line 13) | def _save_optimizer_and_scheduler(self, output_dir): FILE: src/axolotl/core/trainers/mixins/distributed_parallel.py class DistributedParallelMixin (line 9) | class DistributedParallelMixin(Trainer): method _save (line 14) | def _save(self, output_dir: str | None = None, state_dict=None): method create_accelerator_and_postprocess (line 23) | def create_accelerator_and_postprocess(self): FILE: src/axolotl/core/trainers/mixins/optimizer.py class OptimizerMixin (line 17) | class OptimizerMixin(Trainer): method create_optimizer_grouped_parameters (line 22) | def create_optimizer_grouped_parameters( method create_optimizer (line 107) | def create_optimizer(self, model=None): class OptimizerInitMixin (line 201) | class OptimizerInitMixin: method __init__ (line 207) | def __init__(self, *args, **kwargs): FILE: src/axolotl/core/trainers/mixins/packing.py class PackingMixin (line 6) | class PackingMixin(Trainer): method _set_signature_columns_if_needed (line 11) | def _set_signature_columns_if_needed(self): FILE: src/axolotl/core/trainers/mixins/rng_state_loader.py class RngLoaderMixin (line 24) | class RngLoaderMixin(Trainer): method _load_rng_state (line 29) | def _load_rng_state(self, checkpoint): FILE: src/axolotl/core/trainers/mixins/scheduler.py class SchedulerMixin (line 20) | class SchedulerMixin(Trainer): method create_scheduler (line 27) | def create_scheduler( FILE: src/axolotl/core/trainers/trl.py class AxolotlORPOTrainer (line 14) | class AxolotlORPOTrainer( class AxolotlKTOTrainer (line 29) | class AxolotlKTOTrainer( class AxolotlCPOTrainer (line 44) | class AxolotlCPOTrainer( class AxolotlRewardTrainer (line 59) | class AxolotlRewardTrainer( class AxolotlPRMTrainer (line 74) | class AxolotlPRMTrainer( FILE: src/axolotl/core/trainers/utils.py function sanitize_kwargs_for_tagging (line 4) | def sanitize_kwargs_for_tagging(tag_names, kwargs=None): function sanitize_kwargs_for_ds_tagging (line 20) | def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None): FILE: src/axolotl/core/training_args.py class AxolotlTrainingArguments (line 23) | class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): class AxolotlORPOConfig (line 33) | class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig): class AxolotlKTOConfig (line 40) | class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig): class AxolotlCPOConfig (line 47) | class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig): class AxolotlRewardConfig (line 59) | class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig): class AxolotlPRMConfig (line 66) | class AxolotlPRMConfig(AxolotlTrainingMixins, PRMConfig): FILE: src/axolotl/core/training_args_base.py class AxolotlTrainingMixins (line 12) | class AxolotlTrainingMixins: FILE: src/axolotl/datasets.py class TokenizedPromptDataset (line 18) | class TokenizedPromptDataset(Dataset): method __init__ (line 28) | def __init__( method process (line 44) | def process(self, dataset): function wrap_dataset_for_tokenized_prompt (line 72) | def wrap_dataset_for_tokenized_prompt( FILE: src/axolotl/evaluate.py function evaluate_dataset (line 30) | def evaluate_dataset( function evaluate (line 68) | def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dic... FILE: src/axolotl/integrations/base.py class BasePlugin (line 44) | class BasePlugin: method __init__ (line 76) | def __init__(self): method register (line 79) | def register(self, cfg: dict): method get_input_args (line 86) | def get_input_args(self) -> str | None: method get_training_args_mixin (line 89) | def get_training_args_mixin(self) -> str | None: method load_datasets (line 94) | def load_datasets( method pre_model_load (line 107) | def pre_model_load(self, cfg: DictDefault): method post_model_build (line 114) | def post_model_build(self, cfg: DictDefault, model: PreTrainedModel): method pre_lora_load (line 121) | def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel): method post_lora_load (line 129) | def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | Pe... method post_model_load (line 137) | def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | P... method get_trainer_cls (line 145) | def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None: method post_trainer_create (line 155) | def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): method get_training_args (line 163) | def get_training_args(self, cfg: DictDefault): method get_collator_cls_and_kwargs (line 174) | def get_collator_cls_and_kwargs(self, cfg: DictDefault, is_eval: bool ... method create_optimizer (line 186) | def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Opti... method create_lr_scheduler (line 197) | def create_lr_scheduler( method add_callbacks_pre_trainer (line 216) | def add_callbacks_pre_trainer( method add_callbacks_post_trainer (line 230) | def add_callbacks_post_trainer( method post_train (line 245) | def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftMo... method post_train_unload (line 253) | def post_train_unload(self, cfg: DictDefault): function load_plugin (line 261) | def load_plugin(plugin_name: str) -> BasePlugin: class PluginManager (line 301) | class PluginManager: method __new__ (line 320) | def __new__(cls): method get_instance (line 330) | def get_instance() -> "PluginManager": method cfg (line 339) | def cfg(self): method cfg (line 343) | def cfg(self, cfg): method register (line 346) | def register(self, plugin_name: str): method get_input_args (line 366) | def get_input_args(self) -> list[str]: method get_training_args_mixin (line 379) | def get_training_args_mixin(self): method load_datasets (line 393) | def load_datasets( method pre_model_load (line 415) | def pre_model_load(self, cfg: DictDefault): method post_model_build (line 424) | def post_model_build(self, cfg: DictDefault, model: PreTrainedModel): method pre_lora_load (line 435) | def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel): method post_lora_load (line 445) | def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | Pe... method post_model_load (line 455) | def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | P... method get_trainer_cls (line 466) | def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None: method get_training_args (line 482) | def get_training_args(self, cfg): method get_collator_cls_and_kwargs (line 500) | def get_collator_cls_and_kwargs(self, cfg, is_eval=False): method post_trainer_create (line 518) | def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): method create_optimizer (line 528) | def create_optimizer(self, trainer: Trainer) -> Optimizer | None: method create_lr_scheduler (line 544) | def create_lr_scheduler( method add_callbacks_pre_trainer (line 568) | def add_callbacks_pre_trainer( method add_callbacks_post_trainer (line 587) | def add_callbacks_post_trainer( method post_train (line 606) | def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftMo... method post_train_unload (line 616) | def post_train_unload(self, cfg: DictDefault): class BaseOptimizerFactory (line 626) | class BaseOptimizerFactory: method __call__ (line 629) | def __call__( method get_decay_parameter_names (line 635) | def get_decay_parameter_names(self, model) -> list[str]: FILE: src/axolotl/integrations/config.py function merge_input_args (line 27) | def merge_input_args(): function merge_training_args (line 60) | def merge_training_args() -> Type: FILE: src/axolotl/integrations/cut_cross_entropy/__init__.py class CutCrossEntropyPlugin (line 42) | class CutCrossEntropyPlugin(BasePlugin): method get_input_args (line 47) | def get_input_args(self): method _check_requirements (line 50) | def _check_requirements(self): method pre_model_load (line 86) | def pre_model_load(self, cfg): method patch_llama_like (line 105) | def patch_llama_like( FILE: src/axolotl/integrations/cut_cross_entropy/args.py class CutCrossEntropyArgs (line 28) | class CutCrossEntropyArgs(BaseModel): method check_dtype_is_half (line 37) | def check_dtype_is_half(cls, data): method check_chunked_cross_entropy_not_set (line 48) | def check_chunked_cross_entropy_not_set(cls, data): FILE: src/axolotl/integrations/densemixer/args.py class DenseMixerArgs (line 6) | class DenseMixerArgs(BaseModel): FILE: src/axolotl/integrations/densemixer/plugin.py class DenseMixerPlugin (line 11) | class DenseMixerPlugin(BasePlugin): method get_input_args (line 16) | def get_input_args(self) -> str | None: method pre_model_load (line 19) | def pre_model_load(self, cfg): FILE: src/axolotl/integrations/diffusion/args.py class DiffusionConfig (line 10) | class DiffusionConfig(BaseModel): method _validate_mask_ratios (line 83) | def _validate_mask_ratios(self) -> "DiffusionConfig": class DiffusionArgs (line 89) | class DiffusionArgs(BaseModel): FILE: src/axolotl/integrations/diffusion/callbacks.py class DiffusionGenerationCallback (line 23) | class DiffusionGenerationCallback(TrainerCallback): method __init__ (line 26) | def __init__(self, trainer): method on_step_end (line 29) | def on_step_end( method _log_samples (line 70) | def _log_samples(self, samples: list, step: int): FILE: src/axolotl/integrations/diffusion/generation.py function generate_samples (line 15) | def generate_samples( function _sample_sequences_from_dataloader (line 103) | def _sample_sequences_from_dataloader( function generate (line 196) | def generate( function _clean_masked_text (line 321) | def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: ... function _diffusion_step (line 339) | def _diffusion_step( FILE: src/axolotl/integrations/diffusion/plugin.py class DiffusionPlugin (line 15) | class DiffusionPlugin(BasePlugin): method __init__ (line 23) | def __init__(self): method get_input_args (line 27) | def get_input_args(self) -> str: method post_model_load (line 31) | def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | P... method get_trainer_cls (line 35) | def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] ... method post_trainer_create (line 39) | def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrai... FILE: src/axolotl/integrations/diffusion/trainer.py class DiffusionTrainer (line 19) | class DiffusionTrainer(AxolotlTrainer): method __init__ (line 22) | def __init__(self, *args, **kwargs): method set_config (line 27) | def set_config(self, config: DictDefault): method _resolve_mask_token_id (line 40) | def _resolve_mask_token_id(self) -> None: method compute_loss (line 59) | def compute_loss( method _cache_special_token_ids (line 82) | def _cache_special_token_ids(self): method _forward_process (line 100) | def _forward_process( method _compute_diffusion_loss (line 159) | def _compute_diffusion_loss( FILE: src/axolotl/integrations/diffusion/utils.py function resolve_mask_token_id (line 12) | def resolve_mask_token_id( function create_bidirectional_attention_mask (line 125) | def create_bidirectional_attention_mask( function shift_logits_to_input_positions (line 162) | def shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor: FILE: src/axolotl/integrations/grokfast/__init__.py class GrokfastCallbackHandler (line 16) | class GrokfastCallbackHandler(TrainerCallback): method __init__ (line 21) | def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs): method on_train_begin (line 27) | def on_train_begin(self, *args_, **kwargs): method on_pre_optimizer_step (line 30) | def on_pre_optimizer_step(self, args_, state, control, **kwargs): class GrokfastPlugin (line 36) | class GrokfastPlugin(BasePlugin): method get_input_args (line 41) | def get_input_args(self): method add_callbacks_post_trainer (line 44) | def add_callbacks_post_trainer(self, cfg, trainer): FILE: src/axolotl/integrations/grokfast/args.py class GrokfastArgs (line 10) | class GrokfastArgs(BaseModel): FILE: src/axolotl/integrations/grokfast/optimizer.py function gradfilter_ma (line 11) | def gradfilter_ma( function gradfilter_ema (line 44) | def gradfilter_ema( FILE: src/axolotl/integrations/kd/__init__.py class KDPlugin (line 29) | class KDPlugin(BasePlugin): method get_input_args (line 34) | def get_input_args(self): method get_training_args_mixin (line 37) | def get_training_args_mixin(self): method get_trainer_cls (line 40) | def get_trainer_cls(self, cfg): method get_training_args (line 47) | def get_training_args(self, cfg): method get_collator_cls_and_kwargs (line 56) | def get_collator_cls_and_kwargs(self, cfg, is_eval=False): method pre_model_load (line 84) | def pre_model_load(self, cfg): method add_callbacks_post_trainer (line 89) | def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list: FILE: src/axolotl/integrations/kd/args.py class InferenceServerType (line 25) | class InferenceServerType(str, Enum): class KDArgs (line 34) | class KDArgs(BaseModel): class KDTrainingArgsMixin (line 63) | class KDTrainingArgsMixin: FILE: src/axolotl/integrations/kd/callbacks.py class KDTemperatureSchedulerCallback (line 10) | class KDTemperatureSchedulerCallback(TrainerCallback): method __init__ (line 15) | def __init__(self, temperature_start, temperature_min, trainer): method on_step_end (line 22) | def on_step_end(self, args, state, control, **kwargs): FILE: src/axolotl/integrations/kd/chat_template.py class ChatTemplateStrategyWithKD (line 29) | class ChatTemplateStrategyWithKD(ChatTemplateStrategy): method __init__ (line 34) | def __init__( method supports_batched (line 66) | def supports_batched(self) -> bool: method transform_logprobs (line 70) | def transform_logprobs(self, sample): method _tokenize_single_prompt (line 178) | def _tokenize_single_prompt(self, prompt): method _prepare_kd_fields (line 189) | def _prepare_kd_fields(self, tokenized_prompt, original_prompt): class ChatTemplateStrategyWithKDv2 (line 196) | class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD): method transform_logprobs (line 201) | def transform_logprobs(self, sample): method _prepare_kd_fields (line 295) | def _prepare_kd_fields(self, tokenized_prompt, original_prompt): class KDStrategyLoader (line 305) | class KDStrategyLoader(StrategyLoader): method _get_strategy_cls (line 310) | def _get_strategy_cls(self, cfg): method _get_strategy_params (line 313) | def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]): class KDStrategyLoaderV2 (line 325) | class KDStrategyLoaderV2(KDStrategyLoader): method _get_strategy_cls (line 330) | def _get_strategy_cls(self, cfg): FILE: src/axolotl/integrations/kd/collator.py class DataCollatorForKD (line 32) | class DataCollatorForKD(DataCollatorForSeq2Seq): method __init__ (line 49) | def __init__(self, *args, **kwargs): method __call__ (line 53) | def __call__(self, features, return_tensors=None): class KDBatchSamplerDataCollatorForSeq2Seq (line 198) | class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD): method __call__ (line 204) | def __call__(self, features, return_tensors=None): FILE: src/axolotl/integrations/kd/collator_online_teacher.py function hmac_sha_from_int_list (line 21) | def hmac_sha_from_int_list(int_list, key, hash_func=hashlib.sha256): class OnlineTeacherCollator (line 46) | class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): method __init__ (line 53) | def __init__( method _normalize_logprobs (line 85) | def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]: method fetch_online_logprobs_sglang (line 98) | def fetch_online_logprobs_sglang( method fetch_online_logprobs_vllm (line 267) | def fetch_online_logprobs_vllm( method __call__ (line 467) | def __call__( FILE: src/axolotl/integrations/kd/kernels/liger.py class LigerFusedLinearKLTopKLogprobFunction (line 14) | class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillation... method distillation_loss_fn (line 20) | def distillation_loss_fn( method _compute_loss_kl_topk (line 120) | def _compute_loss_kl_topk( method forward (line 180) | def forward( method backward (line 352) | def backward(ctx, grad_output): class LigerFusedLinearKLTopKLogprobLoss (line 415) | class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module): method __init__ (line 420) | def __init__( method forward (line 458) | def forward( FILE: src/axolotl/integrations/kd/kernels/models.py class TransformersKwargs (line 15) | class TransformersKwargs(FlashAttentionKwargs, LossKwargs): function kldiv_forward_llama_like (line 28) | def kldiv_forward_llama_like( function apply_kernel (line 98) | def apply_kernel(model_type): FILE: src/axolotl/integrations/kd/topk_logprob/forward_kl.py function loss (line 24) | def loss( class ChunkedTopKKDLoss (line 100) | class ChunkedTopKKDLoss(nn.Module): method __init__ (line 108) | def __init__(self, num_output_chunks: int = 8, kd_temperature: float =... method forward (line 113) | def forward( FILE: src/axolotl/integrations/kd/trainer.py class AxolotlKDTrainer (line 26) | class AxolotlKDTrainer(AxolotlTrainer): method __init__ (line 31) | def __init__(self, *args, **kwargs): method _set_signature_columns_if_needed (line 52) | def _set_signature_columns_if_needed(self): method compute_loss (line 66) | def compute_loss( FILE: src/axolotl/integrations/kd/utils.py function normalize_logprobs (line 11) | def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor: function strided_chunk_views (line 46) | def strided_chunk_views( function chunk_overlap (line 94) | def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overl... FILE: src/axolotl/integrations/kernels/args.py class KernelsArgs (line 8) | class KernelsArgs(BaseModel): method check_mutually_exclusive (line 14) | def check_mutually_exclusive(cls, data): method check_use_kernels (line 24) | def check_use_kernels(cls, data): method check_experts_implementation (line 35) | def check_experts_implementation(cls, data): method disable_mlp_kernel (line 50) | def disable_mlp_kernel(cls, data): FILE: src/axolotl/integrations/kernels/autotune_callback.py function _get_gpu_info (line 19) | def _get_gpu_info() -> dict: function _get_smem_capacity (line 35) | def _get_smem_capacity() -> dict: class AutotuneReportCallback (line 53) | class AutotuneReportCallback(TrainerCallback): method __init__ (line 67) | def __init__(self): method on_step_end (line 71) | def on_step_end( FILE: src/axolotl/integrations/kernels/autotune_collector.py function _parse_key_tuple (line 28) | def _parse_key_tuple(key_tuple: tuple) -> dict[str, Any]: function _find_lora_ops_module (line 44) | def _find_lora_ops_module() -> ModuleType | None: function collect_autotune_configs (line 68) | def collect_autotune_configs() -> list[dict[str, Any]]: FILE: src/axolotl/integrations/kernels/constants.py function resolve_moe_block_classes (line 47) | def resolve_moe_block_classes(model_type: str): FILE: src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py function _next_power_of_2 (line 46) | def _next_power_of_2(n: int) -> int: function _block_r_for_rank (line 61) | def _block_r_for_rank(r: int) -> int: function round_expert_counts (line 71) | def round_expert_counts( function _get_smem_capacity (line 169) | def _get_smem_capacity() -> int: function _estimate_smem_usage (line 180) | def _estimate_smem_usage( function _estimate_register_pressure (line 198) | def _estimate_register_pressure( function _compute_expert_block_lora (line 234) | def _compute_expert_block_lora( function _scatter2scatter_lora_configs (line 356) | def _scatter2scatter_lora_configs(): function _prune_fwd_configs (line 389) | def _prune_fwd_configs(configs, named_args, **kwargs): function _scatter2scatter_lora (line 463) | def _scatter2scatter_lora( function _scatter2scatter_lora_split (line 590) | def _scatter2scatter_lora_split( function scatter2scatter_lora (line 673) | def scatter2scatter_lora( function _compute_expert_block_lora_dX (line 807) | def _compute_expert_block_lora_dX( function _scatter2scatter_lora_dX_configs (line 936) | def _scatter2scatter_lora_dX_configs(): function _prune_dX_configs (line 969) | def _prune_dX_configs(configs, named_args, **kwargs): function _scatter2scatter_lora_dX (line 1043) | def _scatter2scatter_lora_dX( function scatter2scatter_lora_dX (line 1166) | def scatter2scatter_lora_dX( function _group_bwd_lora_configs (line 1274) | def _group_bwd_lora_configs(): function _prune_bwd_lora_configs (line 1308) | def _prune_bwd_lora_configs(configs, named_args, **kwargs): function _group_bwd_lora (line 1377) | def _group_bwd_lora( function _group_bwd_split_configs (line 1546) | def _group_bwd_split_configs(): function _prune_split_configs (line 1565) | def _prune_split_configs(configs, named_args, **kwargs): function _group_bwd_lora_split (line 1617) | def _group_bwd_lora_split( function group_bwd_lora (line 1819) | def group_bwd_lora( function _group_bwd_lora_fused (line 1939) | def _group_bwd_lora_fused( function group_bwd_lora_fused (line 2145) | def group_bwd_lora_fused( FILE: src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py function _compute_expert_block (line 18) | def _compute_expert_block( function _scatter2scatter_configs (line 62) | def _scatter2scatter_configs(): function _scatter2scatter (line 79) | def _scatter2scatter( function scatter2scatter (line 169) | def scatter2scatter( function scatter2scatter_compileable (line 206) | def scatter2scatter_compileable( function _config_XtY (line 264) | def _config_XtY(): function group_bwd_W (line 272) | def group_bwd_W(DY, X, expert_offsets, E, has_bias=False): function groupXtY_compileable (line 284) | def groupXtY_compileable( function _groupXtY (line 346) | def _groupXtY( function _xty_and_bias (line 469) | def _xty_and_bias( function _config_grouping (line 537) | def _config_grouping(): function group (line 545) | def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None): function group_compileable (line 558) | def group_compileable( function _group (line 595) | def _group( FILE: src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py function _single2scatter (line 13) | def _single2scatter( function single2scatter (line 70) | def single2scatter(X, W, expert_idxs): FILE: src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py function peft_lora_B_to_scattermoe (line 46) | def peft_lora_B_to_scattermoe(peft_B, num_experts, rank): function peft_lora_to_scattermoe (line 62) | def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): function peft_down_proj_lora_to_scattermoe (line 109) | def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): function _unwrap_gate_lora (line 119) | def _unwrap_gate_lora(gate_module): function _convert_smoe_lora (line 155) | def _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling): function _unwrap_experts_lora (line 161) | def _unwrap_experts_lora(experts_module): function _softmax_topk_route (line 228) | def _softmax_topk_route( function _sigmoid_topk_route (line 251) | def _sigmoid_topk_route( function _route (line 319) | def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_d... function _compute_shared_expert (line 343) | def _compute_shared_expert(moe_block, hidden_states_flat): class ScatterMoEGatedMLP (line 380) | class ScatterMoEGatedMLP(nn.Module): method forward (line 381) | def forward(self, layer_input): class HFScatterMoEGatedMLP (line 434) | class HFScatterMoEGatedMLP(nn.Module): method forward (line 451) | def forward(self: nn.Module, layer_input: torch.Tensor): FILE: src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py class ParallelExperts (line 20) | class ParallelExperts(nn.Module): method __init__ (line 29) | def __init__( method reset_parameters (line 50) | def reset_parameters(self) -> None: method extra_repr (line 55) | def extra_repr(self) -> str: method set_lora (line 62) | def set_lora(self, lora_A: torch.Tensor, lora_B: torch.Tensor, scaling... method clear_lora (line 68) | def clear_lora(self): method forward (line 74) | def forward( FILE: src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py function compileable_bincount (line 16) | def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: function _ (line 21) | def _(x: torch.Tensor, minlength: int) -> torch.Tensor: function flatten_sort_count (line 26) | def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int): class ParallelLinear (line 37) | class ParallelLinear(torch.autograd.Function): method forward (line 39) | def forward( method backward (line 87) | def backward(ctx, grad_out: torch.Tensor): function parallel_linear (line 178) | def parallel_linear( class ParallelExperts (line 205) | class ParallelExperts(nn.Module): method __init__ (line 206) | def __init__(self, num_experts, input_size, output_size, bias=False) -... method extra_repr (line 220) | def extra_repr(self): method reset_parameters (line 225) | def reset_parameters(self) -> None: method forward (line 230) | def forward( FILE: src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py class ScatterMoELoRA (line 39) | class ScatterMoELoRA(torch.autograd.Function): method forward (line 50) | def forward( method backward (line 121) | def backward(ctx, grad_out: torch.Tensor): function _compute_lora_input_grad (line 342) | def _compute_lora_input_grad( function get_lora_params_from_wrapper (line 390) | def get_lora_params_from_wrapper(module) -> tuple: function parallel_linear_lora (line 425) | def parallel_linear_lora( FILE: src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant.py function get_active_experts (line 28) | def get_active_experts(sorted_expert_idxs: torch.Tensor, E: int) -> torc... function remap_expert_indices (line 41) | def remap_expert_indices( function _selective_dequant_bnb4 (line 76) | def _selective_dequant_bnb4( function _selective_index_dense (line 178) | def _selective_index_dense( function selective_expert_weights (line 189) | def selective_expert_weights( function selective_lora_weights (line 255) | def selective_lora_weights( FILE: src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant_kernel.py function _selective_dequant_nf4_kernel (line 45) | def _selective_dequant_nf4_kernel( function selective_dequant_nf4_triton (line 118) | def selective_dequant_nf4_triton( FILE: src/axolotl/integrations/kernels/plugin.py function _check_sonicmoe_gpu_compat (line 13) | def _check_sonicmoe_gpu_compat(): class KernelsPlugin (line 59) | class KernelsPlugin(BasePlugin): method get_input_args (line 60) | def get_input_args(self): method pre_model_load (line 63) | def pre_model_load(self, cfg): method _register_kernels (line 95) | def _register_kernels(self): method add_callbacks_pre_trainer (line 122) | def add_callbacks_pre_trainer(self, cfg, model): method _kernelize_model (line 132) | def _kernelize_model(self, model_type: str): FILE: src/axolotl/integrations/kernels/sonicmoe/patch.py function patch_sonicmoe (line 34) | def patch_sonicmoe(model_type: str, torch_compile: bool = False): function _try_compile_routing (line 55) | def _try_compile_routing(routing_fn): function _patch_forward (line 69) | def _patch_forward(moe_cls, routing_fn, activation, router_attr): function _make_general_forward (line 99) | def _make_general_forward(moe_cls, routing_fn, activation): function _make_fused_forward (line 152) | def _make_fused_forward(moe_cls, activation, router_attr): function _compute_shared_expert (line 199) | def _compute_shared_expert(moe_block, hidden_states_flat): FILE: src/axolotl/integrations/kernels/sonicmoe/routing.py function get_model_moe_config (line 18) | def get_model_moe_config(model_type: str): function softmax_topk_routing (line 81) | def softmax_topk_routing( function softmax_group_topk_routing (line 132) | def softmax_group_topk_routing( function sigmoid_topk_routing (line 188) | def sigmoid_topk_routing( FILE: src/axolotl/integrations/kernels/sonicmoe/weight_converter.py function interleave_gate_up (line 23) | def interleave_gate_up(tensor: torch.Tensor) -> torch.Tensor: function deinterleave_gate_up (line 28) | def deinterleave_gate_up(tensor: torch.Tensor) -> torch.Tensor: class ConcatenatedToInterleaved (line 33) | class ConcatenatedToInterleaved(ConversionOps): method __init__ (line 42) | def __init__(self, dim: int = 1): method convert (line 46) | def convert( method _get_target_pattern (line 63) | def _get_target_pattern( method reverse_op (line 79) | def reverse_op(self) -> ConversionOps: class InterleavedToConcatenated (line 83) | class InterleavedToConcatenated(ConversionOps): method __init__ (line 92) | def __init__(self, dim: int = 1): method convert (line 96) | def convert( method _get_target_pattern (line 113) | def _get_target_pattern( method reverse_op (line 128) | def reverse_op(self) -> ConversionOps: function register_sonicmoe_weight_converter (line 132) | def register_sonicmoe_weight_converter(model_type: str): FILE: src/axolotl/integrations/liger/args.py class LigerArgs (line 26) | class LigerArgs(BaseModel): method check_deprecated_swiglu (line 50) | def check_deprecated_swiglu(cls, data): method check_tiled_mlp_conflict (line 66) | def check_tiled_mlp_conflict(cls, data): method check_liger_rms_norm_tensor_parallel (line 79) | def check_liger_rms_norm_tensor_parallel(cls, data): method check_liger_use_token_scaling_flce (line 89) | def check_liger_use_token_scaling_flce(cls, data): method check_tensor_parallel_size_liger_fused_linear_cross_entropy (line 100) | def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self): FILE: src/axolotl/integrations/liger/models/base.py function lce_forward (line 18) | def lce_forward( function lce_maybe_trainable_lm_head (line 121) | def lce_maybe_trainable_lm_head( function _liger_for_causal_lm_loss (line 159) | def _liger_for_causal_lm_loss( function patch_lce_forward (line 172) | def patch_lce_forward( FILE: src/axolotl/integrations/liger/models/deepseekv2.py function lce_forward (line 15) | def lce_forward( FILE: src/axolotl/integrations/liger/models/jamba.py function lce_forward (line 19) | def lce_forward( FILE: src/axolotl/integrations/liger/models/llama4.py function lce_forward (line 14) | def lce_forward( function apply_liger_kernel_to_llama4 (line 122) | def apply_liger_kernel_to_llama4( FILE: src/axolotl/integrations/liger/models/qwen3.py function lce_forward (line 14) | def lce_forward( function apply_liger_kernel_to_qwen3 (line 109) | def apply_liger_kernel_to_qwen3( FILE: src/axolotl/integrations/liger/models/qwen3_moe.py function lce_forward (line 15) | def lce_forward( function apply_liger_kernel_to_qwen3_moe (line 131) | def apply_liger_kernel_to_qwen3_moe( FILE: src/axolotl/integrations/liger/plugin.py class LigerPlugin (line 14) | class LigerPlugin(BasePlugin): method get_input_args (line 19) | def get_input_args(self): method pre_model_load (line 22) | def pre_model_load(self, cfg): FILE: src/axolotl/integrations/liger/utils.py function patch_with_compile_disable (line 10) | def patch_with_compile_disable(module, function_name): FILE: src/axolotl/integrations/llm_compressor/args.py class CompressionArgs (line 11) | class CompressionArgs(BaseModel): class LLMCompressorArgs (line 32) | class LLMCompressorArgs(BaseModel): FILE: src/axolotl/integrations/llm_compressor/plugin.py class LLMCompressorCallbackHandler (line 26) | class LLMCompressorCallbackHandler(TrainerCallback): method __init__ (line 33) | def __init__(self, trainer: Trainer, recipe: Any): method on_train_begin (line 50) | def on_train_begin( method on_step_begin (line 75) | def on_step_begin( method on_step_end (line 88) | def on_step_end( method on_train_end (line 103) | def on_train_end( class LLMCompressorPlugin (line 118) | class LLMCompressorPlugin(BasePlugin): method get_input_args (line 123) | def get_input_args(self) -> str: method add_callbacks_post_trainer (line 132) | def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list: function compute_loss_wrapper (line 151) | def compute_loss_wrapper( FILE: src/axolotl/integrations/llm_compressor/utils.py function save_compressed_model (line 11) | def save_compressed_model( FILE: src/axolotl/integrations/lm_eval/__init__.py class LMEvalPlugin (line 13) | class LMEvalPlugin(BasePlugin): method get_input_args (line 18) | def get_input_args(self): method post_train_unload (line 21) | def post_train_unload(self, cfg): FILE: src/axolotl/integrations/lm_eval/args.py class LMEvalArgs (line 10) | class LMEvalArgs(BaseModel): FILE: src/axolotl/integrations/lm_eval/cli.py function get_model_path (line 16) | def get_model_path(cfg: DictDefault) -> str | None: function build_lm_eval_command (line 31) | def build_lm_eval_command( function lm_eval (line 104) | def lm_eval(config: str, cloud: Optional[str] = None): FILE: src/axolotl/integrations/spectrum/__init__.py function _generate_unfrozen_params_yaml (line 31) | def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5): class SpectrumPlugin (line 55) | class SpectrumPlugin(BasePlugin): method get_input_args (line 64) | def get_input_args(self): method pre_model_load (line 67) | def pre_model_load(self, cfg): FILE: src/axolotl/integrations/spectrum/args.py class SpectrumArgs (line 24) | class SpectrumArgs(BaseModel): method check_fsdp_use_orig_params (line 34) | def check_fsdp_use_orig_params(cls, data): FILE: src/axolotl/integrations/swanlab/args.py class SwanLabConfig (line 6) | class SwanLabConfig(BaseModel): method validate_swanlab_mode (line 96) | def validate_swanlab_mode(cls, v): method validate_swanlab_project (line 116) | def validate_swanlab_project(cls, v): method validate_swanlab_enabled_requires_project (line 128) | def validate_swanlab_enabled_requires_project(self): FILE: src/axolotl/integrations/swanlab/callbacks.py class SwanLabRLHFCompletionCallback (line 20) | class SwanLabRLHFCompletionCallback(TrainerCallback): method __init__ (line 41) | def __init__( method on_init_end (line 61) | def on_init_end( method on_log (line 88) | def on_log( method on_train_end (line 166) | def on_train_end( FILE: src/axolotl/integrations/swanlab/completion_logger.py class CompletionLogger (line 16) | class CompletionLogger: method __init__ (line 39) | def __init__(self, maxlen: int = 128): method add_dpo_completion (line 50) | def add_dpo_completion( method add_kto_completion (line 78) | def add_kto_completion( method add_orpo_completion (line 106) | def add_orpo_completion( method add_grpo_completion (line 134) | def add_grpo_completion( method log_to_swanlab (line 163) | def log_to_swanlab(self, table_name: str = "completions") -> bool: method clear (line 215) | def clear(self) -> None: method __len__ (line 219) | def __len__(self) -> int: method __repr__ (line 223) | def __repr__(self) -> str: FILE: src/axolotl/integrations/swanlab/plugins.py class SwanLabPlugin (line 18) | class SwanLabPlugin(BasePlugin): method __init__ (line 35) | def __init__(self): method get_input_args (line 40) | def get_input_args(self) -> str: method register (line 44) | def register(self, cfg: dict): method pre_model_load (line 147) | def pre_model_load(self, cfg: DictDefault): method add_callbacks_pre_trainer (line 225) | def add_callbacks_pre_trainer(self, cfg: DictDefault, model): method post_trainer_create (line 261) | def post_trainer_create(self, cfg: DictDefault, trainer): method _get_swanlab_init_kwargs (line 298) | def _get_swanlab_init_kwargs(self, cfg: DictDefault) -> dict: method _prepare_config_for_logging (line 350) | def _prepare_config_for_logging(self, cfg: DictDefault) -> dict: method _register_lark_callback (line 435) | def _register_lark_callback(self, cfg: DictDefault): method _register_completion_callback (line 491) | def _register_completion_callback(self, cfg: DictDefault, trainer): FILE: src/axolotl/integrations/swanlab/profiling.py function swanlab_profiling_context (line 18) | def swanlab_profiling_context(trainer: Any, func_name: str): function swanlab_profile (line 61) | def swanlab_profile(func: Callable) -> Callable: class ProfilingConfig (line 88) | class ProfilingConfig: method __init__ (line 99) | def __init__( method should_log (line 117) | def should_log(self, func_name: str, duration_seconds: float) -> bool: function swanlab_profiling_context_advanced (line 152) | def swanlab_profiling_context_advanced( FILE: src/axolotl/kernels/geglu.py function _geglu_fwd_kernel (line 14) | def _geglu_fwd_kernel( function geglu_forward (line 45) | def geglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: function _geglu_bwd_kernel (line 71) | def _geglu_bwd_kernel( function geglu_backward (line 125) | def geglu_backward( FILE: src/axolotl/kernels/lora.py function get_lora_parameters (line 23) | def get_lora_parameters( function matmul_lora (line 84) | def matmul_lora( class LoRA_MLP (line 132) | class LoRA_MLP(torch.autograd.Function): method forward (line 137) | def forward( method backward (line 219) | def backward( function apply_lora_mlp_swiglu (line 393) | def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -... function apply_lora_mlp_geglu (line 436) | def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) ->... class LoRA_QKV (line 478) | class LoRA_QKV(torch.autograd.Function): method forward (line 488) | def forward( method backward (line 555) | def backward( function apply_lora_qkv (line 698) | def apply_lora_qkv( class LoRA_O (line 740) | class LoRA_O(torch.autograd.Function): method forward (line 745) | def forward( method backward (line 783) | def backward( function apply_lora_o (line 831) | def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor: FILE: src/axolotl/kernels/quantize.py function dequantize_fp8 (line 18) | def dequantize_fp8( function dequantize (line 59) | def dequantize( FILE: src/axolotl/kernels/swiglu.py function _swiglu_fwd_kernel (line 15) | def _swiglu_fwd_kernel( function _swiglu_bwd_kernel (line 50) | def _swiglu_bwd_kernel( function swiglu_forward (line 102) | def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: function swiglu_backward (line 130) | def swiglu_backward( FILE: src/axolotl/loaders/adapter.py function setup_quantized_meta_for_peft (line 30) | def setup_quantized_meta_for_peft(model: torch.nn.Module): function setup_quantized_peft_meta_for_training (line 42) | def setup_quantized_peft_meta_for_training(model: torch.nn.Module): function find_all_linear_names (line 50) | def find_all_linear_names(model): function load_lora (line 70) | def load_lora( function load_adapter (line 194) | def load_adapter( function load_llama_adapter (line 214) | def load_llama_adapter( FILE: src/axolotl/loaders/model.py class ModelLoader (line 67) | class ModelLoader: method __init__ (line 98) | def __init__( method has_flash_attn (line 147) | def has_flash_attn(self) -> bool: method is_fsdp_enabled (line 152) | def is_fsdp_enabled(self): method is_qlora_and_fsdp_enabled (line 157) | def is_qlora_and_fsdp_enabled(self): method load (line 162) | def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftCo... method _apply_pre_model_load_setup (line 196) | def _apply_pre_model_load_setup(self): method _apply_post_model_load_setup (line 224) | def _apply_post_model_load_setup(self): method _configure_experts_implementation (line 241) | def _configure_experts_implementation(self): method _apply_activation_checkpointing (line 245) | def _apply_activation_checkpointing(self): method _resize_token_embeddings (line 254) | def _resize_token_embeddings(self): method _adjust_model_config (line 277) | def _adjust_model_config(self): method _configure_embedding_dtypes (line 306) | def _configure_embedding_dtypes(self): method _configure_qat (line 368) | def _configure_qat(self): method _load_adapters (line 381) | def _load_adapters(self) -> PeftConfig | None: method _apply_post_lora_load_setup (line 404) | def _apply_post_lora_load_setup(self, skip_move_to_device: bool): method _set_parallel_config (line 437) | def _set_parallel_config(self): method _set_auto_model_loader (line 444) | def _set_auto_model_loader(self): method _set_device_map_config (line 456) | def _set_device_map_config(self): method _set_quantization_config (line 539) | def _set_quantization_config(self): method _set_attention_config (line 623) | def _set_attention_config(self): method _check_model_requirements (line 650) | def _check_model_requirements(self): method _configure_zero3_memory_efficient_loading (line 660) | def _configure_zero3_memory_efficient_loading( method _load_model_from_config (line 699) | def _load_model_from_config(self, model_loader_class=None) -> PreTrain... method _load_model_from_pretrained (line 716) | def _load_model_from_pretrained(self, model_loader_class=None) -> PreT... method _build_model (line 726) | def _build_model(self) -> bool: method _set_z3_leaf_modules (line 845) | def _set_z3_leaf_modules(self): method _prepare_model_for_quantization (line 860) | def _prepare_model_for_quantization(self): method _convert_embedding_modules_dtype (line 897) | def _convert_embedding_modules_dtype( FILE: src/axolotl/loaders/patch_manager.py class PatchManager (line 27) | class PatchManager: method apply_pre_config_load_patches (line 31) | def apply_pre_config_load_patches(cfg: DictDefault): method apply_pre_tokenizer_load_patches (line 52) | def apply_pre_tokenizer_load_patches(cfg: DictDefault): method __init__ (line 72) | def __init__( method has_flash_attn (line 90) | def has_flash_attn(self) -> bool: method apply_pre_model_load_patches (line 94) | def apply_pre_model_load_patches(self): method apply_post_plugin_pre_model_load_patches (line 122) | def apply_post_plugin_pre_model_load_patches(self): method _apply_transformers_patches (line 127) | def _apply_transformers_patches(self): method apply_post_model_build_patches (line 143) | def apply_post_model_build_patches(self, model: PreTrainedModel): method apply_post_model_load_patches (line 147) | def apply_post_model_load_patches(self, model: PreTrainedModel): method _apply_flash_attention_patches (line 154) | def _apply_flash_attention_patches(self): method _apply_chunked_cross_entropy_patch (line 162) | def _apply_chunked_cross_entropy_patch(self): method _apply_fsdp_patches (line 171) | def _apply_fsdp_patches(self): method _apply_adapter_patches (line 210) | def _apply_adapter_patches(self): method _apply_flex_attention_patches (line 217) | def _apply_flex_attention_patches(self): method _apply_sageattn_patches (line 227) | def _apply_sageattn_patches(self): method _apply_flash_attn_4_patches (line 234) | def _apply_flash_attn_4_patches(self): method _apply_model_specific_patches (line 243) | def _apply_model_specific_patches(self): method _apply_fp8_patches (line 294) | def _apply_fp8_patches(self): method _apply_flash_attention_peft_patches (line 305) | def _apply_flash_attention_peft_patches(self): method _apply_gradient_checkpointing_patches (line 314) | def _apply_gradient_checkpointing_patches(self): method _apply_mistral_cross_entropy_patch (line 337) | def _apply_mistral_cross_entropy_patch(self): method _apply_self_attention_lora_patch (line 349) | def _apply_self_attention_lora_patch(self): method _apply_multipack_patches (line 367) | def _apply_multipack_patches(self): method _apply_fsdp2_bnb_patches (line 405) | def _apply_fsdp2_bnb_patches(self): method _deactivate_hf_async_load (line 425) | def _deactivate_hf_async_load(self): method _apply_moe_expert_quantization_patch (line 430) | def _apply_moe_expert_quantization_patch(self): method _finalize_moe_expert_quantization (line 448) | def _finalize_moe_expert_quantization(self, model: PreTrainedModel): method _apply_tiled_mlp (line 469) | def _apply_tiled_mlp(self, model_type: str): method _apply_voxtral_patches (line 481) | def _apply_voxtral_patches(self): method _patch_attention (line 490) | def _patch_attention(self): method _patch_loss_llama (line 516) | def _patch_loss_llama(self): method _patch_llama_flash_attention (line 546) | def _patch_llama_flash_attention(self): method _patch_llama_xformers_attention (line 565) | def _patch_llama_xformers_attention(self): method _patch_llama_derived_model (line 574) | def _patch_llama_derived_model(self): method _apply_llama_flash_attn_patches (line 590) | def _apply_llama_flash_attn_patches(self, model): method _apply_unsloth_patches (line 610) | def _apply_unsloth_patches(self, model): method _apply_lora_kernel_patch (line 627) | def _apply_lora_kernel_patch(self, model): method _apply_patch_deepspeed_zero3 (line 638) | def _apply_patch_deepspeed_zero3(self): method _apply_apertus_patches (line 652) | def _apply_apertus_patches(self): method _apply_trl_vllm_patches (line 661) | def _apply_trl_vllm_patches(self): method _apply_trl_trainer_utils_patches (line 672) | def _apply_trl_trainer_utils_patches(self): method _apply_scaling_softmax_patch (line 705) | def _apply_scaling_softmax_patch(self, model: PreTrainedModel): FILE: src/axolotl/loaders/processor.py function load_processor (line 17) | def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): FILE: src/axolotl/loaders/tokenizer.py function modify_tokenizer_files (line 30) | def modify_tokenizer_files( function load_tokenizer (line 130) | def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: FILE: src/axolotl/loaders/utils.py function get_module_class_from_name (line 17) | def get_module_class_from_name( function check_model_config (line 45) | def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): function load_model_config (line 151) | def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict: function ensure_dtype (line 210) | def ensure_dtype(model: PreTrainedModel, dtype: torch.dtype = torch.bflo... function get_linear_embedding_layers (line 231) | def get_linear_embedding_layers(model_type: str) -> list[str]: FILE: src/axolotl/logging_config.py class AxolotlOrWarnErrorFilter (line 15) | class AxolotlOrWarnErrorFilter(logging.Filter): method __init__ (line 22) | def __init__(self, **kwargs): method filter (line 40) | def filter(self, record: LogRecord) -> bool: class AxolotlLogger (line 51) | class AxolotlLogger(Logger): method __init__ (line 54) | def __init__(self, name: str, level: int = logging.NOTSET): class ColorfulFormatter (line 60) | class ColorfulFormatter(Formatter): method format (line 71) | def format(self, record): function configure_logging (line 142) | def configure_logging(): FILE: src/axolotl/models/mamba/__init__.py function check_mamba_ssm_installed (line 8) | def check_mamba_ssm_installed(): function fix_mamba_attn_for_loss (line 16) | def fix_mamba_attn_for_loss(): FILE: src/axolotl/models/mamba/configuration_mamba.py class MambaConfig (line 8) | class MambaConfig(PretrainedConfig): method __init__ (line 15) | def __init__( FILE: src/axolotl/models/mamba/modeling_mamba.py class MambaLMHeadModel (line 16) | class MambaLMHeadModel(nn.Module, GenerationMixin): method __init__ (line 17) | def __init__( method tie_weights (line 60) | def tie_weights(self): method allocate_inference_cache (line 63) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,... method forward (line 68) | def forward( method save_pretrained (line 110) | def save_pretrained( method from_pretrained (line 121) | def from_pretrained(cls, pretrained_model_name, device=None, dtype=Non... FILE: src/axolotl/monkeypatch/accelerate/fsdp2.py function fsdp2_load_full_state_dict (line 20) | def fsdp2_load_full_state_dict( function get_state_dict (line 93) | def get_state_dict(self, model, unwrap=True): function patch_peft_param_wrapper_for_fsdp2 (line 189) | def patch_peft_param_wrapper_for_fsdp2(): function _process_lora_module_for_fsdp (line 228) | def _process_lora_module_for_fsdp(module, fsdp2_kwargs): function fsdp2_prepare_model (line 272) | def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn... function patch_tied_keys_for_meta_device (line 445) | def patch_tied_keys_for_meta_device(): function patch_initialize_missing_keys_for_fsdp (line 482) | def patch_initialize_missing_keys_for_fsdp(): function patch_accelerate_fsdp2 (line 522) | def patch_accelerate_fsdp2(): FILE: src/axolotl/monkeypatch/accelerate/parallelism_config.py function _validate_accelerator (line 11) | def _validate_accelerator(self, accelerator): function patched_is_fsdp2 (line 61) | def patched_is_fsdp2(self) -> bool: function patch_parallelism_config (line 73) | def patch_parallelism_config(): function patch_prepare_cp (line 80) | def patch_prepare_cp(): FILE: src/axolotl/monkeypatch/attention/__init__.py function patch_xformers_attn_over_fa2 (line 8) | def patch_xformers_attn_over_fa2(): function unpatch_xformers_attn_over_fa2 (line 16) | def unpatch_xformers_attn_over_fa2(): FILE: src/axolotl/monkeypatch/attention/flash_attn_4.py function _get_head_dims (line 10) | def _get_head_dims(model_config): function patch_flash_attn_4 (line 36) | def patch_flash_attn_4(model_config=None): FILE: src/axolotl/monkeypatch/attention/flex_attn.py function patch_flex_wrapper (line 15) | def patch_flex_wrapper(**flex_attn_compile_kwargs): FILE: src/axolotl/monkeypatch/attention/sage_attn.py function _is_sageattn_available (line 18) | def _is_sageattn_available(): function _check_sageattn_imported (line 33) | def _check_sageattn_imported(): function sage_attention_forward (line 42) | def sage_attention_forward( function patch_sageattn (line 193) | def patch_sageattn(): FILE: src/axolotl/monkeypatch/attention/xformers.py function xformers_attention_forward (line 19) | def xformers_attention_forward( FILE: src/axolotl/monkeypatch/btlm_attn_hijack_flash.py function replace_btlm_attn_with_flash_attn (line 18) | def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-ba... function flashattn_attn (line 31) | def flashattn_attn( FILE: src/axolotl/monkeypatch/data/batch_dataset_fetcher.py class _MapDatasetFetcher (line 12) | class _MapDatasetFetcher(_BaseDatasetFetcher): method fetch (line 18) | def fetch(self, possibly_batched_index): function patch_fetchers (line 45) | def patch_fetchers(): function patched_worker_loop (line 51) | def patched_worker_loop(*args, **kwargs): function apply_multipack_dataloader_patch (line 57) | def apply_multipack_dataloader_patch(): function remove_multipack_dataloader_patch (line 79) | def remove_multipack_dataloader_patch(): FILE: src/axolotl/monkeypatch/deepspeed_utils.py function patch_checkpoint_wrapper_setattr (line 9) | def patch_checkpoint_wrapper_setattr(): function apply_deepspeed_patches (line 60) | def apply_deepspeed_patches(): FILE: src/axolotl/monkeypatch/fsdp2_qlora.py function apply_init_sharded_param_patch (line 19) | def apply_init_sharded_param_patch(): function apply_init_unsharded_param_patch (line 96) | def apply_init_unsharded_param_patch(): function apply_linear8bitlt_save_patch (line 172) | def apply_linear8bitlt_save_patch(): function apply_init_dtype_attrs_patch (line 205) | def apply_init_dtype_attrs_patch(): FILE: src/axolotl/monkeypatch/gradient_checkpointing/__init__.py function uses_gc_layers (line 19) | def uses_gc_layers(decoder_layer): function uses_gc_layers (line 24) | def uses_gc_layers(_): function hf_grad_checkpoint_offload_wrapper (line 28) | def hf_grad_checkpoint_offload_wrapper(decoder_layer, *args, use_reentra... function hf_grad_checkpoint_disk_offload_wrapper (line 45) | def hf_grad_checkpoint_disk_offload_wrapper(decoder_layer, *args, use_re... FILE: src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py class CPU_Offloaded_Gradient_Checkpointer (line 38) | class CPU_Offloaded_Gradient_Checkpointer(torch.autograd.Function): method forward (line 46) | def forward(ctx, forward_function, hidden_states, *args): method backward (line 57) | def backward(ctx, dY): FILE: src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py class DiskOffloadManager (line 43) | class DiskOffloadManager: method __init__ (line 49) | def __init__( method _save_worker (line 106) | def _save_worker(self): method _save_tensor_to_disk (line 129) | def _save_tensor_to_disk(self, tensor: torch.Tensor, file_path: str): method _prefetch_worker (line 155) | def _prefetch_worker(self): method save_tensor (line 223) | def save_tensor(self, tensor: torch.Tensor): method wait_for_save (line 245) | def wait_for_save(self, file_path, timeout=None) -> None: method load_tensor (line 267) | def load_tensor(self, file_path, target_device="cuda"): method _safe_delete_file (line 308) | def _safe_delete_file(self, file_path): method trigger_prefetch (line 335) | def trigger_prefetch(self, n=None): method cleanup_tensor (line 356) | def cleanup_tensor(self, file_path: str): method cleanup (line 376) | def cleanup(self): class Disco (line 425) | class Disco(torch.autograd.Function): method get_instance (line 435) | def get_instance(prefetch_size=1, prefetch_to_gpu=True, save_workers=4): method forward (line 447) | def forward( method backward (line 481) | def backward(ctx, *grad_outputs): FILE: src/axolotl/monkeypatch/llama_attn_hijack_flash.py function is_xformers_available (line 35) | def is_xformers_available() -> bool: function is_xformers_swiglu_available (line 39) | def is_xformers_swiglu_available() -> bool: function replace_llama_mlp_with_swiglu (line 54) | def replace_llama_mlp_with_swiglu(model): function patch_fa_llama_cross_entropy (line 68) | def patch_fa_llama_cross_entropy(): function patch_llama_rms_norm (line 96) | def patch_llama_rms_norm(): function replace_llama_attn_with_flash_attn (line 114) | def replace_llama_attn_with_flash_attn( function _prepare_decoder_attention_mask (line 136) | def _prepare_decoder_attention_mask( function flashattn_forward_with_s2attn (line 150) | def flashattn_forward_with_s2attn( FILE: src/axolotl/monkeypatch/llama_attn_hijack_xformers.py function hijack_llama_attention (line 23) | def hijack_llama_attention(): function xformers_forward (line 27) | def xformers_forward( FILE: src/axolotl/monkeypatch/lora_kernels.py function original_apply_qkv (line 94) | def original_apply_qkv( function original_apply_o (line 115) | def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> to... function get_attention_cls_from_config (line 131) | def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: function patch_self_attn_lora (line 201) | def patch_self_attn_lora(cfg: DictDefault): function find_self_attn_in_layer (line 265) | def find_self_attn_in_layer( function find_mlp_in_layer (line 277) | def find_mlp_in_layer( function get_layers (line 309) | def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]: function apply_lora_kernel_patches (line 334) | def apply_lora_kernel_patches( class FakeMLP (line 476) | class FakeMLP(nn.Module): method __init__ (line 485) | def __init__(self, gate_proj, up_proj, down_proj): FILE: src/axolotl/monkeypatch/loss/chunked.py class CEWithChunkedOutputLoss (line 12) | class CEWithChunkedOutputLoss(torch.nn.Module): method __init__ (line 19) | def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100): method compute_cross_entropy (line 24) | def compute_cross_entropy( method forward (line 37) | def forward( function _build_chunked_ce_loss_fn (line 74) | def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: ... function get_causal_lm_loss (line 82) | def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -... function patch_chunked_ce_loss_fn (line 127) | def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: i... FILE: src/axolotl/monkeypatch/loss/eaft.py function eaft_loss (line 12) | def eaft_loss(outputs, labels, num_items_in_batch=None, alpha=1.0, k=20): FILE: src/axolotl/monkeypatch/mistral_attn_hijack_flash.py function patch_mistral_cross_entropy (line 12) | def patch_mistral_cross_entropy(): FILE: src/axolotl/monkeypatch/mixtral/__init__.py function patch_mixtral_moe_forward_zero3 (line 8) | def patch_mixtral_moe_forward_zero3() -> None: FILE: src/axolotl/monkeypatch/models/apertus/activation.py function patch_apertus_xielu_activation (line 6) | def patch_apertus_xielu_activation(): FILE: src/axolotl/monkeypatch/models/kimi_linear/configuration_kimi.py class KimiLinearConfig (line 13) | class KimiLinearConfig(PretrainedConfig): method __init__ (line 17) | def __init__( method is_mla (line 119) | def is_mla(self): method is_moe (line 130) | def is_moe(self): method is_linear_attn (line 134) | def is_linear_attn(self) -> bool: method is_kda_layer (line 144) | def is_kda_layer(self, layer_idx: int): FILE: src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py function load_balancing_loss_func (line 57) | def load_balancing_loss_func( class KimiDynamicCache (line 83) | class KimiDynamicCache: method __init__ (line 91) | def __init__(self, config: KimiLinearConfig): method __len__ (line 123) | def __len__(self): method update (line 126) | def update( method reorder_cache (line 146) | def reorder_cache(self, beam_idx: torch.LongTensor): method get_seq_length (line 172) | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: method get_mask_sizes (line 184) | def get_mask_sizes( method has_previous_state (line 199) | def has_previous_state(self): class KimiRMSNorm (line 206) | class KimiRMSNorm(nn.Module): method __init__ (line 207) | def __init__(self, hidden_size, eps=1e-6): method forward (line 215) | def forward(self, hidden_states): class KimiBlockSparseMLP (line 226) | class KimiBlockSparseMLP(nn.Module): method __init__ (line 227) | def __init__( method forward (line 243) | def forward(self, hidden_states): class KimiMLP (line 251) | class KimiMLP(nn.Module): method __init__ (line 252) | def __init__( method forward (line 266) | def forward(self, x): function repeat_kv (line 271) | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: function eager_attention_forward (line 285) | def eager_attention_forward( class KimiMLAAttention (line 315) | class KimiMLAAttention(nn.Module): method __init__ (line 320) | def __init__(self, config: KimiLinearConfig, layer_idx: int): method forward (line 372) | def forward( class KimiDeltaAttention (line 451) | class KimiDeltaAttention(nn.Module): method __init__ (line 452) | def __init__(self, config: KimiLinearConfig, layer_idx: int): method forward (line 510) | def forward( class KimiMoEGate (line 619) | class KimiMoEGate(nn.Module): method __init__ (line 625) | def __init__(self, config: KimiLinearConfig): method reset_parameters (line 635) | def reset_parameters(self) -> None: method forward (line 640) | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class KimiSparseMoeBlock (line 797) | class KimiSparseMoeBlock(nn.Module): method __init__ (line 803) | def __init__(self, config: KimiLinearConfig): method route_tokens_to_experts (line 830) | def route_tokens_to_experts( method forward (line 887) | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: method _training_forward (line 922) | def _training_forward( method _inference_forward (line 994) | def _inference_forward( class KimiDecoderLayer (line 1037) | class KimiDecoderLayer(nn.Module): method __init__ (line 1038) | def __init__(self, config: KimiLinearConfig, layer_idx: int): method forward (line 1063) | def forward( class KimiPreTrainedModel (line 1127) | class KimiPreTrainedModel(PreTrainedModel): method _init_weights (line 1141) | def _init_weights(self, module): class KimiLinearModel (line 1153) | class KimiLinearModel(KimiPreTrainedModel): method __init__ (line 1154) | def __init__(self, config: KimiLinearConfig): method _update_linear_attn_mask (line 1185) | def _update_linear_attn_mask(self, attention_mask, cache_position): method forward (line 1199) | def forward( class KimiLinearForCausalLM (line 1272) | class KimiLinearForCausalLM(KimiPreTrainedModel, GenerationMixin): method __init__ (line 1275) | def __init__(self, config): method forward (line 1285) | def forward( FILE: src/axolotl/monkeypatch/models/kimi_linear/patch_kimi_linear.py function get_patch_file_path (line 13) | def get_patch_file_path(package_dot_path: str, filename: str) -> Path: function _load_local_module (line 23) | def _load_local_module(module_name: str, filename: str): function _patch_get_class_in_module (line 38) | def _patch_get_class_in_module(): function patch_kimi (line 73) | def patch_kimi(): FILE: src/axolotl/monkeypatch/models/kimi_linear/tokenization_kimi.py class TikTokenTokenizer (line 33) | class TikTokenTokenizer(PreTrainedTokenizer): method __init__ (line 78) | def __init__( method encode (line 167) | def encode( method decode (line 232) | def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str: method _split_whitespaces_or_nonwhitespaces (line 253) | def _split_whitespaces_or_nonwhitespaces( method pre_tokenizer_process (line 278) | def pre_tokenizer_process(self, text: str) -> List[str]: method vocab_size (line 288) | def vocab_size(self) -> int: method get_vocab (line 291) | def get_vocab(self) -> Dict[str, int]: method _tokenize (line 294) | def _tokenize(self, text: str, **kwargs) -> List[str]: method _convert_token_to_id (line 297) | def _convert_token_to_id(self, token: str) -> int: method _convert_id_to_token (line 300) | def _convert_id_to_token(self, index: int) -> str: method clean_up_tokenization (line 304) | def clean_up_tokenization(out_string: str) -> str: method convert_tokens_to_string (line 307) | def convert_tokens_to_string(self, tokens: List[str]) -> str: method save_vocabulary (line 314) | def save_vocabulary( method apply_chat_template (line 334) | def apply_chat_template( function deep_sort_dict (line 352) | def deep_sort_dict(obj: Any) -> Any: FILE: src/axolotl/monkeypatch/models/llama4/modeling.py class Llama4TextExperts (line 13) | class Llama4TextExperts(nn.Module): method __init__ (line 18) | def __init__(self, config: Llama4Config): method forward (line 50) | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: function patch_llama4_linearized_modeling (line 90) | def patch_llama4_linearized_modeling(): FILE: src/axolotl/monkeypatch/models/mistral3/mistral_common_tokenizer.py function apply_mistral_tokenizer_image_patch (line 14) | def apply_mistral_tokenizer_image_patch(): FILE: src/axolotl/monkeypatch/models/pixtral/modeling_flash_attention_utils.py function apply_patch_is_packed_sequence (line 6) | def apply_patch_is_packed_sequence(): FILE: src/axolotl/monkeypatch/models/qwen3_5/modeling.py function get_cu_seqlens (line 24) | def get_cu_seqlens(position_ids): function _inject_fla_kernels (line 49) | def _inject_fla_kernels(module) -> None: function _patched_decoder_forward (line 68) | def _patched_decoder_forward( function _make_qwen3_5_gated_delta_forward (line 113) | def _make_qwen3_5_gated_delta_forward(apply_mask_fn): function _apply_packing_patches (line 240) | def _apply_packing_patches(model_type: str, cls_prefix: str, forward_fac... function patch_qwen3_5_modeling_packing (line 260) | def patch_qwen3_5_modeling_packing(): function patch_qwen3_5_moe_modeling_packing (line 264) | def patch_qwen3_5_moe_modeling_packing(): function patch_qwen3_5_vlm_flash_attention (line 270) | def patch_qwen3_5_vlm_flash_attention(): FILE: src/axolotl/monkeypatch/models/qwen3_next/modeling.py function get_cu_seqlens (line 18) | def get_cu_seqlens(position_ids): function patch_qwen3_next_decoder_layer (line 39) | def patch_qwen3_next_decoder_layer(): function patch_qwen3_next_gateddelta_layer (line 110) | def patch_qwen3_next_gateddelta_layer(): function patch_qwen3_next_imports (line 283) | def patch_qwen3_next_imports(): function patch_qwen3_next_modeling_packing (line 336) | def patch_qwen3_next_modeling_packing(): FILE: src/axolotl/monkeypatch/models/voxtral/modeling.py function patch_voxtral_conditional_generation_forward (line 10) | def patch_voxtral_conditional_generation_forward(): FILE: src/axolotl/monkeypatch/moe_quant.py class Bnb8bitParametrization (line 23) | class Bnb8bitParametrization(torch.nn.Module): method __init__ (line 26) | def __init__(self, row_stats: torch.Tensor): method forward (line 31) | def forward(self, quantized_param: torch.Tensor) -> torch.Tensor: function _enable_parametrization_cache (line 40) | def _enable_parametrization_cache(module, inputs): function _disable_parametrization_cache (line 44) | def _disable_parametrization_cache(module, inputs, output): function replace_parameter_8bit (line 50) | def replace_parameter_8bit(module, param_name): function patch_moe_quantization_on_load (line 71) | def patch_moe_quantization_on_load(cfg): function get_moe_quantized_count (line 146) | def get_moe_quantized_count(): function patch_peft_target_parameters_matching (line 151) | def patch_peft_target_parameters_matching(): FILE: src/axolotl/monkeypatch/multipack.py function patch_for_multipack (line 66) | def patch_for_multipack(model_type, model_name=None, has_remote_code=Fal... function patch_remote (line 80) | def patch_remote(model_name): FILE: src/axolotl/monkeypatch/peft/utils.py function get_peft_prep_code (line 32) | def get_peft_prep_code() -> str: function check_peft_prep_code_is_patchable (line 37) | def check_peft_prep_code_is_patchable() -> bool: function patch_peft_prep_code (line 43) | def patch_peft_prep_code(): FILE: src/axolotl/monkeypatch/relora.py function magnitude_pruning_ (line 33) | def magnitude_pruning_(tensor, prune_ratio): function reset_optimizer (line 43) | def reset_optimizer( class ReLoRACallback (line 81) | class ReLoRACallback(TrainerCallback): method __init__ (line 84) | def __init__(self, cfg: DictDefault): method on_train_begin (line 101) | def on_train_begin( method on_step_begin (line 120) | def on_step_begin( method on_save (line 182) | def on_save( method on_log (line 220) | def on_log( method on_train_end (line 231) | def on_train_end( function sharded_paths (line 255) | def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]: function lora_delta_weight (line 270) | def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torc... function find_lora_modules (line 289) | def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lo... function update_weights (line 305) | def update_weights( function merge_and_save (line 330) | def merge_and_save( function load_weight_checkpoint (line 420) | def load_weight_checkpoint(model: peft.LoraModel, checkpoint_path: str): FILE: src/axolotl/monkeypatch/ring_attn/adapters/batch.py function create_flash_attn_forward_varlen_llama3 (line 42) | def create_flash_attn_forward_varlen_llama3( function substitute_hf_flash_attn (line 156) | def substitute_hf_flash_attn( FILE: src/axolotl/monkeypatch/ring_attn/patch.py function get_ring_attn_group (line 37) | def get_ring_attn_group() -> dist.ProcessGroup: function set_ring_attn_group (line 44) | def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None): function create_ring_flash_attention_forward (line 50) | def create_ring_flash_attention_forward( function register_ring_attn_from_device_mesh (line 135) | def register_ring_attn_from_device_mesh( function update_ring_attn_params (line 214) | def update_ring_attn_params(position_ids: torch.Tensor | None): FILE: src/axolotl/monkeypatch/scaled_softmax_attn.py function patch_scaled_softmax_attention (line 29) | def patch_scaled_softmax_attention( function ssmax_flex_attention_forward (line 53) | def ssmax_flex_attention_forward( function unpatch_scaled_softmax_attention (line 133) | def unpatch_scaled_softmax_attention(): FILE: src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py function replace_stablelm_attn_with_flash_attn (line 42) | def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stable... function rotate_half (line 57) | def rotate_half(x: torch.Tensor): function apply_rotary_pos_emb (line 64) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): function repeat_kv (line 76) | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: function flashattn_attn (line 90) | def flashattn_attn( function decoder_layer_forward (line 200) | def decoder_layer_forward( function stablelm_model_forward (line 247) | def stablelm_model_forward( FILE: src/axolotl/monkeypatch/tiled_mlp/base.py class DeepSpeedTiledMLPMoE (line 11) | class DeepSpeedTiledMLPMoE(torch.autograd.Function): method forward (line 13) | def forward( method backward (line 47) | def backward(ctx, *grads) -> torch.Tensor: class TiledMLP (line 99) | class TiledMLP(torch.autograd.Function): method forward (line 105) | def forward( method backward (line 138) | def backward(ctx, *grads) -> torch.Tensor: class GradientAccumulator (line 191) | class GradientAccumulator: method __init__ (line 197) | def __init__( method install_hooks (line 223) | def install_hooks(self, is_last_shard: bool): method cleanup (line 251) | def cleanup(self): FILE: src/axolotl/monkeypatch/tiled_mlp/patch.py function patch_tiled_mlp (line 15) | def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=No... FILE: src/axolotl/monkeypatch/trainer/lr.py function _get_learning_rate (line 13) | def _get_learning_rate(self): function patch_trainer_get_lr (line 39) | def patch_trainer_get_lr(): FILE: src/axolotl/monkeypatch/trainer/trl.py function prepare_fsdp (line 4) | def prepare_fsdp(model, accelerator): function patch_trl_prepare_fsdp2 (line 10) | def patch_trl_prepare_fsdp2(): FILE: src/axolotl/monkeypatch/trainer/trl_vllm.py function _batch_update_named_params (line 20) | def _batch_update_named_params( function _update_model_params (line 65) | def _update_model_params(self, model: nn.Module, chunk_size: int | None ... function _patched_extract_logprobs (line 71) | def _patched_extract_logprobs(all_outputs): function _patched_split_tensor_dict (line 97) | def _patched_split_tensor_dict(tensor_dict, num_chunks): function _patched_shuffle_sequence_dict (line 121) | def _patched_shuffle_sequence_dict(seq_dict): function _patch_sync_weights_batched (line 146) | def _patch_sync_weights_batched(original_init): function _make_batched_sync_weights (line 157) | def _make_batched_sync_weights(original_sync_weights): function patch_trl_vllm (line 214) | def patch_trl_vllm(): FILE: src/axolotl/monkeypatch/trainer/utils.py function _entropy_online_kernel (line 22) | def _entropy_online_kernel( function _entropy_online_kernel_strided (line 62) | def _entropy_online_kernel_strided( function entropy_from_logits (line 108) | def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> ... function selective_log_softmax_original (line 174) | def selective_log_softmax_original(logits, index) -> torch.Tensor: function _selective_logsoftmax_fwd_kernel (line 199) | def _selective_logsoftmax_fwd_kernel( function _selective_logsoftmax_bwd_kernel (line 255) | def _selective_logsoftmax_bwd_kernel( class _SelectiveLogSoftmaxTriton (line 320) | class _SelectiveLogSoftmaxTriton(torch.autograd.Function): method forward (line 322) | def forward(ctx, flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_... method backward (line 352) | def backward(ctx, grad_output): function selective_log_softmax (line 390) | def selective_log_softmax(logits, index) -> torch.Tensor: FILE: src/axolotl/monkeypatch/trainer_accelerator_args.py function get_create_accelerate_code (line 30) | def get_create_accelerate_code() -> str: function check_create_accelerate_code_is_patchable (line 35) | def check_create_accelerate_code_is_patchable() -> bool: function patch_create_accelerate_code_for_fp8 (line 41) | def patch_create_accelerate_code_for_fp8(enable_fsdp_float8_all_gather: ... FILE: src/axolotl/monkeypatch/trainer_fsdp_optim.py function get_training_loop_code (line 25) | def get_training_loop_code() -> str: function check_training_loop_is_patchable (line 30) | def check_training_loop_is_patchable() -> bool: function patch_training_loop_for_fsdp (line 36) | def patch_training_loop_for_fsdp(): FILE: src/axolotl/monkeypatch/transformers/trainer_context_parallel.py function patch_prepare_context_parallel_inputs (line 19) | def patch_prepare_context_parallel_inputs() -> None: FILE: src/axolotl/monkeypatch/transformers/trainer_loss_calc.py function check_evaluation_loop_is_patchable (line 39) | def check_evaluation_loop_is_patchable() -> bool: function patch_evaluation_loop (line 44) | def patch_evaluation_loop(): function check_maybe_log_save_evaluate_is_patchable (line 95) | def check_maybe_log_save_evaluate_is_patchable() -> bool: function patch_maybe_log_save_evaluate (line 100) | def patch_maybe_log_save_evaluate(): FILE: src/axolotl/monkeypatch/transformers_fa_utils.py function fixed_fa_peft_integration_check (line 15) | def fixed_fa_peft_integration_check( function patch_fa_peft_integration (line 63) | def patch_fa_peft_integration(): FILE: src/axolotl/monkeypatch/unsloth_.py function original_apply_qkv (line 35) | def original_apply_qkv(self, hidden_states): function original_apply_o (line 42) | def original_apply_o(self, hidden_states): function get_self_attn_code (line 47) | def get_self_attn_code() -> str: function check_self_attn_is_patchable (line 52) | def check_self_attn_is_patchable() -> bool: function integrate_cross_entropy_loss_patch (line 58) | def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None: function patch_self_attn_lora (line 91) | def patch_self_attn_lora(): function integrate_rope_embeddings (line 130) | def integrate_rope_embeddings(): function integrate_lora_mlp_patch (line 148) | def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM): function integrate_lora_patch (line 183) | def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): function patch_unsloth_layernorm (line 228) | def patch_unsloth_layernorm(): FILE: src/axolotl/monkeypatch/utils.py function get_max_seqlen_in_batch (line 13) | def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor: function get_unpad_data (line 26) | def get_unpad_data(attention_mask: torch.Tensor): function get_cu_seqlens (line 43) | def get_cu_seqlens(attn_mask): function get_cu_seqlens_from_pos_ids (line 94) | def get_cu_seqlens_from_pos_ids( function set_module_name (line 155) | def set_module_name(model, name, value): function detab_code (line 168) | def detab_code(code: str) -> Tuple[str, str]: FILE: src/axolotl/monkeypatch/xformers_/__init__.py class FusedMLP (line 12) | class FusedMLP(torch.nn.Module): method __init__ (line 17) | def __init__( method _post_training (line 38) | def _post_training(self, model, name): method forward (line 51) | def forward(self, x: torch.Tensor) -> torch.Tensor: FILE: src/axolotl/processing_strategies.py class ProcessingStrategy (line 21) | class ProcessingStrategy: method __init__ (line 24) | def __init__( method __call__ (line 47) | def __call__(self, examples: list[dict]) -> list[dict]: method _mask_non_assistant (line 223) | def _mask_non_assistant(self, labels: Tensor) -> Tensor: method process_labels (line 230) | def process_labels(self, input_ids: Tensor) -> Tensor: class Qwen2VLProcessingStrategy (line 244) | class Qwen2VLProcessingStrategy(ProcessingStrategy): method __init__ (line 247) | def __init__( class Qwen3_5ProcessingStrategy (line 261) | class Qwen3_5ProcessingStrategy(ProcessingStrategy): method __init__ (line 264) | def __init__( method process_labels (line 281) | def process_labels(self, input_ids): class Gemma3ProcessingStrategy (line 287) | class Gemma3ProcessingStrategy(ProcessingStrategy): method __init__ (line 290) | def __init__( method process_labels (line 303) | def process_labels(self, input_ids): class Gemma3nProcessingStrategy (line 314) | class Gemma3nProcessingStrategy(ProcessingStrategy): method _mask_non_assistant (line 317) | def _mask_non_assistant(self, labels: Tensor) -> Tensor: method process_labels (line 389) | def process_labels(self, input_ids): class VoxtralProcessingStrategy (line 407) | class VoxtralProcessingStrategy(ProcessingStrategy): method __init__ (line 410) | def __init__( method process_labels (line 425) | def process_labels(self, input_ids): class SmolVLM2ProcessingStrategy (line 435) | class SmolVLM2ProcessingStrategy(ProcessingStrategy): method __init__ (line 438) | def __init__( class Mistral3ProcessingStrategy (line 453) | class Mistral3ProcessingStrategy(ProcessingStrategy): method __init__ (line 456) | def __init__( method process_labels (line 472) | def process_labels(self, input_ids): class InternVLProcessingStrategy (line 483) | class InternVLProcessingStrategy(ProcessingStrategy): method __init__ (line 486) | def __init__( method process_labels (line 500) | def process_labels(self, input_ids): class Glm4vProcessingStrategy (line 514) | class Glm4vProcessingStrategy(ProcessingStrategy): method __init__ (line 517) | def __init__( method process_labels (line 550) | def process_labels(self, input_ids): function get_processing_strategy (line 566) | def get_processing_strategy( FILE: src/axolotl/prompt_strategies/__init__.py function load (line 12) | def load(strategy, tokenizer, cfg, ds_cfg, processor=None): FILE: src/axolotl/prompt_strategies/alpaca_chat.py function load (line 12) | def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): class AlpacaConcisePrompter (line 25) | class AlpacaConcisePrompter(AlpacaPrompter): class AlpacaChatPrompter (line 34) | class AlpacaChatPrompter(AlpacaPrompter): method __init__ (line 42) | def __init__(self): class NoSystemPrompter (line 47) | class NoSystemPrompter(AlpacaPrompter): method __init__ (line 57) | def __init__(self): class AlpacaQAPromptTokenizingStrategy (line 61) | class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrate... method parse_instruction_fields (line 66) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: class CamelAIPromptTokenizingStrategy (line 74) | class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): method parse_instruction_fields (line 79) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: function load_concise (line 87) | def load_concise(tokenizer, cfg): function load_qa (line 96) | def load_qa(tokenizer, cfg): function load_camel_ai (line 105) | def load_camel_ai(tokenizer, cfg): function load_no_prompt (line 114) | def load_no_prompt(tokenizer, cfg): FILE: src/axolotl/prompt_strategies/alpaca_instruct.py function load (line 7) | def load(tokenizer, cfg): function load_no_prompt (line 16) | def load_no_prompt(tokenizer, cfg): FILE: src/axolotl/prompt_strategies/alpaca_w_system.py class InstructionWSystemPromptTokenizingStrategy (line 11) | class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy): method parse_instruction_fields (line 16) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]: method tokenize_prompt (line 24) | def tokenize_prompt(self, prompt): class SystemDataPrompter (line 55) | class SystemDataPrompter(AlpacaPrompter): method build_prompt_w_system (line 62) | def build_prompt_w_system( class OpenOrcaSystemDataPrompter (line 89) | class OpenOrcaSystemDataPrompter(SystemDataPrompter): method match_prompt_style (line 94) | def match_prompt_style(self): class OpenOrcaPromptTokenizingStrategy (line 111) | class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizin... method parse_instruction_fields (line 116) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]: function load (line 125) | def load(tokenizer, cfg): function load_instruct (line 129) | def load_instruct(tokenizer, cfg): function load_chat (line 138) | def load_chat(tokenizer, cfg): function load_open_orca (line 147) | def load_open_orca(tokenizer, cfg): function load_open_orca_chatml (line 156) | def load_open_orca_chatml(tokenizer, cfg): FILE: src/axolotl/prompt_strategies/base.py function load (line 12) | def load(strategy, cfg, module_base=None, **kwargs): FILE: src/axolotl/prompt_strategies/bradley_terry/__init__.py function load (line 12) | def load(strategy, tokenizer, cfg, ds_cfg): FILE: src/axolotl/prompt_strategies/bradley_terry/chat_template.py class BTChatTemplateStrategy (line 19) | class BTChatTemplateStrategy(ChatTemplateStrategy): method supports_batched (line 25) | def supports_batched(self) -> bool: method _tokenize_single_prompt (line 28) | def _tokenize_single_prompt(self, prompt): function load (line 83) | def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): FILE: src/axolotl/prompt_strategies/bradley_terry/llama3.py function icr (line 6) | def icr( FILE: src/axolotl/prompt_strategies/chat_template.py class ChatTemplatePrompter (line 28) | class ChatTemplatePrompter(Prompter): method __init__ (line 31) | def __init__( method chat_template_msg_variables (line 89) | def chat_template_msg_variables(self) -> Set[str]: method build_prompt (line 92) | def build_prompt( method get_offsets_for_train_detail (line 158) | def get_offsets_for_train_detail( method adjust_train_details (line 202) | def adjust_train_details( method get_chat_template_msg_variables (line 260) | def get_chat_template_msg_variables( class ChatTemplateStrategy (line 267) | class ChatTemplateStrategy(PromptTokenizingStrategy): method __init__ (line 272) | def __init__( method _validate_eot_and_eos_tokens (line 318) | def _validate_eot_and_eos_tokens(self): method supports_batched (line 384) | def supports_batched(self) -> bool: method is_prompt_batched (line 388) | def is_prompt_batched(self, prompt: dict[str, Any]) -> bool: method tokenize_prompt (line 396) | def tokenize_prompt(self, prompt: dict[str, Any]): method _tokenize_single_prompt (line 423) | def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]: method find_first_eos_token (line 587) | def find_first_eos_token(self, input_ids, start_idx): method find_first_eot_token (line 594) | def find_first_eot_token(self, input_ids, start_idx): method find_turn (line 613) | def find_turn( method get_conversation_thread (line 700) | def get_conversation_thread(self, prompt): method transform_message (line 733) | def transform_message(self, message: dict) -> dict: method _get_images (line 828) | def _get_images(self, prompt): method _get_tools (line 831) | def _get_tools(self, prompt) -> list[dict] | None: method _get_messages (line 862) | def _get_messages(self, prompt): class MistralStrategy (line 876) | class MistralStrategy(ChatTemplateStrategy): method __init__ (line 881) | def __init__( method find_first_eot_token (line 931) | def find_first_eot_token(self, input_ids, start_idx): class MistralPrompter (line 937) | class MistralPrompter(ChatTemplatePrompter): method __init__ (line 942) | def __init__(self, *args, **kwargs): class StrategyLoader (line 948) | class StrategyLoader: method _get_strategy_cls (line 953) | def _get_strategy_cls(self, cfg): method _get_prompter_cls (line 959) | def _get_prompter_cls(self, cfg): method _get_strategy_params (line 965) | def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]): method __call__ (line 976) | def __call__( FILE: src/axolotl/prompt_strategies/completion.py class CompletionPromptTokenizingStrategy (line 11) | class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStra... method __init__ (line 18) | def __init__(self, *args, max_length=None, **kwargs): method supports_batched (line 24) | def supports_batched(self): method field (line 28) | def field(self) -> str: method field (line 32) | def field(self, new_field: str): method parse_instruction_fields (line 35) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: method tokenize_prompt (line 42) | def tokenize_prompt(self, prompt): method _build_full_prompt (line 62) | def _build_full_prompt(self, instruction, input, response): class CompletionPrompter (line 66) | class CompletionPrompter: method build_prompt (line 71) | def build_prompt( function load (line 80) | def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): FILE: src/axolotl/prompt_strategies/context_qa.py function load_404 (line 10) | def load_404(tokenizer, cfg): function load (line 19) | def load(tokenizer, cfg): function load_v2 (line 28) | def load_v2(tokenizer, cfg): class AlpacaContextPrompter (line 37) | class AlpacaContextPrompter(AlpacaPrompter): class AlpacaContextPromptTokenizingStrategy (line 50) | class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingS... method parse_instruction_fields (line 55) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: class ContextQaV2PromptTokenizingStrategy (line 63) | class ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStr... method parse_instruction_fields (line 68) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: class ContextV2Prompter (line 80) | class ContextV2Prompter(AlpacaPrompter): method match_prompt_style (line 88) | def match_prompt_style(self): class AlpacaMissingInfoContextPromptTokenizingStrategy (line 94) | class AlpacaMissingInfoContextPromptTokenizingStrategy( method parse_instruction_fields (line 102) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: FILE: src/axolotl/prompt_strategies/creative_acr.py class CreativeAnsweringPromptTokenizingStrategy (line 10) | class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokeniz... method parse_instruction_fields (line 15) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: class CreativeCritiquePromptTokenizingStrategy (line 27) | class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizi... method parse_instruction_fields (line 63) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: class CreativeRevisePromptTokenizingStrategy (line 84) | class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizing... method parse_instruction_fields (line 103) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: class CreativePrompterBase (line 126) | class CreativePrompterBase: method build_prompt (line 134) | def build_prompt( class CreativeAnswerPrompter (line 149) | class CreativeAnswerPrompter(CreativePrompterBase): class CreativeCritiquePrompter (line 157) | class CreativeCritiquePrompter(CreativePrompterBase): class CreativeRevisePrompter (line 165) | class CreativeRevisePrompter(CreativePrompterBase): function load_answer (line 173) | def load_answer(tokenizer, cfg): function load_critique (line 182) | def load_critique(tokenizer, cfg): function load_revise (line 191) | def load_revise(tokenizer, cfg): FILE: src/axolotl/prompt_strategies/dpo/chat_template.py function default (line 9) | def default(cfg, dataset_idx=0, **kwargs): function argilla_chat (line 125) | def argilla_chat(cfg, dataset_idx=0, **kwargs): FILE: src/axolotl/prompt_strategies/dpo/chatml.py function default (line 6) | def default( function argilla_chat (line 46) | def argilla_chat( function icr (line 65) | def icr( function intel (line 91) | def intel(cfg, **kwargs): function prompt_pairs (line 113) | def prompt_pairs(cfg, **kwargs): function ultra (line 131) | def ultra(cfg, **kwargs): FILE: src/axolotl/prompt_strategies/dpo/llama3.py function default (line 6) | def default( function argilla_chat (line 46) | def argilla_chat( function icr (line 65) | def icr( function intel (line 91) | def intel(cfg, **kwargs): function prompt_pairs (line 113) | def prompt_pairs(cfg, **kwargs): function ultra (line 131) | def ultra(cfg, **kwargs): FILE: src/axolotl/prompt_strategies/dpo/passthrough.py function default (line 6) | def default(cfg, dataset_idx=0, **kwargs): FILE: src/axolotl/prompt_strategies/dpo/user_defined.py function default (line 6) | def default(cfg, dataset_idx=0, **kwargs): FILE: src/axolotl/prompt_strategies/dpo/zephyr.py function nectar (line 6) | def nectar(cfg, **kwargs): FILE: src/axolotl/prompt_strategies/input_output.py class RawInputOutputStrategy (line 9) | class RawInputOutputStrategy(PromptTokenizingStrategy): method __init__ (line 12) | def __init__(self, *args, eos_token=None, **kwargs): method tokenize_prompt (line 18) | def tokenize_prompt(self, prompt): class RawInputOutputPrompter (line 40) | class RawInputOutputPrompter(Prompter): method build_prompt (line 43) | def build_prompt(self, source) -> Generator[Tuple[bool, str], None, No... function load (line 48) | def load(tokenizer, cfg): FILE: src/axolotl/prompt_strategies/jinja_template_analyzer.py class JinjaTemplateAnalysis (line 9) | class JinjaTemplateAnalysis(TypedDict): class GenerationTagIgnore (line 31) | class GenerationTagIgnore(Extension): method parse (line 38) | def parse(self, parser): class JinjaTemplateAnalyzer (line 43) | class JinjaTemplateAnalyzer: method __init__ (line 72) | def __init__(self, template: str): method _visit_node (line 83) | def _visit_node(self, node) -> None: method _get_target_name (line 132) | def _get_target_name(self, node) -> Optional[str]: method _get_target_names (line 146) | def _get_target_names(self, node) -> list[str]: method _get_base_name (line 167) | def _get_base_name(self, node) -> Optional[str]: method get_template_variables (line 180) | def get_template_variables(self) -> Dict[str, Set[str]]: method analyze_template (line 212) | def analyze_template(self) -> Dict[str, JinjaTemplateAnalysis]: method get_downstream_properties (line 269) | def get_downstream_properties(self, start_var: str) -> Dict[str, Set[s... method get_message_vars (line 318) | def get_message_vars(self, field_messages: str = "messages") -> Set[str]: FILE: src/axolotl/prompt_strategies/kto/chatml.py function argilla (line 6) | def argilla( function argilla_chat (line 26) | def argilla_chat( function intel (line 44) | def intel(cfg, **kwargs): function prompt_pairs (line 66) | def prompt_pairs(cfg, **kwargs): function ultra (line 83) | def ultra(cfg, **kwargs): FILE: src/axolotl/prompt_strategies/kto/llama3.py function argilla (line 6) | def argilla( function argilla_chat (line 26) | def argilla_chat( function intel (line 44) | def intel(cfg, **kwargs): function prompt_pairs (line 66) | def prompt_pairs(cfg, **kwargs): function ultra (line 83) | def ultra(cfg, **kwargs): FILE: src/axolotl/prompt_strategies/kto/user_defined.py function default (line 6) | def default(cfg, dataset_idx=0, **kwargs): FILE: src/axolotl/prompt_strategies/llama2_chat.py class Llama2ChatConversation (line 38) | class Llama2ChatConversation: method get_prompt (line 58) | def get_prompt(self) -> str: method append_message (line 73) | def append_message(self, role: str, message: str): class LLama2ChatTokenizingStrategy (line 78) | class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy): method __init__ (line 84) | def __init__(self, *args, **kwargs): method tokenize_prompt (line 91) | def tokenize_prompt(self, prompt): class Llama2ChatPrompter (line 156) | class Llama2ChatPrompter: method build_prompt (line 169) | def build_prompt(self, source) -> Generator[Llama2ChatConversation, No... function load (line 202) | def load(tokenizer, cfg) -> LLama2ChatTokenizingStrategy: FILE: src/axolotl/prompt_strategies/messages/__init__.py function load (line 11) | def load(tokenizer, cfg, ds_cfg, processor=None): FILE: src/axolotl/prompt_strategies/messages/chat.py class ChatMessageDatasetWrappingStrategy (line 12) | class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy): method __init__ (line 17) | def __init__( method wrap_dataset (line 33) | def wrap_dataset( function load (line 51) | def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): FILE: src/axolotl/prompt_strategies/metharme.py class MetharmePromptTokenizingStrategy (line 14) | class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrate... method parse_instruction_fields (line 19) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: method _tokenize (line 22) | def _tokenize( class MetharmePrompter (line 56) | class MetharmePrompter(AlpacaPrompter): method __init__ (line 67) | def __init__(self, *args, **kwargs): function load (line 71) | def load(tokenizer, cfg): FILE: src/axolotl/prompt_strategies/orcamini.py class OrcaMiniPrompter (line 19) | class OrcaMiniPrompter(AlpacaPrompter): method match_prompt_style (line 22) | def match_prompt_style(self): method build_prompt_w_system (line 27) | def build_prompt_w_system( function load (line 41) | def load(tokenizer, cfg): FILE: src/axolotl/prompt_strategies/orpo/chat_template.py class Message (line 12) | class Message(BaseModel): class MessageList (line 20) | class MessageList(BaseModel): function load (line 26) | def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwar... class ORPODatasetParsingStrategy (line 44) | class ORPODatasetParsingStrategy: method get_chosen_conversation_thread (line 47) | def get_chosen_conversation_thread(self, prompt) -> MessageList: method get_rejected_conversation_thread (line 63) | def get_rejected_conversation_thread(self, prompt) -> MessageList: method get_prompt (line 79) | def get_prompt(self, prompt) -> MessageList: method get_chosen (line 112) | def get_chosen(self, prompt) -> MessageList: method get_rejected (line 121) | def get_rejected(self, prompt) -> MessageList: class ORPOTokenizingStrategy (line 131) | class ORPOTokenizingStrategy(PromptTokenizingStrategy): method __init__ (line 141) | def __init__( method tokenize_prompt (line 150) | def tokenize_prompt(self, prompt): class ORPOPrompter (line 205) | class ORPOPrompter(Prompter): method __init__ (line 208) | def __init__(self, chat_template, tokenizer): method build_prompt (line 212) | def build_prompt( function argilla (line 251) | def argilla(cfg, **kwargs): FILE: src/axolotl/prompt_strategies/pretrain.py class PretrainTokenizer (line 10) | class PretrainTokenizer: method build_prompt (line 13) | def build_prompt(self, prompt) -> Generator[str, None, None]: class PretrainTokenizationStrategy (line 17) | class PretrainTokenizationStrategy(PromptTokenizingStrategy): method supports_batched (line 21) | def supports_batched(self): method __init__ (line 24) | def __init__(self, *args, max_length=None, text_column="text", **kwargs): method _tokenize (line 30) | def _tokenize( method tokenize_prompt (line 48) | def tokenize_prompt(self, prompt): function load (line 52) | def load(tokenizer, cfg): FILE: src/axolotl/prompt_strategies/pygmalion.py class PygmalionPromptTokenizingStrategy (line 19) | class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): method __init__ (line 26) | def __init__(self, prompter, tokenizer, *args, **kwargs): method tokenize_prompt (line 31) | def tokenize_prompt(self, prompt): class PygmalionPrompter (line 82) | class PygmalionPrompter: method __init__ (line 87) | def __init__(self, *args, **kwargs): method build_prompt (line 90) | def build_prompt( function load (line 100) | def load(tokenizer, cfg): FILE: src/axolotl/prompt_strategies/stepwise_supervised.py class StepwiseSupervisedPromptTokenizingStrategy (line 15) | class StepwiseSupervisedPromptTokenizingStrategy: method __init__ (line 24) | def __init__( method tokenize_prompt (line 38) | def tokenize_prompt( method supports_batched (line 101) | def supports_batched(self): function load (line 105) | def load( FILE: src/axolotl/prompt_strategies/user_defined.py class UserDefinedDatasetConfig (line 16) | class UserDefinedDatasetConfig: method __getitem__ (line 30) | def __getitem__(self, item): class UserDefinedPromptTokenizationStrategy (line 34) | class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptToke... function load (line 40) | def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = No... FILE: src/axolotl/prompt_tokenizers.py class InvalidDataException (line 21) | class InvalidDataException(Exception): class DatasetWrappingStrategy (line 27) | class DatasetWrappingStrategy(abc.ABC): method wrap_dataset (line 33) | def wrap_dataset( class PromptTokenizingStrategy (line 43) | class PromptTokenizingStrategy(abc.ABC): method __init__ (line 50) | def __init__( method tokenize_prompt (line 66) | def tokenize_prompt(self, prompt): method supports_batched (line 70) | def supports_batched(self): method _tokenize (line 73) | def _tokenize( class InstructionPromptTokenizingStrategy (line 108) | class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): method parse_instruction_fields (line 113) | def parse_instruction_fields( method tokenize_prompt (line 118) | def tokenize_prompt(self, prompt): method _build_full_prompt (line 146) | def _build_full_prompt( class AlpacaPromptTokenizingStrategy (line 163) | class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): method parse_instruction_fields (line 168) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: class AlpacaMultipleChoicePromptTokenizingStrategy (line 176) | class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptToke... method parse_instruction_fields (line 181) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: class JeopardyPromptTokenizingStrategy (line 189) | class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrate... method parse_instruction_fields (line 194) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: class OpenAssistantPromptTokenizingStrategy (line 202) | class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingS... method parse_instruction_fields (line 207) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: class SummarizeTLDRPromptTokenizingStrategy (line 215) | class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingS... method parse_instruction_fields (line 220) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: class GPTeacherPromptTokenizingStrategy (line 228) | class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrat... method parse_instruction_fields (line 233) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: class NomicGPT4AllPromptTokenizingStrategy (line 241) | class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingSt... method parse_instruction_fields (line 246) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: class ReflectionPromptTokenizingStrategy (line 254) | class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): method parse_instruction_fields (line 259) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str... method tokenize_prompt (line 262) | def tokenize_prompt(self, prompt): method _build_full_prompt (line 292) | def _build_full_prompt(self, instruction, input, output, reflection, c... method _tokenize (line 305) | def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): class AlpacaReflectionPTStrategy (line 325) | class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy): method parse_instruction_fields (line 330) | def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str... function tokenize_prompt_default (line 340) | def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]: function parse_tokenized_to_result (line 354) | def parse_tokenized_to_result( FILE: src/axolotl/prompters.py class PromptStyle (line 15) | class PromptStyle(Enum): class Prompter (line 26) | class Prompter: class AlpacaPrompter (line 32) | class AlpacaPrompter(Prompter): method __init__ (line 44) | def __init__(self, prompt_style: Optional[str] = PromptStyle.INSTRUCT.... method match_prompt_style (line 48) | def match_prompt_style(self): method _build_result (line 72) | def _build_result(self, instruction, input_text, output): method build_prompt (line 92) | def build_prompt( method __repr__ (line 100) | def __repr__(self) -> str: class UnpromptedPrompter (line 106) | class UnpromptedPrompter(AlpacaPrompter): class JeopardyPrompter (line 115) | class JeopardyPrompter(AlpacaPrompter): class MultipleChoiceExplainPrompter (line 123) | class MultipleChoiceExplainPrompter(AlpacaPrompter): class MultipleChoiceConcisePrompter (line 136) | class MultipleChoiceConcisePrompter(AlpacaPrompter): method match_prompt_style (line 144) | def match_prompt_style(self): class SummarizeTLDRPrompter (line 149) | class SummarizeTLDRPrompter(AlpacaPrompter): method match_prompt_style (line 157) | def match_prompt_style(self): class GPTeacherPrompter (line 162) | class GPTeacherPrompter(AlpacaPrompter): class NomicGPT4AllPrompter (line 168) | class NomicGPT4AllPrompter(AlpacaPrompter): class ReflectAlpacaPrompter (line 174) | class ReflectAlpacaPrompter(Prompter): method __init__ (line 189) | def __init__(self, prompt_style="instruct"): method match_prompt_style (line 193) | def match_prompt_style(self): method _build_result (line 217) | def _build_result( method build_prompt (line 241) | def build_prompt( method __repr__ (line 257) | def __repr__(self) -> str: class UnsupportedPrompter (line 268) | class UnsupportedPrompter(Prompter): method __init__ (line 273) | def __init__(self) -> None: method __repr__ (line 276) | def __repr__(self): FILE: src/axolotl/scripts/vllm_serve_lora.py class LoRAScriptArguments (line 38) | class LoRAScriptArguments(ScriptArguments): function llm_worker (line 59) | def llm_worker( function main (line 125) | def main(script_args: ScriptArguments): FILE: src/axolotl/scripts/vllm_worker_ext.py class BatchWeightSyncWorkerExtension (line 33) | class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension): method init_communicator (line 36) | def init_communicator(self, host, port, world_size, client_device_uuid): method _direct_set_weight (line 42) | def _direct_set_weight(self, name: str, weight: torch.Tensor) -> None: method update_named_param (line 112) | def update_named_param(self, name, dtype, shape): method batch_update_named_params (line 129) | def batch_update_named_params(self, params_list: list[tuple[str, str, ... FILE: src/axolotl/telemetry/callbacks.py class TelemetryCallback (line 21) | class TelemetryCallback(TrainerCallback): method __init__ (line 31) | def __init__(self): method on_train_begin (line 41) | def on_train_begin( method on_train_end (line 52) | def on_train_end( method on_epoch_begin (line 68) | def on_epoch_begin( method on_epoch_end (line 80) | def on_epoch_end( method on_step_end (line 91) | def on_step_end( method _extract_last_metrics (line 155) | def _extract_last_metrics(self, state: TrainerState) -> dict: FILE: src/axolotl/telemetry/errors.py function sanitize_stack_trace (line 18) | def sanitize_stack_trace(stack_trace: str) -> str: function send_errors (line 104) | def send_errors(func: Callable) -> Callable: FILE: src/axolotl/telemetry/manager.py function is_main_process (line 61) | def is_main_process() -> bool: class TelemetryManager (line 98) | class TelemetryManager: method __new__ (line 104) | def __new__(cls): method __init__ (line 115) | def __init__(self): method get_instance (line 140) | def get_instance(cls) -> "TelemetryManager": method _check_telemetry_enabled (line 146) | def _check_telemetry_enabled(self) -> bool: method _load_whitelist (line 187) | def _load_whitelist(self) -> dict: method _is_whitelisted (line 199) | def _is_whitelisted(self, value: str) -> bool: method _init_posthog (line 220) | def _init_posthog(self): method _redact_paths (line 226) | def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]: method _get_system_info (line 267) | def _get_system_info(self) -> dict[str, Any]: method send_event (line 359) | def send_event(self, event_type: str, properties: dict[str, Any] | Non... method send_system_info (line 387) | def send_system_info(self): method shutdown (line 392) | def shutdown(self): FILE: src/axolotl/telemetry/runtime_metrics.py class RuntimeMetrics (line 17) | class RuntimeMetrics: method __post_init__ (line 34) | def __post_init__(self): method elapsed_time (line 41) | def elapsed_time(self) -> float: method epoch_time (line 45) | def epoch_time(self, epoch: int) -> float | None: method average_epoch_time (line 52) | def average_epoch_time(self) -> float | None: method steps_per_second (line 68) | def steps_per_second(self) -> float | None: method to_dict (line 75) | def to_dict(self) -> dict[str, Any]: class RuntimeMetricsTracker (line 112) | class RuntimeMetricsTracker: method __init__ (line 117) | def __init__(self): method start_epoch (line 123) | def start_epoch(self, epoch: int): method end_epoch (line 129) | def end_epoch(self, epoch: int): method update_step (line 133) | def update_step(self, step: int): method _get_allocated_memory (line 142) | def _get_allocated_memory(self) -> dict[int, int]: method update_memory_metrics (line 182) | def update_memory_metrics(self): method get_memory_metrics (line 195) | def get_memory_metrics(self) -> dict[str, Any]: FILE: src/axolotl/train.py function setup_model_and_tokenizer (line 54) | def setup_model_and_tokenizer( function setup_reference_model (line 120) | def setup_reference_model( function setup_signal_handler (line 149) | def setup_signal_handler(cfg: DictDefault, model: PreTrainedModel): function execute_training (line 175) | def execute_training( function save_trained_model (line 224) | def save_trained_model( function create_model_card (line 354) | def create_model_card(cfg: DictDefault, trainer: Trainer): function save_initial_configs (line 390) | def save_initial_configs( function setup_model_card (line 430) | def setup_model_card(cfg: DictDefault): function handle_untrained_tokens_fix (line 447) | def handle_untrained_tokens_fix( function setup_model_and_trainer (line 487) | def setup_model_and_trainer( function train (line 558) | def train( FILE: src/axolotl/utils/__init__.py function is_mlflow_available (line 12) | def is_mlflow_available(): function is_comet_available (line 16) | def is_comet_available(): function is_opentelemetry_available (line 20) | def is_opentelemetry_available(): function is_trackio_available (line 27) | def is_trackio_available(): function get_pytorch_version (line 31) | def get_pytorch_version() -> tuple[int, int, int]: function set_pytorch_cuda_alloc_conf (line 47) | def set_pytorch_cuda_alloc_conf(): function set_misc_env (line 66) | def set_misc_env(): function get_not_null (line 71) | def get_not_null(value, default=None): FILE: src/axolotl/utils/bench.py function check_cuda_device (line 25) | def check_cuda_device(default_value): function gpu_memory_usage (line 54) | def gpu_memory_usage(device=0): function gpu_memory_usage_all (line 59) | def gpu_memory_usage_all(device=0): function mps_memory_usage_all (line 67) | def mps_memory_usage_all(): function npu_memory_usage_all (line 73) | def npu_memory_usage_all(device=0): function gpu_memory_usage_smi (line 80) | def gpu_memory_usage_smi(device=0): function get_gpu_memory_usage (line 96) | def get_gpu_memory_usage(device: int | torch.device = 0): function log_gpu_memory_usage (line 110) | def log_gpu_memory_usage( FILE: src/axolotl/utils/callbacks/__init__.py class LossWatchDogCallback (line 57) | class LossWatchDogCallback(TrainerCallback): method __init__ (line 60) | def __init__(self, cfg): method on_step_end (line 66) | def on_step_end( class SaveModelOnFirstStepCallback (line 86) | class SaveModelOnFirstStepCallback(TrainerCallback): method on_step_end (line 89) | def on_step_end( function bench_eval_callback_factory (line 101) | def bench_eval_callback_factory(trainer, tokenizer): function causal_lm_bench_eval_callback_factory (line 303) | def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer): function log_prediction_callback_factory (line 512) | def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger:... class SaveAxolotlConfigtoWandBCallback (line 731) | class SaveAxolotlConfigtoWandBCallback(TrainerCallback): method __init__ (line 734) | def __init__(self, axolotl_config_path): method on_train_begin (line 737) | def on_train_begin( class GCCallback (line 829) | class GCCallback(TrainerCallback): method __init__ (line 832) | def __init__(self, gc_steps: int | None = -1): method _gc (line 836) | def _gc(self): method on_train_begin (line 840) | def on_train_begin( method on_step_begin (line 849) | def on_step_begin( method on_step_end (line 859) | def on_step_end( method on_epoch_end (line 885) | def on_epoch_end( function colab_inference_post_train_callback (line 895) | def colab_inference_post_train_callback(trainer: Trainer): FILE: src/axolotl/utils/callbacks/comet_.py class SaveAxolotlConfigtoCometCallback (line 17) | class SaveAxolotlConfigtoCometCallback(TrainerCallback): method __init__ (line 20) | def __init__(self, axolotl_config_path): method on_train_begin (line 23) | def on_train_begin( FILE: src/axolotl/utils/callbacks/dynamic_checkpoint.py class DynamicCheckpointCallback (line 22) | class DynamicCheckpointCallback(TrainerCallback): method _get_config_value (line 34) | def _get_config_value(self, config, key, default=None): method __init__ (line 40) | def __init__(self, cfg): method on_step_end (line 64) | def on_step_end( FILE: src/axolotl/utils/callbacks/generation.py class SFTGenerationCallback (line 12) | class SFTGenerationCallback(TrainerCallback): method __init__ (line 15) | def __init__(self, trainer): method on_evaluate (line 18) | def on_evaluate( method _log_samples (line 62) | def _log_samples(self, samples: list, step: int): FILE: src/axolotl/utils/callbacks/lisa.py function lisa_callback_factory (line 23) | def lisa_callback_factory(trainer: "AxolotlTrainer"): FILE: src/axolotl/utils/callbacks/mlflow_.py function should_log_artifacts (line 20) | def should_log_artifacts() -> bool: class SaveAxolotlConfigtoMlflowCallback (line 25) | class SaveAxolotlConfigtoMlflowCallback(TrainerCallback): method __init__ (line 28) | def __init__(self, axolotl_config_path): method on_train_begin (line 31) | def on_train_begin( FILE: src/axolotl/utils/callbacks/models.py function get_causal_lm_model_cls_prefix (line 8) | def get_causal_lm_model_cls_prefix(model_type: str) -> Tuple[str, str]: FILE: src/axolotl/utils/callbacks/opentelemetry.py class OpenTelemetryMetricsCallback (line 30) | class OpenTelemetryMetricsCallback(TrainerCallback): method __init__ (line 45) | def __init__(self, cfg): method _create_metrics (line 76) | def _create_metrics(self): method _start_metrics_server (line 120) | def _start_metrics_server(self): method on_train_begin (line 135) | def on_train_begin( method on_log (line 149) | def on_log( method on_step_end (line 178) | def on_step_end( method on_evaluate (line 194) | def on_evaluate( method on_train_end (line 224) | def on_train_end( FILE: src/axolotl/utils/callbacks/perplexity.py class Perplexity (line 19) | class Perplexity: method __init__ (line 25) | def __init__( method _feature_names (line 36) | def _feature_names(self) -> List[str]: method compute (line 39) | def compute( FILE: src/axolotl/utils/callbacks/profiler.py class PytorchProfilerCallback (line 17) | class PytorchProfilerCallback(TrainerCallback): method __init__ (line 24) | def __init__(self, steps_to_profile: int = 5, profiler_steps_start: in... method on_step_begin (line 36) | def on_step_begin( method on_step_end (line 60) | def on_step_end( method on_train_end (line 82) | def on_train_end( FILE: src/axolotl/utils/callbacks/qat.py function toggle_fake_quant (line 16) | def toggle_fake_quant(mod: nn.Module, enable: bool): class QATCallback (line 33) | class QATCallback(TrainerCallback): method __init__ (line 38) | def __init__(self, cfg: QATConfig): method on_step_begin (line 41) | def on_step_begin(self, args, state, control, model, **kwargs): FILE: src/axolotl/utils/callbacks/swanlab.py class CustomSwanLabCallback (line 26) | class CustomSwanLabCallback(TrainerCallback): method __init__ (line 34) | def __init__(self): method setup (line 38) | def setup(self): method on_train_begin (line 58) | def on_train_begin( method on_log (line 95) | def on_log( method on_train_end (line 129) | def on_train_end( class SaveAxolotlConfigtoSwanLabCallback (line 146) | class SaveAxolotlConfigtoSwanLabCallback(TrainerCallback): method __init__ (line 149) | def __init__(self, axolotl_config_path): method on_train_begin (line 152) | def on_train_begin( FILE: src/axolotl/utils/callbacks/tokens_per_second.py class TokensPerSecondCallback (line 22) | class TokensPerSecondCallback(TrainerCallback): method __init__ (line 28) | def __init__( method on_train_begin (line 41) | def on_train_begin( method on_step_begin (line 61) | def on_step_begin( method on_step_end (line 73) | def on_step_end( method on_log (line 95) | def on_log( FILE: src/axolotl/utils/callbacks/trackio_.py class SaveAxolotlConfigtoTrackioCallback (line 18) | class SaveAxolotlConfigtoTrackioCallback(TrainerCallback): method __init__ (line 21) | def __init__(self, axolotl_config_path): method on_train_begin (line 24) | def on_train_begin( FILE: src/axolotl/utils/chat_templates/base.py function get_chat_template (line 26) | def get_chat_template( function extract_chat_template_args (line 88) | def extract_chat_template_args(cfg, ds_cfg: Dict[str, Any] | None = None): function get_chat_template_from_config (line 98) | def get_chat_template_from_config( function register_chat_template (line 113) | def register_chat_template(template_name: str, chat_template: str): FILE: src/axolotl/utils/collators/batching.py class DataCollatorForSeq2Seq (line 12) | class DataCollatorForSeq2Seq: method __call__ (line 55) | def __call__(self, features, return_tensors=None): class BatchSamplerDataCollatorForSeq2Seq (line 129) | class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): method __call__ (line 134) | def __call__(self, features, return_tensors=None): class V2BatchSamplerDataCollatorForSeq2Seq (line 159) | class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): method __call__ (line 166) | def __call__(self, features, return_tensors=None): class PretrainingBatchSamplerDataCollatorForSeq2Seq (line 200) | class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2S... method __init__ (line 205) | def __init__(self, *args, multipack_attn=True, **kwargs): method __call__ (line 209) | def __call__(self, features, return_tensors=None): FILE: src/axolotl/utils/collators/mamba.py class MambaDataCollator (line 15) | class MambaDataCollator: method __call__ (line 22) | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: FILE: src/axolotl/utils/collators/mm_chat.py class MultiModalChatDataCollator (line 17) | class MultiModalChatDataCollator(DataCollatorMixin): method __post_init__ (line 29) | def __post_init__(self): method torch_call (line 33) | def torch_call(self, examples: list[dict]) -> dict[str, Any]: method process_rows (line 36) | def process_rows( FILE: src/axolotl/utils/comet_.py function python_value_to_environ_value (line 45) | def python_value_to_environ_value(python_value): function setup_comet_env_vars (line 61) | def setup_comet_env_vars(cfg: DictDefault): FILE: src/axolotl/utils/config/__init__.py function choose_device (line 30) | def choose_device(cfg): function resolve_dtype (line 64) | def resolve_dtype(cfg): function normalize_config (line 107) | def normalize_config(cfg): function normalize_cfg_datasets (line 269) | def normalize_cfg_datasets(cfg): function validate_config (line 288) | def validate_config( function prepare_plugins (line 337) | def prepare_plugins(cfg): FILE: src/axolotl/utils/ctx_managers/sequence_parallel.py function apply_sequence_parallelism (line 24) | def apply_sequence_parallelism( class SequenceParallelContextManager (line 170) | class SequenceParallelContextManager: method __init__ (line 189) | def __init__( method __enter__ (line 233) | def __enter__(self): method __exit__ (line 238) | def __exit__(self, exc_type, exc_val, exc_tb): method _register_ring_attn (line 246) | def _register_ring_attn(self): method _register_model_hooks (line 255) | def _register_model_hooks(self): method _gather_outputs (line 359) | def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMO... class AllGatherWithGrad (line 368) | class AllGatherWithGrad(torch.autograd.Function): method forward (line 372) | def forward( method backward (line 419) | def backward( FILE: src/axolotl/utils/data/lock.py class FileLockLoader (line 17) | class FileLockLoader: method __init__ (line 24) | def __init__(self, cfg: DictDefault): method load (line 33) | def load(self, load_fn: Callable[[], Any]) -> Any: method _increment_counter (line 46) | def _increment_counter(self): method cleanup (line 55) | def cleanup(self): FILE: src/axolotl/utils/data/rl.py function prepare_preference_datasets (line 37) | def prepare_preference_datasets( function _map_dataset (line 86) | def _map_dataset( function _drop_long_sequences (line 125) | def _drop_long_sequences( function _load_split (line 182) | def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Da... function _load_or_create_dataset_split (line 262) | def _load_or_create_dataset_split( FILE: src/axolotl/utils/data/sft.py function prepare_datasets (line 48) | def prepare_datasets( function _prepare_standard_dataset (line 68) | def _prepare_standard_dataset( function _prepare_streaming_dataset (line 125) | def _prepare_streaming_dataset( function _extract_pretraining_config (line 176) | def _extract_pretraining_config(cfg: DictDefault) -> DictDefault: function _load_streaming_dataset (line 205) | def _load_streaming_dataset( function _create_placeholder_dataset (line 251) | def _create_placeholder_dataset() -> IterableDataset: function _load_tokenized_prepared_datasets (line 260) | def _load_tokenized_prepared_datasets( function _load_raw_datasets (line 311) | def _load_raw_datasets( function _load_and_process_single_dataset (line 369) | def _load_and_process_single_dataset( function _parse_dataset_type (line 420) | def _parse_dataset_type(d_type: str) -> tuple[str | None, str | None]: function _handle_train_dataset_split (line 432) | def _handle_train_dataset_split( function _apply_dataset_sharding (line 451) | def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset: function _load_and_prepare_datasets (line 472) | def _load_and_prepare_datasets( FILE: src/axolotl/utils/data/shared.py function get_dataset_type (line 48) | def get_dataset_type(dataset_config: DictDefault) -> str: function datasets_with_name_generator (line 60) | def datasets_with_name_generator( function load_dataset_with_config (line 93) | def load_dataset_with_config( function _check_if_hub_dataset (line 151) | def _check_if_hub_dataset(dataset_config: DictDefault, use_auth_token: b... function _get_remote_filesystem (line 173) | def _get_remote_filesystem( function _load_from_local_path (line 222) | def _load_from_local_path( function _load_from_hub (line 258) | def _load_from_hub( function _load_from_cloud (line 271) | def _load_from_cloud( function _load_from_url (line 298) | def _load_from_url( function _load_from_data_files (line 310) | def _load_from_data_files( function generate_split_fingerprints (line 339) | def generate_split_fingerprints( function get_prepared_dataset_path (line 354) | def get_prepared_dataset_path(cfg: DictDefault, dataset_hash: str) -> Path: function create_train_validation_split (line 368) | def create_train_validation_split( function _generate_from_iterable_dataset (line 400) | def _generate_from_iterable_dataset( function save_preprocessed_dataset (line 409) | def save_preprocessed_dataset( function load_preprocessed_dataset (line 457) | def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Da... function try_load_from_hub (line 486) | def try_load_from_hub( function generate_dataset_hash_from_config (line 506) | def generate_dataset_hash_from_config( function merge_datasets (line 529) | def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset: FILE: src/axolotl/utils/data/streaming.py function encode_streaming (line 20) | def encode_streaming( function wrap_streaming_dataset (line 179) | def wrap_streaming_dataset( function encode_packed_streaming (line 254) | def encode_packed_streaming( FILE: src/axolotl/utils/data/utils.py class RetryStrategy (line 23) | class RetryStrategy(Enum): function retry_on_request_exceptions (line 31) | def retry_on_request_exceptions( function md5 (line 73) | def md5(to_hash: str, encoding: str = "utf-8") -> str: function sha256 (line 81) | def sha256(to_hash: str, encoding: str = "utf-8") -> str: function _deduplicate_dataset (line 86) | def _deduplicate_dataset( function deduplicate_and_log_datasets (line 112) | def deduplicate_and_log_datasets( function keep_min_len (line 151) | def keep_min_len(sample, min_sequence_len=2): function truncate_long_seq (line 167) | def truncate_long_seq(sample, sequence_len=2048): function _should_skip_processing (line 188) | def _should_skip_processing(dataset: Dataset) -> bool: function _log_dataset_stats (line 208) | def _log_dataset_stats(dataset: Dataset) -> None: function _build_filter_kwargs (line 216) | def _build_filter_kwargs(dataset: Dataset, cfg: DictDefault) -> dict: function _filter_short_sequences (line 225) | def _filter_short_sequences( function _truncate_long_sequences (line 251) | def _truncate_long_sequences( function _drop_outside_range (line 269) | def _drop_outside_range( function handle_long_seq_in_dataset (line 314) | def handle_long_seq_in_dataset( FILE: src/axolotl/utils/data/wrappers.py function handle_unknown_dataset_strategy (line 45) | def handle_unknown_dataset_strategy(dataset_config: DictDefault) -> NoRe... function get_dataset_wrapper (line 57) | def get_dataset_wrapper( function _is_dataset_already_tokenized (line 134) | def _is_dataset_already_tokenized(dataset: Dataset | IterableDataset) ->... function _handle_custom_dataset_type (line 144) | def _handle_custom_dataset_type( function _handle_bradley_terry_dataset (line 165) | def _handle_bradley_terry_dataset( function _handle_stepwise_supervised_dataset (line 189) | def _handle_stepwise_supervised_dataset( function _handle_loaded_strategy (line 213) | def _handle_loaded_strategy( function _handle_alpaca_dataset (line 231) | def _handle_alpaca_dataset( function _handle_explainchoice_dataset (line 254) | def _handle_explainchoice_dataset( function _handle_concisechoice_dataset (line 277) | def _handle_concisechoice_dataset( function _handle_summarizetldr_dataset (line 300) | def _handle_summarizetldr_dataset( function _handle_jeopardy_dataset (line 323) | def _handle_jeopardy_dataset( function _handle_oasst_dataset (line 346) | def _handle_oasst_dataset( function _handle_gpteacher_dataset (line 369) | def _handle_gpteacher_dataset( function _handle_reflection_dataset (line 392) | def _handle_reflection_dataset( FILE: src/axolotl/utils/datasets.py function get_default_process_count (line 10) | def get_default_process_count(): FILE: src/axolotl/utils/dict.py class DictDefault (line 6) | class DictDefault(Dict): method __missing__ (line 11) | def __missing__(self, key): method __or__ (line 14) | def __or__(self, other): method __setitem__ (line 17) | def __setitem__(self, name, value): function remove_none_values (line 41) | def remove_none_values(obj): FILE: src/axolotl/utils/distributed.py function get_device_type (line 21) | def get_device_type() -> torch.device: function get_device_count (line 32) | def get_device_count() -> int: function get_current_device (line 41) | def get_current_device() -> int: function init_distributed_state (line 50) | def init_distributed_state(): function get_distributed_state (line 60) | def get_distributed_state() -> PartialState | None: function is_distributed (line 64) | def is_distributed() -> bool: function barrier (line 74) | def barrier(): function is_main_process (line 83) | def is_main_process() -> bool: function is_local_main_process (line 101) | def is_local_main_process() -> bool: function get_world_size (line 107) | def get_world_size() -> int: function cleanup_distributed (line 111) | def cleanup_distributed(): function zero_first (line 129) | def zero_first(is_main: bool): function gather_scalar_from_all_ranks (line 140) | def gather_scalar_from_all_ranks(fn, world_size=1): function broadcast_dict (line 176) | def broadcast_dict(vals: dict): function compute_and_broadcast (line 204) | def compute_and_broadcast(fn): function gather_from_all_ranks (line 237) | def gather_from_all_ranks(fn, world_size=1): function reduce_and_broadcast (line 274) | def reduce_and_broadcast(fn1, fn2): function build_parallelism_config (line 299) | def build_parallelism_config(cfg): function _get_parallel_config_kwargs (line 319) | def _get_parallel_config_kwargs( FILE: src/axolotl/utils/environment.py function check_cuda_p2p_ib_support (line 19) | def check_cuda_p2p_ib_support(): function check_cuda_p2p_support (line 27) | def check_cuda_p2p_support() -> bool: function get_package_version (line 49) | def get_package_version(package: str) -> Version: function is_package_version_ge (line 54) | def is_package_version_ge(package: str, version_: str) -> bool: FILE: src/axolotl/utils/freeze.py function freeze_layers_except (line 14) | def freeze_layers_except(model, regex_patterns): function _invert_ranges (line 72) | def _invert_ranges( function _merge_ranges (line 102) | def _merge_ranges( function _create_freeze_parameters_hook (line 145) | def _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int... class LayerNamePattern (line 173) | class LayerNamePattern: method __init__ (line 178) | def __init__(self, pattern: str): method match (line 189) | def match(self, name: str) -> bool: method _parse_pattern (line 201) | def _parse_pattern( FILE: src/axolotl/utils/generation/sft.py function generate_samples (line 14) | def generate_samples( function format_generation_for_logging (line 143) | def format_generation_for_logging( FILE: src/axolotl/utils/import_helper.py function get_cls_from_module_str (line 8) | def get_cls_from_module_str(module_str: str): FILE: src/axolotl/utils/logging.py class MultiProcessAdapter (line 20) | class MultiProcessAdapter(logging.LoggerAdapter): method _should_log (line 26) | def _should_log(main_process_only: bool): method log (line 29) | def log(self, level, msg, *args, **kwargs): method warning_once (line 38) | def warning_once(self, *args, **kwargs): function get_logger (line 49) | def get_logger(name: str, log_level: str | None = None) -> MultiProcessA... FILE: src/axolotl/utils/lora.py function get_lora_merged_state_dict (line 24) | def get_lora_merged_state_dict( FILE: src/axolotl/utils/mistral/mistral3_processor.py class Mistral3ProcessorKwargs (line 14) | class Mistral3ProcessorKwargs(ProcessingKwargs): class Mistral3Processor (line 27) | class Mistral3Processor(ProcessorMixin): method __init__ (line 33) | def __init__(self, tokenizer: HFMistralTokenizer): method audio_tokenizer (line 37) | def audio_tokenizer(self) -> None: method _merge_kwargs (line 41) | def _merge_kwargs( method apply_chat_template (line 64) | def apply_chat_template( method __call__ (line 140) | def __call__( FILE: src/axolotl/utils/mistral/mistral_tokenizer.py class HFMistralTokenizer (line 14) | class HFMistralTokenizer(MistralCommonBackend): method __init__ (line 20) | def __init__(self, name_or_path: str, **kwargs): method name_or_path (line 37) | def name_or_path(self) -> str: method name_or_path (line 41) | def name_or_path(self, name_or_path: str) -> None: method chat_template (line 45) | def chat_template(self) -> str | None: method chat_template (line 50) | def chat_template(self, chat_template: str | None) -> None: method _set_mode (line 53) | def _set_mode(self, mode: ValidationMode): method apply_chat_template (line 82) | def apply_chat_template( # type: ignore method decode (line 107) | def decode( # type: ignore method from_pretrained (line 124) | def from_pretrained( method save_pretrained (line 233) | def save_pretrained(self, *args, **kwargs) -> tuple[str, ...]: FILE: src/axolotl/utils/mlflow_.py function setup_mlflow_env_vars (line 8) | def setup_mlflow_env_vars(cfg: DictDefault): FILE: src/axolotl/utils/model_shard_quant.py function _replace_linear (line 21) | def _replace_linear( function load_and_quantize (line 62) | def load_and_quantize( function n_loading_workers (line 128) | def n_loading_workers(quant_method: str, param_count: float): function load_sharded_model (line 140) | def load_sharded_model( function load_sharded_model_quant (line 167) | def load_sharded_model_quant( FILE: src/axolotl/utils/optimizers/adopt.py class ADOPT (line 39) | class ADOPT(Optimizer): method __init__ (line 40) | def __init__( method __setstate__ (line 104) | def __setstate__(self, state): method _init_group (line 126) | def _init_group( method step (line 192) | def step(self, closure=None): function _single_tensor_adopt (line 249) | def _single_tensor_adopt( function _multi_tensor_adopt (line 331) | def _multi_tensor_adopt( function adopt (line 460) | def adopt( FILE: src/axolotl/utils/quantization.py function get_quantization_config (line 52) | def get_quantization_config( function quantize_model (line 138) | def quantize_model( function prepare_model_for_qat (line 176) | def prepare_model_for_qat( function convert_qat_model (line 231) | def convert_qat_model( FILE: src/axolotl/utils/samplers/multipack.py function ffd_check (line 25) | def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins:... function pack_group (line 61) | def pack_group( function _process_group (line 115) | def _process_group( function pack_parallel (line 125) | def pack_parallel( function allocate_sequentially (line 194) | def allocate_sequentially( class MultipackBatchSampler (line 244) | class MultipackBatchSampler(BatchSampler): method __init__ (line 257) | def __init__( method set_epoch (line 305) | def set_epoch(self, epoch: int): method generate_batches (line 310) | def generate_batches(self, set_stats: bool = False) -> list[list[list[... method __iter__ (line 383) | def __iter__(self) -> Iterator[list[list[int]]]: method efficiency (line 395) | def efficiency(self) -> float: method gather_efficiency (line 406) | def gather_efficiency(self) -> float: method gather_len_batches (line 432) | def gather_len_batches(self, num: int) -> int: method __len__ (line 445) | def __len__(self) -> int: FILE: src/axolotl/utils/samplers/utils.py function get_dataset_lengths (line 8) | def get_dataset_lengths(dataset, from_arrow=False): FILE: src/axolotl/utils/schedulers.py class RexLR (line 12) | class RexLR(LRScheduler): method __init__ (line 29) | def __init__( method last_step (line 57) | def last_step(self): method last_step (line 61) | def last_step(self, value): method get_lr (line 64) | def get_lr(self): class InterpolatingLogScheduler (line 88) | class InterpolatingLogScheduler(LRScheduler): method __init__ (line 93) | def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1): method get_lr (line 113) | def get_lr(self): function _get_cosine_schedule_with_quadratic_warmup_lr_lambda (line 127) | def _get_cosine_schedule_with_quadratic_warmup_lr_lambda( function get_cosine_schedule_with_quadratic_warmup (line 144) | def get_cosine_schedule_with_quadratic_warmup( function _get_cosine_schedule_with_min_lr_lambda (line 182) | def _get_cosine_schedule_with_min_lr_lambda( function get_cosine_schedule_with_min_lr (line 201) | def get_cosine_schedule_with_min_lr( function _get_cosine_schedule_with_warmup_decay_constant_lr_lambda (line 222) | def _get_cosine_schedule_with_warmup_decay_constant_lr_lambda( function get_cosine_schedule_with_warmup_decay_constant (line 252) | def get_cosine_schedule_with_warmup_decay_constant( class JaggedLRRestartScheduler (line 299) | class JaggedLRRestartScheduler(LRScheduler): method __init__ (line 302) | def __init__( method get_lr (line 318) | def get_lr(self) -> float | Sequence[float]: FILE: src/axolotl/utils/schemas/config.py class AxolotlInputConfig (line 57) | class AxolotlInputConfig( method datasets_serializer (line 1150) | def datasets_serializer( method warn_peft_trainable_token_to_fix_untrained (line 1159) | def warn_peft_trainable_token_to_fix_untrained(cls, data): method check_sageattn_wo_sample_packing (line 1179) | def check_sageattn_wo_sample_packing(cls, data): method check_sageattn_fft (line 1190) | def check_sageattn_fft(cls, data): class AxolotlConfigWCapabilities (line 1199) | class AxolotlConfigWCapabilities(AxolotlInputConfig): method check_bf16 (line 1206) | def check_bf16(self): method check_tf32 (line 1224) | def check_tf32(self): method check_fp8 (line 1230) | def check_fp8(self): method check_sample_packing_w_sdpa_bf16 (line 1241) | def check_sample_packing_w_sdpa_bf16(cls, data): method check_compute_capability_w_sageattn (line 1262) | def check_compute_capability_w_sageattn(cls, data): method check_multigpu_unsloth (line 1277) | def check_multigpu_unsloth(cls, data): method check_multigpu_lora_kernels (line 1292) | def check_multigpu_lora_kernels(cls, data): method check_quantize_moe_experts (line 1311) | def check_quantize_moe_experts(cls, data): method check_auto_enable_lora_kernels (line 1336) | def check_auto_enable_lora_kernels(cls, data): method check_adopt_torch_version (line 1392) | def check_adopt_torch_version(cls, data): method check_flex_torch_version (line 1410) | def check_flex_torch_version(cls, data): method check_torch_compile_auto (line 1428) | def check_torch_compile_auto(cls, data): method check_beta_and_trl_beta_match (line 1447) | def check_beta_and_trl_beta_match(cls, data): method check_min_torch_version (line 1454) | def check_min_torch_version(self): method check_qat_config (line 1466) | def check_qat_config(cls, data): method check_fsdp_torch_version (line 1495) | def check_fsdp_torch_version(cls, data): method default_dataloader_opts (line 1512) | def default_dataloader_opts(cls, data): method default_dataset_num_proc (line 1526) | def default_dataset_num_proc(cls, data): method check_deduplication_with_streaming (line 1546) | def check_deduplication_with_streaming(cls, data): method check_deduplication_with_skip_prepare (line 1557) | def check_deduplication_with_skip_prepare(cls, data): FILE: src/axolotl/utils/schemas/datasets.py class UserDefinedPrompterType (line 11) | class UserDefinedPrompterType(BaseModel): class SFTDataset (line 39) | class SFTDataset(BaseModel): method handle_legacy_message_fields (line 200) | def handle_legacy_message_fields(cls, data): method check_chat_template_config (line 206) | def check_chat_template_config(cls, data): class PretrainingDataset (line 229) | class PretrainingDataset(BaseModel): class UserDefinedDPOType (line 242) | class UserDefinedDPOType(BaseModel): class DPODataset (line 254) | class DPODataset(BaseModel): class StepwiseSupervisedDataset (line 265) | class StepwiseSupervisedDataset(BaseModel): class UserDefinedKTOType (line 277) | class UserDefinedKTOType(BaseModel): class KTODataset (line 288) | class KTODataset(BaseModel): FILE: src/axolotl/utils/schemas/deprecated.py class DeprecatedParameters (line 12) | class DeprecatedParameters(BaseModel): method validate_max_packed_sequence_len (line 27) | def validate_max_packed_sequence_len(cls, max_packed_sequence_len): method validate_rope_scaling (line 34) | def validate_rope_scaling(cls, rope_scaling): method validate_noisy_embedding_alpha (line 43) | def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha): method validate_dpo_beta (line 50) | def validate_dpo_beta(cls, dpo_beta): method validate_evaluation_strategy (line 57) | def validate_evaluation_strategy(cls, evaluation_strategy): method validate_eval_table_size (line 64) | def validate_eval_table_size(cls, eval_table_size): method validate_eval_max_new_tokens (line 75) | def validate_eval_max_new_tokens(cls, eval_max_new_tokens): method validate_dpo_use_logits_to_keep (line 85) | def validate_dpo_use_logits_to_keep(cls, dpo_use_logits_to_keep): method validate_dpo_generate_during_eval (line 95) | def validate_dpo_generate_during_eval(cls, dpo_generate_during_eval): class RemappedParameters (line 104) | class RemappedParameters(BaseModel): FILE: src/axolotl/utils/schemas/dynamic_checkpoint.py class DynamicCheckpointConfig (line 6) | class DynamicCheckpointConfig(BaseModel): FILE: src/axolotl/utils/schemas/enums.py class TorchAOQuantDType (line 8) | class TorchAOQuantDType(Enum): method from_string (line 15) | def from_string(str): class RLType (line 28) | class RLType(str, Enum): class ChatTemplate (line 40) | class ChatTemplate(str, Enum): class CustomSupportedOptimizers (line 79) | class CustomSupportedOptimizers(str, Enum): class RingAttnFunc (line 97) | class RingAttnFunc(str, Enum): FILE: src/axolotl/utils/schemas/fsdp.py class FSDPConfig (line 10) | class FSDPConfig(BaseModel): FILE: src/axolotl/utils/schemas/integrations.py class MLFlowConfig (line 12) | class MLFlowConfig(BaseModel): class LISAConfig (line 33) | class LISAConfig(BaseModel): class WandbConfig (line 50) | class WandbConfig(BaseModel): method check_wandb_run (line 84) | def check_wandb_run(cls, data): class CometConfig (line 95) | class CometConfig(BaseModel): class GradioConfig (line 146) | class GradioConfig(BaseModel): class RayConfig (line 157) | class RayConfig(BaseModel): class OpenTelemetryConfig (line 181) | class OpenTelemetryConfig(BaseModel): class TrackioConfig (line 205) | class TrackioConfig(BaseModel): FILE: src/axolotl/utils/schemas/internal/__init__.py class GPUCapabilities (line 8) | class GPUCapabilities(BaseModel): class EnvCapabilities (line 19) | class EnvCapabilities(BaseModel): FILE: src/axolotl/utils/schemas/model.py class ModelInputConfig (line 12) | class ModelInputConfig(BaseModel): method hint_trust_remote_code (line 101) | def hint_trust_remote_code(cls, trust_remote_code): class ModelOutputConfig (line 109) | class ModelOutputConfig(BaseModel): method validate_save_safetensors (line 138) | def validate_save_safetensors(cls, v): class SpecialTokensConfig (line 149) | class SpecialTokensConfig(BaseModel): FILE: src/axolotl/utils/schemas/multimodal.py class MultiModalConfig (line 9) | class MultiModalConfig(BaseModel): method convert_image_resize_algorithm (line 32) | def convert_image_resize_algorithm(cls, image_resize_algorithm): FILE: src/axolotl/utils/schemas/peft.py class LoftQConfig (line 8) | class LoftQConfig(BaseModel): class PeftConfig (line 17) | class PeftConfig(BaseModel): class LoraConfig (line 28) | class LoraConfig(BaseModel): method validate_adapter (line 161) | def validate_adapter(cls, data): method validate_qlora (line 174) | def validate_qlora(self): method convert_loraplus_lr_embedding (line 200) | def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding): method validate_lora_dropout (line 207) | def validate_lora_dropout(cls, data): method validate_lora_target_parameters_dropout (line 213) | def validate_lora_target_parameters_dropout(self): class ReLoRAConfig (line 226) | class ReLoRAConfig(BaseModel): FILE: src/axolotl/utils/schemas/quantization.py function validate_ao_dtype (line 12) | def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None: class QATConfig (line 31) | class QATConfig(BaseModel): method validate_dtype (line 57) | def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None: class PTQConfig (line 61) | class PTQConfig(BaseModel): method validate_dtype (line 84) | def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None: FILE: src/axolotl/utils/schemas/training.py class LrGroup (line 15) | class LrGroup(BaseModel): class HyperparametersConfig (line 23) | class HyperparametersConfig(BaseModel): method hint_batch_size_set (line 168) | def hint_batch_size_set(cls, batch_size): method convert_learning_rate (line 179) | def convert_learning_rate(cls, learning_rate): class JaggedLRConfig (line 185) | class JaggedLRConfig(BaseModel): FILE: src/axolotl/utils/schemas/trl.py class TRLConfig (line 8) | class TRLConfig(BaseModel): FILE: src/axolotl/utils/schemas/utils.py function handle_legacy_message_fields_logic (line 8) | def handle_legacy_message_fields_logic(data: dict) -> dict: FILE: src/axolotl/utils/schemas/validation.py class DatasetValidationMixin (line 22) | class DatasetValidationMixin: method set_default_seed (line 27) | def set_default_seed(cls, seed): method deprecate_sharegpt_datasets (line 35) | def deprecate_sharegpt_datasets(cls, datasets): method check_dataset_or_pretraining_dataset (line 57) | def check_dataset_or_pretraining_dataset(cls, data): method check_pretraining_streaming_deprecation (line 64) | def check_pretraining_streaming_deprecation(cls, data): method check_push_ds_auth (line 78) | def check_push_ds_auth(cls, data): method check_val_w_test_datasets (line 90) | def check_val_w_test_datasets(cls, data): method check_test_datasets_bench (line 99) | def check_test_datasets_bench(cls, data): method check_eval_packing (line 113) | def check_eval_packing(cls, data): method check_mm_prepare (line 148) | def check_mm_prepare(cls, data): class AttentionValidationMixin (line 159) | class AttentionValidationMixin: method check_attention_fields (line 164) | def check_attention_fields(cls, data): method check_sample_packing_without_attention (line 181) | def check_sample_packing_without_attention(cls, data): method check_sample_packing_with_s2attn (line 197) | def check_sample_packing_with_s2attn(cls, data): method check_scaling_softmax_requires_flex (line 207) | def check_scaling_softmax_requires_flex(cls, data): class TrainingValidationMixin (line 216) | class TrainingValidationMixin: method check_batch_size_fields (line 221) | def check_batch_size_fields(cls, data): method hint_sample_packing_padding (line 231) | def hint_sample_packing_padding(cls, data): method hint_reward_model_pad (line 247) | def hint_reward_model_pad(cls, data): method set_reward_model_defaults (line 258) | def set_reward_model_defaults(cls, data): method check_gas_bsz (line 275) | def check_gas_bsz(cls, data): method hint_eval_train_mbsz (line 284) | def hint_eval_train_mbsz(cls, data): method check_warmup (line 297) | def check_warmup(cls, data): method check_saves (line 304) | def check_saves(cls, data): method check_push_save (line 321) | def check_push_save(cls, data): method check_evals (line 332) | def check_evals(cls, data): method check_neftune (line 374) | def check_neftune(cls, data): method check_multipack_buffer_size (line 386) | def check_multipack_buffer_size(cls, data): method check_fft_possible_bad_config (line 409) | def check_fft_possible_bad_config(self): method check_fp8_config (line 427) | def check_fp8_config(cls, data): method check_use_reentrant_mismatch (line 457) | def check_use_reentrant_mismatch(cls, data): method check_eval_strategy (line 472) | def check_eval_strategy(cls, data): method check_causal_lm_evals (line 485) | def check_causal_lm_evals(cls, data): method check_tokenizer_use_mistral_common (line 503) | def check_tokenizer_use_mistral_common(cls, data): method check_mistral_common_import (line 522) | def check_mistral_common_import(cls, tokenizer_use_mistral_common): method check_mistral_common_incompatible_options (line 535) | def check_mistral_common_incompatible_options(cls, data): method pretrain_with_tps (line 565) | def pretrain_with_tps(cls, data): class LoRAValidationMixin (line 578) | class LoRAValidationMixin: method check_lr_groups (line 583) | def check_lr_groups(cls, data): method check_frozen (line 590) | def check_frozen(cls, data): method check_peft_layers_pattern (line 603) | def check_peft_layers_pattern(cls, data): method check_qlora_unsloth (line 612) | def check_qlora_unsloth(cls, data): method check_lora_axolotl_unsloth (line 626) | def check_lora_axolotl_unsloth(cls, data): method check_fused_lora (line 641) | def check_fused_lora(self): method warn_qlora_zero3_w_use_reentrant (line 648) | def warn_qlora_zero3_w_use_reentrant(cls, data): method check_lora_kernels_8bit (line 668) | def check_lora_kernels_8bit(cls, data): method check_lora_kernels_dora (line 683) | def check_lora_kernels_dora(cls, data): method check_lora_kernels_trust_remote_code (line 697) | def check_lora_kernels_trust_remote_code(cls, data): class RLValidationMixin (line 711) | class RLValidationMixin: method check_sample_packing_w_rl (line 716) | def check_sample_packing_w_rl(cls, data): method check_kto_config (line 723) | def check_kto_config(cls, data): method check_grpo_liger_sequence_parallel (line 734) | def check_grpo_liger_sequence_parallel(cls, data): method check_rl_config_gradient_checkpointing (line 746) | def check_rl_config_gradient_checkpointing(cls, data): method check_gdpo (line 770) | def check_gdpo(cls, data): class OptimizationValidationMixin (line 782) | class OptimizationValidationMixin: method check_adamw_optimizer_params (line 786) | def check_adamw_optimizer_params(self): method _resolve_fsdp_version (line 794) | def _resolve_fsdp_version(data): method check_muon_deepspeed_fsdp (line 803) | def check_muon_deepspeed_fsdp(cls, data): method check_flashoptim_deepspeed_fsdp (line 819) | def check_flashoptim_deepspeed_fsdp(cls, data): method check_batch_flattening_fa (line 838) | def check_batch_flattening_fa(cls, data): method check_xentropy_patch_conflicts (line 862) | def check_xentropy_patch_conflicts(cls, data): method check_cross_entropy_conflicts (line 873) | def check_cross_entropy_conflicts(cls, data): method check_fsdp_version (line 903) | def check_fsdp_version(cls, data): method check_fsdp2_cpu_offload_pin_memory (line 916) | def check_fsdp2_cpu_offload_pin_memory(cls, data): method check_fsdp2_base_model_quant_rl (line 933) | def check_fsdp2_base_model_quant_rl(cls, data): method check_fsdp_config_kwargs_prefix (line 949) | def check_fsdp_config_kwargs_prefix(cls, data): method check_fsdp_version_in_fsdp_config (line 971) | def check_fsdp_version_in_fsdp_config(cls, data): method check_fsdp_offload_w_8bit_optimizer (line 985) | def check_fsdp_offload_w_8bit_optimizer(self): method check_fsdp2_w_8bit_optimizer (line 1000) | def check_fsdp2_w_8bit_optimizer(self): method check_tensor_parallel_size_update_ds_json (line 1018) | def check_tensor_parallel_size_update_ds_json(cls, data): method check_deepcompile (line 1050) | def check_deepcompile(cls, data): class SystemValidationMixin (line 1069) | class SystemValidationMixin: method check_mem_mismatch (line 1074) | def check_mem_mismatch(cls, data): method check_fsdp_deepspeed (line 1086) | def check_fsdp_deepspeed(cls, data): method check_model_quantization_config_vs_bnb (line 1093) | def check_model_quantization_config_vs_bnb(cls, data): method check_npu_config (line 1103) | def check_npu_config(cls, data): class ChatTemplateValidationMixin (line 1136) | class ChatTemplateValidationMixin: method check_chat_template_config (line 1141) | def check_chat_template_config(cls, data): class PretrainingValidationMixin (line 1157) | class PretrainingValidationMixin: method check_pretraining_w_max_steps (line 1162) | def check_pretraining_w_max_steps(cls, data): method check_pretraining_w_group_by_length (line 1171) | def check_pretraining_w_group_by_length(cls, data): method check_pretraining_split_batches_accelerate (line 1180) | def check_pretraining_split_batches_accelerate(cls, data): method check_pretraining_w_val_set_size (line 1198) | def check_pretraining_w_val_set_size(cls, data): method check_streaming_w_val_set_size (line 1208) | def check_streaming_w_val_set_size(cls, data): method check_streaming_w_max_steps (line 1218) | def check_streaming_w_max_steps(cls, data): method check_streaming_w_multiple_datasets (line 1228) | def check_streaming_w_multiple_datasets(cls, data): class ModelCompatibilityValidationMixin (line 1241) | class ModelCompatibilityValidationMixin: method check_falcon_fsdp (line 1245) | def check_falcon_fsdp(self): method check_mpt_checkpointing (line 1251) | def check_mpt_checkpointing(self): method check_gradient_checkpointing_w_offload (line 1259) | def check_gradient_checkpointing_w_offload(self): method check_activation_offloading_wo_gc (line 1278) | def check_activation_offloading_wo_gc(self): method check_better_transformers (line 1284) | def check_better_transformers(self): method check_gptq_w_revision (line 1301) | def check_gptq_w_revision(cls, data): method check_gpt_oss_fsdp_loading (line 1312) | def check_gpt_oss_fsdp_loading(cls, data): class ComplexValidationMixin (line 1322) | class ComplexValidationMixin: method validate_neftune_noise_alpha (line 1327) | def validate_neftune_noise_alpha(cls, neftune_noise_alpha): method check_rl_beta (line 1333) | def check_rl_beta(self): method check_simpo_warmup (line 1340) | def check_simpo_warmup(self): method check_relora (line 1348) | def check_relora(self): method check_early_stopping (line 1371) | def check_early_stopping(self): method check_tensor_parallel_size (line 1384) | def check_tensor_parallel_size(self): method check_context_parallel_size (line 1390) | def check_context_parallel_size(self): method validate_ring_attn_func (line 1446) | def validate_ring_attn_func(self): method hint_gradient_checkpointing_dpo_lora_ddp (line 1463) | def hint_gradient_checkpointing_dpo_lora_ddp(self): class DistributedValidationMixin (line 1479) | class DistributedValidationMixin: method check_tensor_parallel_optimizer (line 1483) | def check_tensor_parallel_optimizer(self): class GRPOVllmValidationMixin (line 1493) | class GRPOVllmValidationMixin: method check_vllm_mode_set (line 1497) | def check_vllm_mode_set(self): class ValidationMixin (line 1506) | class ValidationMixin( FILE: src/axolotl/utils/schemas/vllm.py class VllmConfig (line 8) | class VllmConfig(BaseModel): FILE: src/axolotl/utils/tee.py class _FileOnlyWriter (line 23) | class _FileOnlyWriter(io.TextIOBase): method write (line 29) | def write(self, s: str) -> int: # type: ignore[override] method flush (line 36) | def flush(self) -> None: # type: ignore[override] class _StreamTee (line 48) | class _StreamTee(io.TextIOBase): method __init__ (line 54) | def __init__(self, stream: io.TextIOBase): method write (line 57) | def write(self, s: str) -> int: # type: ignore[override] method flush (line 64) | def flush(self) -> None: # type: ignore[override] method encoding (line 74) | def encoding(self): # type: ignore[override] method errors (line 78) | def errors(self): # type: ignore[override] method isatty (line 81) | def isatty(self): # type: ignore[override] method fileno (line 84) | def fileno(self): # type: ignore[override] function prepare_debug_log (line 90) | def prepare_debug_log(cfg, filename: str = "debug.log") -> str: function close_debug_log (line 140) | def close_debug_log() -> None: FILE: src/axolotl/utils/tokenization.py function check_dataset_labels (line 10) | def check_dataset_labels( function check_example_labels (line 25) | def check_example_labels(example, tokenizer, text_only=False): function color_token_for_rl_debug (line 57) | def color_token_for_rl_debug(decoded_token, encoded_token, color, text_o... function process_tokens_for_rl_debug (line 67) | def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only): function check_rl_example_labels (line 76) | def check_rl_example_labels(example, tokenizer, text_only=False): FILE: src/axolotl/utils/trackio_.py function setup_trackio_env_vars (line 8) | def setup_trackio_env_vars(cfg: DictDefault): FILE: src/axolotl/utils/train.py function determine_last_checkpoint (line 11) | def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> ... FILE: src/axolotl/utils/trainer.py function weighted_cross_entropy (line 29) | def weighted_cross_entropy( function create_weighted_mask (line 47) | def create_weighted_mask(labels: torch.Tensor): function trainer_weighted_loss (line 78) | def trainer_weighted_loss(model_output, labels, shift_labels=True): function disable_datasets_caching (line 91) | def disable_datasets_caching(): function add_position_ids (line 99) | def add_position_ids(sample): function add_pose_position_ids (line 138) | def add_pose_position_ids( function add_length (line 203) | def add_length(sample): function filter_sequences_by_length (line 208) | def filter_sequences_by_length( function process_datasets_for_packing (line 252) | def process_datasets_for_packing(cfg, train_dataset, eval_dataset): function process_pretraining_datasets_for_packing (line 386) | def process_pretraining_datasets_for_packing( function calculate_total_num_steps (line 408) | def calculate_total_num_steps(cfg, train_dataset, update=True): function setup_torch_compile_env (line 525) | def setup_torch_compile_env(cfg): function setup_deepspeed_env (line 533) | def setup_deepspeed_env(cfg, stage=None): function setup_fsdp_envs (line 589) | def setup_fsdp_envs(cfg): function setup_parallelism_envs (line 621) | def setup_parallelism_envs(cfg): function prepare_optim_env (line 643) | def prepare_optim_env(cfg): function setup_trainer (line 679) | def setup_trainer( FILE: src/axolotl/utils/wandb_.py function setup_wandb_env_vars (line 8) | def setup_wandb_env_vars(cfg: DictDefault): FILE: src/setuptools_axolotl_dynamic_dependencies.py function parse_requirements (line 12) | def parse_requirements(): class BuildPyCommand (line 94) | class BuildPyCommand(_build_py): method finalize_options (line 99) | def finalize_options(self): FILE: tests/cli/conftest.py function cli_runner (line 22) | def cli_runner(): function valid_test_config (line 27) | def valid_test_config(): function config_path (line 32) | def config_path(tmp_path): FILE: tests/cli/test_cli_base.py class BaseCliTest (line 9) | class BaseCliTest: method _test_cli_validation (line 12) | def _test_cli_validation(self, cli_runner, command: str): method _test_basic_execution (line 30) | def _test_basic_execution( method _test_cli_overrides (line 78) | def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str): FILE: tests/cli/test_cli_evaluate.py class TestEvaluateCommand (line 10) | class TestEvaluateCommand(BaseCliTest): method test_evaluate_cli_validation (line 15) | def test_evaluate_cli_validation(self, cli_runner): method test_evaluate_basic_execution (line 19) | def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_te... method test_evaluate_basic_execution_no_accelerate (line 25) | def test_evaluate_basic_execution_no_accelerate( method test_evaluate_cli_overrides (line 47) | def test_evaluate_cli_overrides(self, cli_runner, tmp_path, valid_test... method test_evaluate_with_launcher_args_torchrun (line 73) | def test_evaluate_with_launcher_args_torchrun( method test_evaluate_with_launcher_args_accelerate (line 106) | def test_evaluate_with_launcher_args_accelerate( method test_evaluate_backward_compatibility_no_launcher_args (line 140) | def test_evaluate_backward_compatibility_no_launcher_args( FILE: tests/cli/test_cli_fetch.py function test_fetch_cli_examples (line 8) | def test_fetch_cli_examples(cli_runner): function test_fetch_cli_deepspeed (line 17) | def test_fetch_cli_deepspeed(cli_runner): function test_fetch_cli_with_dest (line 26) | def test_fetch_cli_with_dest(cli_runner, tmp_path): function test_fetch_cli_invalid_directory (line 36) | def test_fetch_cli_invalid_directory(cli_runner): FILE: tests/cli/test_cli_inference.py function test_inference_basic (line 8) | def test_inference_basic(cli_runner, config_path): function test_inference_gradio (line 21) | def test_inference_gradio(cli_runner, config_path): function test_inference_with_launcher_args_torchrun (line 34) | def test_inference_with_launcher_args_torchrun(cli_runner, config_path): function test_inference_with_launcher_args_accelerate (line 63) | def test_inference_with_launcher_args_accelerate(cli_runner, config_path): function test_inference_gradio_with_launcher_args (line 93) | def test_inference_gradio_with_launcher_args(cli_runner, config_path): function test_inference_backward_compatibility_no_launcher_args (line 123) | def test_inference_backward_compatibility_no_launcher_args(cli_runner, c... FILE: tests/cli/test_cli_interface.py function test_build_command (line 6) | def test_build_command(): function test_invalid_command_options (line 28) | def test_invalid_command_options(cli_runner): function test_required_config_argument (line 43) | def test_required_config_argument(cli_runner): FILE: tests/cli/test_cli_merge_lora.py function test_merge_lora_basic (line 8) | def test_merge_lora_basic(cli_runner, config_path): function test_merge_lora_with_dirs (line 18) | def test_merge_lora_with_dirs(cli_runner, config_path, tmp_path): function test_merge_lora_nonexistent_config (line 44) | def test_merge_lora_nonexistent_config(cli_runner, tmp_path): function test_merge_lora_nonexistent_lora_dir (line 51) | def test_merge_lora_nonexistent_lora_dir(cli_runner, config_path, tmp_pa... FILE: tests/cli/test_cli_merge_sharded_fsdp_weights.py function test_merge_sharded_fsdp_weights_no_accelerate (line 8) | def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path): function test_merge_sharded_fsdp_weights_with_launcher_args_torchrun (line 21) | def test_merge_sharded_fsdp_weights_with_launcher_args_torchrun( function test_merge_sharded_fsdp_weights_with_launcher_args_accelerate (line 52) | def test_merge_sharded_fsdp_weights_with_launcher_args_accelerate( function test_merge_sharded_fsdp_weights_backward_compatibility_no_launcher_args (line 84) | def test_merge_sharded_fsdp_weights_backward_compatibility_no_launcher_a... FILE: tests/cli/test_cli_preprocess.py function cleanup_last_run_prepared (line 13) | def cleanup_last_run_prepared(): function test_preprocess_config_not_found (line 20) | def test_preprocess_config_not_found(cli_runner): function test_preprocess_basic (line 26) | def test_preprocess_basic(cli_runner, config_path): function test_preprocess_without_download (line 40) | def test_preprocess_without_download(cli_runner, config_path): function test_preprocess_custom_path (line 53) | def test_preprocess_custom_path(cli_runner, tmp_path, valid_test_config): FILE: tests/cli/test_cli_sweeps.py function test_generate_sweep_configs_no_pairs (line 8) | def test_generate_sweep_configs_no_pairs(): function test_generate_sweep_configs_with_pairs (line 33) | def test_generate_sweep_configs_with_pairs(): FILE: tests/cli/test_cli_train.py class TestTrainCommand (line 10) | class TestTrainCommand(BaseCliTest): method test_train_cli_validation (line 15) | def test_train_cli_validation(self, cli_runner): method test_train_basic_execution (line 19) | def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_... method test_train_basic_execution_no_accelerate (line 25) | def test_train_basic_execution_no_accelerate( method test_train_cli_overrides (line 51) | def test_train_cli_overrides(self, cli_runner, tmp_path, valid_test_co... method test_train_with_launcher_args_torchrun (line 79) | def test_train_with_launcher_args_torchrun( method test_train_with_launcher_args_accelerate (line 112) | def test_train_with_launcher_args_accelerate( method test_train_backward_compatibility_no_launcher_args (line 147) | def test_train_backward_compatibility_no_launcher_args( method test_train_mixed_args_with_launcher_args (line 182) | def test_train_mixed_args_with_launcher_args( method test_train_cloud_with_launcher_args (line 218) | def test_train_cloud_with_launcher_args( FILE: tests/cli/test_cli_version.py function test_print_version (line 6) | def test_print_version(cli_runner): FILE: tests/cli/test_nested_options.py class InnerConfig (line 10) | class InnerConfig(BaseModel): class OuterConfig (line 27) | class OuterConfig(BaseModel): class TestAddOptionsFromConfigNested (line 44) | class TestAddOptionsFromConfigNested: method setup_method (line 47) | def setup_method(self): method test_nested_dot_notation_options_are_registered (line 50) | def test_nested_dot_notation_options_are_registered(self): method test_nested_bool_option (line 65) | def test_nested_bool_option(self): method test_flat_and_nested_options_together (line 79) | def test_flat_and_nested_options_together(self): method test_no_nested_options_passed (line 97) | def test_no_nested_options_passed(self): class TestLoadCfgNestedKwargs (line 111) | class TestLoadCfgNestedKwargs: method _apply_nested_kwargs (line 115) | def _apply_nested_kwargs(cfg, kwargs): method test_nested_kwargs_applied_to_cfg (line 145) | def test_nested_kwargs_applied_to_cfg(self, tmp_path): method test_nested_kwargs_creates_parent_if_none (line 165) | def test_nested_kwargs_creates_parent_if_none(self): method test_nested_kwargs_overwrites_string_parent (line 176) | def test_nested_kwargs_overwrites_string_parent(self): class TestCoerceValue (line 186) | class TestCoerceValue: method test_coerce_with_existing_float (line 189) | def test_coerce_with_existing_float(self): method test_coerce_with_existing_int (line 195) | def test_coerce_with_existing_int(self): method test_coerce_with_existing_bool (line 201) | def test_coerce_with_existing_bool(self): method test_coerce_yaml_inference_no_existing (line 209) | def test_coerce_yaml_inference_no_existing(self): method test_coerce_non_string_passthrough (line 222) | def test_coerce_non_string_passthrough(self): FILE: tests/cli/test_utils.py function mock_responses (line 23) | def mock_responses(): function test_fetch_from_github_new_files (line 37) | def test_fetch_from_github_new_files(tmp_path, mock_responses): function test_fetch_from_github_unchanged_files (line 48) | def test_fetch_from_github_unchanged_files(tmp_path, mock_responses): function test_fetch_from_github_invalid_prefix (line 61) | def test_fetch_from_github_invalid_prefix(mock_responses): function test_fetch_from_github_network_error (line 68) | def test_fetch_from_github_network_error(): function assert_launcher_args_in_command (line 75) | def assert_launcher_args_in_command( function assert_no_launcher_args_contamination (line 111) | def assert_no_launcher_args_contamination(mock_subprocess_call, launcher... function common_launcher_args (line 141) | def common_launcher_args(): function test_add_default_rdzv_args_with_endpoint (line 149) | def test_add_default_rdzv_args_with_endpoint(): function test_add_default_rdzv_args_with_existing_backend (line 165) | def test_add_default_rdzv_args_with_existing_backend(): function test_add_default_rdzv_args_with_existing_id (line 182) | def test_add_default_rdzv_args_with_existing_id(): function test_add_default_rdzv_args_without_endpoint (line 203) | def test_add_default_rdzv_args_without_endpoint(): function test_add_default_rdzv_args_with_all_existing (line 215) | def test_add_default_rdzv_args_with_all_existing(): FILE: tests/conftest.py function retry_on_request_exceptions (line 33) | def retry_on_request_exceptions(max_retries=3, delay=1): function snapshot_download_w_retry (line 57) | def snapshot_download_w_retry(*args, **kwargs): function download_ds_fixture_bundle (line 73) | def download_ds_fixture_bundle(): function download_smollm2_135m_model (line 81) | def download_smollm2_135m_model(): function download_smollm2_135m_instruct_model (line 87) | def download_smollm2_135m_instruct_model(): function download_smollm2_135m_gptq_model (line 93) | def download_smollm2_135m_gptq_model(): function download_qwen_2_5_half_billion_model (line 99) | def download_qwen_2_5_half_billion_model(): function download_qwen3_half_billion_model (line 105) | def download_qwen3_half_billion_model(): function download_tatsu_lab_alpaca_dataset (line 111) | def download_tatsu_lab_alpaca_dataset(): function download_mhenrichsen_alpaca_2k_dataset (line 117) | def download_mhenrichsen_alpaca_2k_dataset(): function download_mhenrichsen_alpaca_2k_w_revision_dataset (line 123) | def download_mhenrichsen_alpaca_2k_w_revision_dataset(): function download_mlabonne_finetome_100k_dataset (line 131) | def download_mlabonne_finetome_100k_dataset(): function download_argilla_distilabel_capybara_dpo_7k_binarized_dataset (line 137) | def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset(): function download_argilla_distilabel_intel_orca_dpo_dataset (line 145) | def download_argilla_distilabel_intel_orca_dpo_dataset(): function download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset (line 153) | def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset(): function download_argilla_ultrafeedback_binarized_preferences_cleaned_kto_dataset (line 161) | def download_argilla_ultrafeedback_binarized_preferences_cleaned_kto_dat... function download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset (line 200) | def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset(): function download_argilla_dpo_pairs_dataset (line 208) | def download_argilla_dpo_pairs_dataset(): function download_tiny_shakespeare_dataset (line 216) | def download_tiny_shakespeare_dataset(): function download_evolkit_kd_sample_dataset (line 222) | def download_evolkit_kd_sample_dataset(): function download_deepseek_model_fixture (line 230) | def download_deepseek_model_fixture(): function download_huggyllama_model_fixture (line 235) | def download_huggyllama_model_fixture(): function download_llama33_70b_model_fixture (line 245) | def download_llama33_70b_model_fixture(): function download_llama_1b_model_fixture (line 255) | def download_llama_1b_model_fixture(): function download_llama3_8b_model_fixture (line 265) | def download_llama3_8b_model_fixture(): function download_llama3_8b_instruct_model_fixture (line 275) | def download_llama3_8b_instruct_model_fixture(): function download_phi_35_mini_model_fixture (line 285) | def download_phi_35_mini_model_fixture(): function download_phi_4_reasoning_model_fixture (line 295) | def download_phi_4_reasoning_model_fixture(): function download_phi_3_medium_model_fixture (line 305) | def download_phi_3_medium_model_fixture(): function download_mistral_7b_model_fixture (line 315) | def download_mistral_7b_model_fixture(): function download_gemma3_4b_model_fixture (line 325) | def download_gemma3_4b_model_fixture(): function download_gemma_2b_model_fixture (line 335) | def download_gemma_2b_model_fixture(): function download_gemma2_9b_model_fixture (line 346) | def download_gemma2_9b_model_fixture(): function download_mlx_mistral_7b_model_fixture (line 356) | def download_mlx_mistral_7b_model_fixture(): function download_llama2_model_fixture (line 366) | def download_llama2_model_fixture(): function download_llama32_1b_model_fixture (line 376) | def download_llama32_1b_model_fixture(): function tokenizer_huggyllama (line 385) | def tokenizer_huggyllama( function tokenizer_huggyllama_w_special_tokens (line 396) | def tokenizer_huggyllama_w_special_tokens( function tokenizer_llama2_7b (line 412) | def tokenizer_llama2_7b( function tokenizer_mistral_7b_instruct (line 422) | def tokenizer_mistral_7b_instruct( function tokenizer_mistral_7b_instruct_chatml (line 429) | def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct): function temp_dir (line 446) | def temp_dir() -> Generator[str, None, None]: function torch_manual_seed (line 455) | def torch_manual_seed(): function cleanup_monkeypatches (line 460) | def cleanup_monkeypatches(): function dataset_winglian_tiny_shakespeare (line 512) | def dataset_winglian_tiny_shakespeare( function dataset_tatsu_lab_alpaca (line 520) | def dataset_tatsu_lab_alpaca( function dataset_mhenrichsen_alpaca_2k_test (line 528) | def dataset_mhenrichsen_alpaca_2k_test( function dataset_argilla_ultrafeedback_binarized_preferences_cleaned (line 536) | def dataset_argilla_ultrafeedback_binarized_preferences_cleaned( function dataset_fozziethebeat_alpaca_messages_2k_dpo_test (line 547) | def dataset_fozziethebeat_alpaca_messages_2k_dpo_test( function dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff (line 555) | def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff( function fixture_min_base_cfg (line 566) | def fixture_min_base_cfg(): function test_load_fixtures (line 586) | def test_load_fixtures( function disable_telemetry (line 617) | def disable_telemetry(monkeypatch): FILE: tests/core/chat/test_messages.py function llama_tokenizer_fixture (line 18) | def llama_tokenizer_fixture(): function llama_tokenizer_w_chatml (line 23) | def llama_tokenizer_w_chatml(llama_tokenizer): function chat_msgs_fixture (line 41) | def chat_msgs_fixture(): class TestMessagesCase (line 118) | class TestMessagesCase: method test_tool_call_stringify (line 123) | def test_tool_call_stringify(self, chat_msgs): method test_chatml_formatted_wrapper (line 129) | def test_chatml_formatted_wrapper(self, chat_msgs): method test_chatml_formatting_tool_call (line 156) | def test_chatml_formatting_tool_call(self, chat_msgs): method test_train_labels (line 163) | def test_train_labels(self, chatml_tokenizer, chat_msgs): method test_train_labels_2 (line 180) | def test_train_labels_2(self, chatml_tokenizer, chat_msgs): FILE: tests/core/test_async_grpo.py class TestReplayBuffer (line 9) | class TestReplayBuffer(unittest.TestCase): method test_add_noop_when_max_size_zero (line 12) | def test_add_noop_when_max_size_zero(self): method test_add_noop_when_max_size_negative (line 19) | def test_add_noop_when_max_size_negative(self): method test_sample_returns_none_when_max_size_zero (line 26) | def test_sample_returns_none_when_max_size_zero(self): method test_sample_returns_none_when_empty (line 32) | def test_sample_returns_none_when_empty(self): method test_normal_add_and_sample (line 38) | def test_normal_add_and_sample(self): method test_replaces_lowest_when_full (line 50) | def test_replaces_lowest_when_full(self): class TestGRPOStrategyConflict (line 62) | class TestGRPOStrategyConflict(unittest.TestCase): method test_raises_on_both_enabled (line 65) | def test_raises_on_both_enabled(self): method test_sequence_parallel_only (line 73) | def test_sequence_parallel_only(self): method test_async_only (line 82) | def test_async_only(self): method test_neither (line 89) | def test_neither(self): class TestDequantizeFP8TailBlocks (line 97) | class TestDequantizeFP8TailBlocks(unittest.TestCase): method test_exact_divisible_shape (line 100) | def test_exact_divisible_shape(self): method test_non_divisible_rows (line 109) | def test_non_divisible_rows(self): method test_non_divisible_cols (line 121) | def test_non_divisible_cols(self): method test_scalar_scale (line 129) | def test_scalar_scale(self): class TestLoraFP8Guard (line 138) | class TestLoraFP8Guard(unittest.TestCase): method test_non_fp8_weight_skips_scale_inv (line 141) | def test_non_fp8_weight_skips_scale_inv(self): method test_fp8_weight_uses_scale_inv (line 160) | def test_fp8_weight_uses_scale_inv(self): class TestValidateQuantPatchRestore (line 181) | class TestValidateQuantPatchRestore(unittest.TestCase): method test_patch_restored_on_success (line 184) | def test_patch_restored_on_success(self): method test_patch_restored_on_error (line 201) | def test_patch_restored_on_error(self): FILE: tests/core/test_builders.py function fixture_base_cfg (line 21) | def fixture_base_cfg(): function fixture_dpo_cfg (line 91) | def fixture_dpo_cfg(base_cfg): function fixture_orpo_cfg (line 105) | def fixture_orpo_cfg(base_cfg): function fixture_kto_cfg (line 118) | def fixture_kto_cfg(base_cfg): function fixture_grpo_cfg (line 132) | def fixture_grpo_cfg(base_cfg): function fixture_ipo_cfg (line 163) | def fixture_ipo_cfg(base_cfg): function fixture_simpo_cfg (line 176) | def fixture_simpo_cfg(base_cfg): function fixture_sft_cfg (line 190) | def fixture_sft_cfg(base_cfg): function fixture_rm_cfg (line 204) | def fixture_rm_cfg(sft_cfg): function fixture_prm_cfg (line 224) | def fixture_prm_cfg(sft_cfg): function fixture_tokenizer (line 244) | def fixture_tokenizer(base_cfg): function fixture_model (line 249) | def fixture_model(base_cfg, tokenizer): class TestHFRLTrainerBuilder (line 254) | class TestHFRLTrainerBuilder: method _test_common_training_arguments (line 259) | def _test_common_training_arguments(self, training_arguments, rl: str): method test_dpo_training_arguments (line 291) | def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer): method test_orpo_training_arguments (line 302) | def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer): method test_kto_training_arguments (line 310) | def test_kto_training_arguments(self, kto_cfg, model, tokenizer): method _write_rewards_file (line 319) | def _write_rewards_file(self, rewards_dir: Path): method test_grpo_training_arguments (line 333) | def test_grpo_training_arguments(self, grpo_cfg, model, tokenizer, tmp... method test_ipo_training_arguments (line 366) | def test_ipo_training_arguments(self, ipo_cfg, model, tokenizer): method test_simpo_training_arguments (line 376) | def test_simpo_training_arguments(self, simpo_cfg, model, tokenizer): method test_custom_optimizer_cls_and_kwargs (line 409) | def test_custom_optimizer_cls_and_kwargs( class TestHFCausalTrainerBuilder (line 501) | class TestHFCausalTrainerBuilder: method test_training_arguments (line 506) | def test_training_arguments(self, sft_cfg, model, tokenizer): method test_builder_w_rm_trainers (line 543) | def test_builder_w_rm_trainers(self, request, cfg_string, model, token... class TestTrainerClsPlugin (line 574) | class TestTrainerClsPlugin: method test_trainer_cls_is_not_none_with_plugin (line 579) | def test_trainer_cls_is_not_none_with_plugin(self, kto_cfg, model, tok... FILE: tests/e2e/integrations/test_cut_cross_entropy.py function min_cfg (line 17) | def min_cfg(temp_dir): class TestCutCrossEntropyIntegration (line 48) | class TestCutCrossEntropyIntegration: method test_llama_w_cce (line 53) | def test_llama_w_cce(self, min_cfg, temp_dir): method test_qwen2_w_cce (line 68) | def test_qwen2_w_cce(self, temp_dir): method test_llama_w_cce_and_attention (line 120) | def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_... FILE: tests/e2e/integrations/test_fp8.py class FP8IntegrationTestCase (line 13) | class FP8IntegrationTestCase: method test_fp8_single_gpu_smoke (line 19) | def test_fp8_single_gpu_smoke(self, temp_dir): FILE: tests/e2e/integrations/test_hooks.py class LogHooksPlugin (line 17) | class LogHooksPlugin(BasePlugin): method __init__ (line 24) | def __init__(self): method post_trainer_create (line 31) | def post_trainer_create(self, cfg, trainer): method pre_model_load (line 37) | def pre_model_load(self, cfg): method post_model_build (line 43) | def post_model_build(self, cfg, model): method pre_lora_load (line 49) | def pre_lora_load(self, cfg, model): method post_lora_load (line 55) | def post_lora_load(self, cfg, model): method post_model_load (line 61) | def post_model_load(self, cfg, model): method create_optimizer (line 67) | def create_optimizer(self, cfg, trainer): method get_trainer_cls (line 73) | def get_trainer_cls(self, cfg): method create_lr_scheduler (line 79) | def create_lr_scheduler(self, cfg, trainer, optimizer, num_training_st... method add_callbacks_pre_trainer (line 85) | def add_callbacks_pre_trainer(self, cfg, model): method add_callbacks_post_trainer (line 92) | def add_callbacks_post_trainer(self, cfg, trainer): method post_train (line 99) | def post_train(self, cfg, model): method post_train_unload (line 105) | def post_train_unload(self, cfg): class TestPluginHooks (line 112) | class TestPluginHooks: method test_plugin_hooks (line 117) | def test_plugin_hooks(self, temp_dir): FILE: tests/e2e/integrations/test_kd.py function min_cfg (line 17) | def min_cfg(temp_dir): class TestKnowledgeDistillation (line 73) | class TestKnowledgeDistillation: method test_llama_kd (line 81) | def test_llama_kd(self, temp_dir, kd_min_cfg): method test_llama_lora_kd (line 110) | def test_llama_lora_kd(self, temp_dir, kd_min_cfg, load_in_8bit): FILE: tests/e2e/integrations/test_liger.py class LigerIntegrationTestCase (line 15) | class LigerIntegrationTestCase: method test_llama_wo_flce (line 21) | def test_llama_wo_flce(self, temp_dir): method test_llama_w_flce (line 70) | def test_llama_w_flce(self, temp_dir, liger_use_token_scaling): FILE: tests/e2e/integrations/test_llm_compressor.py class TestLLMCompressorIntegration (line 32) | class TestLLMCompressorIntegration: method test_llmcompressor_plugin (line 39) | def test_llmcompressor_plugin( function _check_llmcompressor_model_outputs (line 100) | def _check_llmcompressor_model_outputs(temp_dir, save_compressed): FILE: tests/e2e/integrations/test_scattermoe_lora_kernels.py function flatten_sort_count_ref (line 42) | def flatten_sort_count_ref(expert_idxs: torch.Tensor, num_experts: int): function reference_parallel_linear_lora (line 52) | def reference_parallel_linear_lora( function reference_lora_backward (line 111) | def reference_lora_backward( function make_test_data (line 166) | def make_test_data( class TestForwardPass (line 212) | class TestForwardPass: method _run_forward_test (line 215) | def _run_forward_test( method test_basic (line 250) | def test_basic(self): method test_topk2 (line 254) | def test_topk2(self): method test_larger_rank (line 258) | def test_larger_rank(self): method test_small_rank (line 262) | def test_small_rank(self): method test_many_experts (line 266) | def test_many_experts(self): method test_non_power_of_2_dims (line 270) | def test_non_power_of_2_dims(self): method test_single_token (line 274) | def test_single_token(self): method test_bf16 (line 278) | def test_bf16(self): method test_fp16 (line 284) | def test_fp16(self): class TestForwardGrouped (line 291) | class TestForwardGrouped: method _make_grouped_data (line 294) | def _make_grouped_data(self, M=32, K=64, N=128, E=4, R=8, k=2, dtype=t... method test_x_grouped (line 306) | def test_x_grouped(self): method test_y_grouped (line 340) | def test_y_grouped(self): class TestLoRAGradients (line 380) | class TestLoRAGradients: method _run_lora_grad_test (line 383) | def _run_lora_grad_test(self, M, K, N, E, R, k, atol=1e-2, rtol=1e-2): method test_basic_lora_grads (line 431) | def test_basic_lora_grads(self): method test_small_rank (line 434) | def test_small_rank(self): method test_larger_rank (line 437) | def test_larger_rank(self): method test_many_experts (line 442) | def test_many_experts(self): method test_single_token_per_expert (line 445) | def test_single_token_per_expert(self): class TestAutograd (line 455) | class TestAutograd: method test_lora_receives_gradients (line 458) | def test_lora_receives_gradients(self): method test_input_gradient_matches_reference (line 501) | def test_input_gradient_matches_reference(self): method test_lora_gradient_matches_reference (line 561) | def test_lora_gradient_matches_reference(self): class TestBaseEquivalence (line 623) | class TestBaseEquivalence: method test_zero_scaling_matches_base (line 626) | def test_zero_scaling_matches_base(self): method test_zero_lora_weights_matches_base (line 656) | def test_zero_lora_weights_matches_base(self): class TestLoRAAdditivity (line 695) | class TestLoRAAdditivity: method test_lora_additivity (line 698) | def test_lora_additivity(self): class TestParallelExpertsModule (line 752) | class TestParallelExpertsModule: method test_set_and_clear_lora (line 755) | def test_set_and_clear_lora(self): method test_forward_with_lora (line 775) | def test_forward_with_lora(self): class TestEdgeCases (line 819) | class TestEdgeCases: method test_all_tokens_one_expert (line 822) | def test_all_tokens_one_expert(self): method test_empty_experts (line 866) | def test_empty_experts(self): class TestFusedDX (line 916) | class TestFusedDX: method _run_fused_dX_test (line 919) | def _run_fused_dX_test( method test_basic (line 980) | def test_basic(self): method test_large (line 983) | def test_large(self): method test_single_expert (line 986) | def test_single_expert(self): method test_k1 (line 989) | def test_k1(self): method test_bf16 (line 992) | def test_bf16(self): method test_grouped_output (line 1005) | def test_grouped_output(self): method test_autograd_with_fused_dX (line 1058) | def test_autograd_with_fused_dX(self): class TestFusedGatherBackward (line 1125) | class TestFusedGatherBackward: method _run_fused_gather_test (line 1128) | def _run_fused_gather_test( method test_basic (line 1174) | def test_basic(self): method test_large (line 1177) | def test_large(self): method test_single_expert (line 1180) | def test_single_expert(self): method test_k1 (line 1183) | def test_k1(self): method test_many_experts (line 1186) | def test_many_experts(self): method test_bf16 (line 1189) | def test_bf16(self): method test_autograd_with_fused_gather (line 1202) | def test_autograd_with_fused_gather(self): class TestTokenRounding (line 1272) | class TestTokenRounding: method test_round_expert_counts_basic (line 1275) | def test_round_expert_counts_basic(self): method test_round_with_fused_gather (line 1318) | def test_round_with_fused_gather(self): method test_empty_experts_with_rounding (line 1373) | def test_empty_experts_with_rounding(self): class TestCombinedOptimizations (line 1417) | class TestCombinedOptimizations: method test_fused_dX_and_fused_gather (line 1420) | def test_fused_dX_and_fused_gather(self): function _reference_moe_forward (line 1489) | def _reference_moe_forward( function _make_mock_sigmoid_moe_block (line 1540) | def _make_mock_sigmoid_moe_block( class TestHFScatterMoESigmoidRouting (line 1586) | class TestHFScatterMoESigmoidRouting: method test_forward_matches_reference_bias_on_gate (line 1589) | def test_forward_matches_reference_bias_on_gate(self): method test_forward_matches_reference_bias_on_block (line 1632) | def test_forward_matches_reference_bias_on_block(self): method test_softmax_routing_still_works (line 1672) | def test_softmax_routing_still_works(self): class TestHFScatterMoESigmoidWithSharedExperts (line 1727) | class TestHFScatterMoESigmoidWithSharedExperts: method test_shared_experts_plural (line 1730) | def test_shared_experts_plural(self): method test_shared_expert_with_gate (line 1786) | def test_shared_expert_with_gate(self): FILE: tests/e2e/integrations/test_scattermoe_lora_olmoe.py function peft_lora_B_to_scattermoe (line 47) | def peft_lora_B_to_scattermoe(peft_B, num_experts, rank): function peft_lora_to_scattermoe (line 56) | def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): function _unwrap_experts_lora (line 77) | def _unwrap_experts_lora(experts_module): function _unwrap_gate_lora (line 80) | def _unwrap_gate_lora(gate_module): function make_olmoe_config (line 125) | def make_olmoe_config(use_full=False): function scattermoe_lora_B_to_peft (line 136) | def scattermoe_lora_B_to_peft(smoe_B, num_experts, rank): function peft_gate_up_lora_to_scattermoe (line 147) | def peft_gate_up_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): function _init_expert_weights (line 161) | def _init_expert_weights(moe_block): class MinimalOLMoEModel (line 172) | class MinimalOLMoEModel(nn.Module): method __init__ (line 175) | def __init__(self, config): method forward (line 180) | def forward(self, x): function _get_routing (line 184) | def _get_routing(moe_block, hidden_states): function _reference_moe_forward (line 193) | def _reference_moe_forward( function _reference_moe_forward_with_lora (line 223) | def _reference_moe_forward_with_lora( function _compute_delta_from_scattermoe_lora (line 248) | def _compute_delta_from_scattermoe_lora(lora_A, lora_B, scaling, E, r, p... class TestLoRABLayoutConversion (line 266) | class TestLoRABLayoutConversion: method test_roundtrip (line 269) | def test_roundtrip(self): method test_per_expert_slices (line 276) | def test_per_expert_slices(self): method test_lora_A_already_compatible (line 290) | def test_lora_A_already_compatible(self): method test_delta_weight_equivalence (line 301) | def test_delta_weight_equivalence(self): method test_down_proj_conversion (line 319) | def test_down_proj_conversion(self): method test_gate_up_proj_conversion (line 341) | def test_gate_up_proj_conversion(self): class TestPeftLoRAWeightExtraction (line 392) | class TestPeftLoRAWeightExtraction: method test_peft_creates_correct_shapes (line 395) | def test_peft_creates_correct_shapes(self): method test_peft_forward_runs (line 443) | def test_peft_forward_runs(self): method test_unwrap_experts_lora (line 466) | def test_unwrap_experts_lora(self): method test_unwrap_no_lora (line 515) | def test_unwrap_no_lora(self): method test_unwrap_gate_lora (line 524) | def test_unwrap_gate_lora(self): method test_unwrap_gate_no_lora (line 560) | def test_unwrap_gate_no_lora(self): method test_gate_lora_delta_matches_peft (line 569) | def test_gate_lora_delta_matches_peft(self): class TestOLMoEReferenceVsScatterMoE (line 609) | class TestOLMoEReferenceVsScatterMoE: method test_small (line 612) | def test_small(self): method test_full (line 616) | def test_full(self): method _run (line 619) | def _run(self, use_full, M): class TestOLMoEPeftLoRAForward (line 684) | class TestOLMoEPeftLoRAForward: method test_small (line 687) | def test_small(self): method test_full (line 691) | def test_full(self): method _run (line 694) | def _run(self, use_full, M, r): class TestOLMoEPeftLoRABackward (line 793) | class TestOLMoEPeftLoRABackward: method test_small (line 796) | def test_small(self): method _run (line 799) | def _run(self, use_full, M, r): class TestKernelizeIntegration (line 907) | class TestKernelizeIntegration: method _get_kernelize_imports (line 911) | def _get_kernelize_imports(): method _get_repo_path (line 933) | def _get_repo_path(): method _setup_kernels (line 945) | def _setup_kernels( method test_base_forward_via_kernelize (line 974) | def test_base_forward_via_kernelize(self): method test_lora_forward_via_kernelize (line 1023) | def test_lora_forward_via_kernelize(self): method test_gate_lora_forward_via_kernelize (line 1072) | def test_gate_lora_forward_via_kernelize(self): class TestSharedExpertHandling (line 1131) | class TestSharedExpertHandling: method _make_shared_expert_block (line 1135) | def _make_shared_expert_block(config): method test_shared_expert_is_used (line 1160) | def test_shared_expert_is_used(self): method test_shared_expert_forward_via_kernelize (line 1180) | def test_shared_expert_forward_via_kernelize(self): FILE: tests/e2e/integrations/test_sonicmoe.py function _create_tiny_qwen3_config (line 31) | def _create_tiny_qwen3_config(): function _interleave_gate_up_weights (line 52) | def _interleave_gate_up_weights(model): function _unpatch_sonicmoe (line 64) | def _unpatch_sonicmoe(): class TestSonicMoEForwardCorrectness (line 74) | class TestSonicMoEForwardCorrectness: method teardown_method (line 77) | def teardown_method(self): method test_forward_output_matches (line 80) | def test_forward_output_matches(self): class TestSonicMoEGradientCorrectness (line 110) | class TestSonicMoEGradientCorrectness: method teardown_method (line 113) | def teardown_method(self): method test_gradients_match (line 116) | def test_gradients_match(self): method test_router_weights_receive_gradients (line 190) | def test_router_weights_receive_gradients(self): class TestSonicMoETrainingConvergence (line 216) | class TestSonicMoETrainingConvergence: method teardown_method (line 219) | def teardown_method(self): method test_loss_decreases (line 222) | def test_loss_decreases(self): method test_expert_weights_update (line 253) | def test_expert_weights_update(self): FILE: tests/e2e/kernels/test_geglu.py function test_geglu_forward_shape (line 10) | def test_geglu_forward_shape(): function test_geglu_forward_values (line 27) | def test_geglu_forward_values(torch_seed): function test_geglu_backward (line 48) | def test_geglu_backward(torch_seed): function test_geglu_inplace_preservation (line 74) | def test_geglu_inplace_preservation(): FILE: tests/e2e/kernels/test_lora.py function mock_quantstate (line 22) | def mock_quantstate(): function sample_tensors (line 53) | def sample_tensors(): function mock_proj (line 78) | def mock_proj(): function test_get_lora_parameters (line 102) | def test_get_lora_parameters(mock_proj): function test_matmul_lora (line 126) | def test_matmul_lora(sample_tensors): function test_lora_mlp_direct (line 163) | def test_lora_mlp_direct(sample_tensors, activation_forward, activation_... function test_lora_mlp_with_adapters (line 216) | def test_lora_mlp_with_adapters( function test_lora_qkv (line 298) | def test_lora_qkv(sample_tensors): function test_lora_o (line 413) | def test_lora_o(sample_tensors): function test_with_quantization (line 440) | def test_with_quantization(sample_tensors, mock_quantstate): function test_shapes_and_dimensions (line 475) | def test_shapes_and_dimensions(batch, seq, hidden, rank, out): function test_gradient_flow (line 488) | def test_gradient_flow(sample_tensors): function test_inplace_operations (line 526) | def test_inplace_operations(sample_tensors, apply_function): FILE: tests/e2e/kernels/test_quantize.py function test_dequantize_null_state (line 9) | def test_dequantize_null_state(): function test_dequantize_shape_preservation (line 15) | def test_dequantize_shape_preservation(): function test_dequantize_transposed (line 46) | def test_dequantize_transposed(): function test_dequantize_output_tensor (line 75) | def test_dequantize_output_tensor(): FILE: tests/e2e/kernels/test_swiglu.py function test_swiglu_forward_shape (line 9) | def test_swiglu_forward_shape(): function test_swiglu_forward_values (line 21) | def test_swiglu_forward_values(): function test_swiglu_backward (line 35) | def test_swiglu_backward(): function test_swiglu_inplace_preservation (line 61) | def test_swiglu_inplace_preservation(): FILE: tests/e2e/multigpu/patched/test_sp.py class TestSequenceParallelism (line 15) | class TestSequenceParallelism: method _run_sequence_parallel_test (line 18) | def _run_sequence_parallel_test( method test_sequence_parallel_training (line 117) | def test_sequence_parallel_training( FILE: tests/e2e/multigpu/solo/test_flex.py function download_model (line 22) | def download_model(): class TestPackedFlex (line 27) | class TestPackedFlex: method test_loss_llama (line 33) | def test_loss_llama(self, temp_dir): FILE: tests/e2e/multigpu/solo/test_gdpo.py class TestGDPO (line 24) | class TestGDPO: method _utils_write_yaml_and_rewards (line 27) | def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""): method test_gdpo_multi_reward_lora (line 60) | def test_gdpo_multi_reward_lora(self, temp_dir, num_gpus): method test_gdpo_three_rewards (line 159) | def test_gdpo_three_rewards(self, temp_dir): method test_gdpo_single_reward_fallback (line 256) | def test_gdpo_single_reward_fallback(self, temp_dir): method test_gdpo_fft (line 351) | def test_gdpo_fft(self, temp_dir): method test_gdpo_sequence_parallel (line 443) | def test_gdpo_sequence_parallel(self, temp_dir): FILE: tests/e2e/multigpu/solo/test_grpo.py function start_vllm (line 25) | def start_vllm( function recursive_kill (line 130) | def recursive_kill(process: subprocess.Popen): class TestGRPO (line 145) | class TestGRPO: method _utils_write_yaml_and_rewards (line 150) | def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""): method test_llama_dora (line 177) | def test_llama_dora(self, temp_dir, num_gpus): method test_llama_lora_sp (line 270) | def test_llama_lora_sp(self, temp_dir): method test_llama_fft (line 368) | def test_llama_fft(self, temp_dir, num_gpus): FILE: tests/e2e/multigpu/test_dist_muon_fsdp2.py function verify_training_success (line 19) | def verify_training_success(temp_dir): class TestDistMuon (line 48) | class TestDistMuon: method test_fft_sft (line 52) | def test_fft_sft(self, temp_dir): method test_lora_sft (line 109) | def test_lora_sft(self, temp_dir): FILE: tests/e2e/multigpu/test_eval.py class TestMultiGPUEval (line 18) | class TestMultiGPUEval: method test_eval_sample_packing (line 23) | def test_eval_sample_packing(self, temp_dir): method test_eval (line 94) | def test_eval(self, temp_dir): FILE: tests/e2e/multigpu/test_fp8_fsdp2.py function verify_fp8_training_success (line 19) | def verify_fp8_training_success(temp_dir): class TestFP8FSDP2 (line 48) | class TestFP8FSDP2: method test_fp8_fsdp2_smoke (line 53) | def test_fp8_fsdp2_smoke(self, temp_dir): FILE: tests/e2e/multigpu/test_fsdp1.py function verify_training_success (line 20) | def verify_training_success(temp_dir): class TestFSDP1 (line 49) | class TestFSDP1: method test_fft_sft (line 56) | def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading): method test_lora_sft (line 126) | def test_lora_sft(self, temp_dir, adapter_config): method test_dpo_fft (line 190) | def test_dpo_fft(self, temp_dir): method test_dpo_lora (line 262) | def test_dpo_lora(self, temp_dir, adapter_config): FILE: tests/e2e/multigpu/test_fsdp2.py function verify_training_success (line 20) | def verify_training_success(temp_dir): class TestFSDP2 (line 49) | class TestFSDP2: method test_fft_sft (line 57) | def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading): method test_lora_sft (line 114) | def test_lora_sft(self, temp_dir, peft_use_dora): method test_lora_sft_kernels (line 180) | def test_lora_sft_kernels(self, temp_dir): method test_qlora_sft (line 243) | def test_qlora_sft(self, temp_dir): method test_qlora_sft_kernels (line 305) | def test_qlora_sft_kernels(self, temp_dir): method test_dpo_fft (line 370) | def test_dpo_fft(self, temp_dir): method test_dpo_lora (line 428) | def test_dpo_lora(self, temp_dir): FILE: tests/e2e/multigpu/test_gemma3.py function download_model (line 21) | def download_model(): class TestMultiGPUGemma3 (line 27) | class TestMultiGPUGemma3: method test_lora_ddp_packed (line 32) | def test_lora_ddp_packed(self, temp_dir): FILE: tests/e2e/multigpu/test_llama.py function download_model (line 23) | def download_model(): function transformers_version_eq (line 28) | def transformers_version_eq(required_version): class TestMultiGPULlama (line 32) | class TestMultiGPULlama: method test_lora_ddp (line 37) | def test_lora_ddp(self, temp_dir): method test_lora_ddp_packed (line 100) | def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps): method test_dpo_lora_ddp (line 162) | def test_dpo_lora_ddp(self, temp_dir): method test_dpo_qlora_ddp (line 241) | def test_dpo_qlora_ddp(self, temp_dir): method test_fsdp (line 324) | def test_fsdp(self, temp_dir, gradient_accumulation_steps): method test_fsdp_packed (line 398) | def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type): method test_fsdp2_packed (line 476) | def test_fsdp2_packed( method test_fsdp_qlora_prequant_packed (line 549) | def test_fsdp_qlora_prequant_packed(self, temp_dir): method test_ds_zero3_packed (line 645) | def test_ds_zero3_packed( method test_ds_zero2_packed (line 722) | def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, ... method test_ds_zero1_packed (line 798) | def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, ... method test_fix_untrained_tokens (line 868) | def test_fix_untrained_tokens(self, temp_dir): FILE: tests/e2e/multigpu/test_locking.py class TestFileLockLoader (line 15) | class TestFileLockLoader: method temp_dir (line 19) | def temp_dir(self): method cfg (line 25) | def cfg(self, temp_dir): method loader (line 30) | def loader(self, cfg): method test_load_first_process (line 34) | def test_load_first_process(self, loader): method test_load_subsequent_process (line 47) | def test_load_subsequent_process(self, loader): method test_load_concurrent_processes (line 60) | def test_load_concurrent_processes(self, cfg): method test_load_waiting_for_ready_flag (line 90) | def test_load_waiting_for_ready_flag(self, mock_sleep, loader): method test_complete_workflow_with_cleanup (line 128) | def test_complete_workflow_with_cleanup(self, loader): method test_multiple_processes_workflow (line 143) | def test_multiple_processes_workflow(self, loader): method test_load_exception_handling (line 164) | def test_load_exception_handling(self, loader): method test_file_lock_called (line 176) | def test_file_lock_called(self, loader): FILE: tests/e2e/multigpu/test_ray.py class TestMultiGPURay (line 21) | class TestMultiGPURay: method test_lora_ddp (line 27) | def test_lora_ddp(self, temp_dir): method test_ds_zero2_packed (line 90) | def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps): method test_sft_fsdp2_packed (line 149) | def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps): FILE: tests/e2e/multigpu/test_tp.py class TestTensorParallel (line 14) | class TestTensorParallel: method test_fft_sft (line 21) | def test_fft_sft(self, temp_dir): FILE: tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py function init_accelerate (line 58) | def init_accelerate(): function small_llama_model (line 64) | def small_llama_model(): function test_attention_patching_integration (line 84) | def test_attention_patching_integration(model_name, attention_cls): function test_swiglu_mlp_integration (line 109) | def test_swiglu_mlp_integration(small_llama_model): function test_geglu_model_integration (line 159) | def test_geglu_model_integration(): function test_model_specific_activation (line 202) | def test_model_specific_activation(model_name, expected_activation): function test_kernel_patch_conditions (line 224) | def test_kernel_patch_conditions(): function test_kernel_config_options (line 264) | def test_kernel_config_options(): function get_lora_config (line 338) | def get_lora_config(): function get_test_inputs (line 351) | def get_test_inputs(model, seq_length=20): function test_model_architecture (line 363) | def test_model_architecture(model_config): function test_kernel_training_integration (line 396) | def test_kernel_training_integration(temp_dir): function test_kernel_training_integration_auto_enable (line 442) | def test_kernel_training_integration_auto_enable(temp_dir): function test_kernel_training_integration_dropout_non_zero (line 513) | def test_kernel_training_integration_dropout_non_zero(temp_dir): FILE: tests/e2e/patched/test_4d_multipack_llama.py class Test4dMultipackLlama (line 15) | class Test4dMultipackLlama(unittest.TestCase): method test_sdp_lora_packing (line 21) | def test_sdp_lora_packing(self, temp_dir): method test_torch_lora_packing (line 68) | def test_torch_lora_packing(self, temp_dir): FILE: tests/e2e/patched/test_activation_checkpointing.py function fix_checkpoint_after_test (line 18) | def fix_checkpoint_after_test(): class TestActivationCheckpointing (line 23) | class TestActivationCheckpointing: method test_activation_checkpointing_offload (line 32) | def test_activation_checkpointing_offload( FILE: tests/e2e/patched/test_cli_integrations.py class TestPluginArgs (line 13) | class TestPluginArgs: method test_liger_plugin_args (line 18) | def test_liger_plugin_args(self, temp_dir): FILE: tests/e2e/patched/test_fa_xentropy.py class TestFAXentropyLlama (line 16) | class TestFAXentropyLlama: method test_lora_packing_fa_cross_entropy (line 25) | def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumu... FILE: tests/e2e/patched/test_falcon_samplepack.py class TestFalconPatched (line 17) | class TestFalconPatched(unittest.TestCase): method test_qlora (line 24) | def test_qlora(self, temp_dir): method test_ft (line 72) | def test_ft(self, temp_dir): FILE: tests/e2e/patched/test_flattening.py class TestFAFlattening (line 16) | class TestFAFlattening: method test_lora_packing_flattening (line 25) | def test_lora_packing_flattening(self, temp_dir, gradient_accumulation... FILE: tests/e2e/patched/test_fsdp2_qlora.py class TestFSDPPatchIntegration (line 7) | class TestFSDPPatchIntegration: method test_fsdp2_init_patches (line 11) | def test_fsdp2_init_patches(self): FILE: tests/e2e/patched/test_fused_llama.py class TestFusedLlama (line 19) | class TestFusedLlama(unittest.TestCase): method test_fft_packing (line 25) | def test_fft_packing(self, temp_dir): FILE: tests/e2e/patched/test_llama_s2_attention.py class TestLlamaShiftedSparseAttention (line 18) | class TestLlamaShiftedSparseAttention(unittest.TestCase): method test_lora_s2_attn (line 24) | def test_lora_s2_attn(self, temp_dir): method test_fft_s2_attn (line 72) | def test_fft_s2_attn(self, temp_dir): FILE: tests/e2e/patched/test_lora_llama_multipack.py class TestLoraLlama (line 17) | class TestLoraLlama(unittest.TestCase): method test_lora_packing (line 23) | def test_lora_packing(self, temp_dir): FILE: tests/e2e/patched/test_mistral_samplepack.py class TestMistral (line 15) | class TestMistral(unittest.TestCase): method test_lora_packing (line 22) | def test_lora_packing(self, temp_dir): method test_ft_packing (line 69) | def test_ft_packing(self, temp_dir): FILE: tests/e2e/patched/test_mixtral_samplepack.py class TestMixtral (line 15) | class TestMixtral(unittest.TestCase): method test_qlora (line 21) | def test_qlora(self, temp_dir): method test_ft (line 65) | def test_ft(self, temp_dir): FILE: tests/e2e/patched/test_model_patches.py class TestModelPatches (line 16) | class TestModelPatches(unittest.TestCase): method test_mixtral_multipack (line 22) | def test_mixtral_multipack(self, temp_dir): method test_mistral_multipack (line 57) | def test_mistral_multipack(self, temp_dir): FILE: tests/e2e/patched/test_peft_embeddings.py class TestLlamaPeftEmbeddings (line 12) | class TestLlamaPeftEmbeddings: method test_peft_embeddings_upcast (line 17) | def test_peft_embeddings_upcast(self, temp_dir): FILE: tests/e2e/patched/test_phi_multipack.py class TestPhiMultipack (line 15) | class TestPhiMultipack(unittest.TestCase): method test_ft_packed (line 21) | def test_ft_packed(self, temp_dir): method test_qlora_packed (line 68) | def test_qlora_packed(self, temp_dir): FILE: tests/e2e/patched/test_resume.py class TestResumeLlama (line 20) | class TestResumeLlama: method test_resume_lora_packed (line 26) | def test_resume_lora_packed(self, temp_dir): FILE: tests/e2e/patched/test_unsloth_integration.py class TestUnslothIntegration (line 11) | class TestUnslothIntegration(unittest.TestCase): method test_is_self_attn_patchable (line 14) | def test_is_self_attn_patchable(self): FILE: tests/e2e/patched/test_unsloth_qlora.py class TestUnslothQLoRA (line 18) | class TestUnslothQLoRA: method test_unsloth_llama_qlora_fa2 (line 27) | def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing): method test_unsloth_llama_qlora_unpacked (line 79) | def test_unsloth_llama_qlora_unpacked(self, temp_dir): method test_unsloth_llama_qlora_unpacked_no_fa2_fp16 (line 134) | def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_... FILE: tests/e2e/solo/test_flex.py class TestPackedFlex (line 17) | class TestPackedFlex(unittest.TestCase): method test_loss_llama (line 24) | def test_loss_llama(self, temp_dir): FILE: tests/e2e/solo/test_relora_llama.py class TestReLoraLlama (line 16) | class TestReLoraLlama(unittest.TestCase): method test_relora (line 22) | def test_relora(self, temp_dir): FILE: tests/e2e/test_activation_offloading.py class TestActivationOffloading (line 15) | class TestActivationOffloading: method test_activation_offloading (line 24) | def test_activation_offloading( FILE: tests/e2e/test_deepseekv3.py class TestDeepseekV3 (line 17) | class TestDeepseekV3: method test_lora_deepseekv3 (line 27) | def test_lora_deepseekv3(self, temp_dir, sample_packing): method test_fft_deepseekv3 (line 83) | def test_fft_deepseekv3(self, temp_dir, sample_packing): FILE: tests/e2e/test_diffusion.py class TestDiffusion (line 11) | class TestDiffusion: method test_diffusion_smoke_test (line 14) | def test_diffusion_smoke_test(self, temp_dir): method test_diffusion_sft_labels (line 72) | def test_diffusion_sft_labels(self, temp_dir): FILE: tests/e2e/test_dpo.py class TestDPOLlamaLora (line 17) | class TestDPOLlamaLora(unittest.TestCase): method test_dpo_lora (line 23) | def test_dpo_lora(self, temp_dir): method test_dpo_nll_lora (line 71) | def test_dpo_nll_lora(self, temp_dir): method test_dpo_use_weighting (line 120) | def test_dpo_use_weighting(self, temp_dir): method test_kto_pair_lora (line 170) | def test_kto_pair_lora(self, temp_dir): method test_ipo_lora (line 218) | def test_ipo_lora(self, temp_dir): method test_orpo_lora (line 266) | def test_orpo_lora(self, temp_dir): method test_kto_lora (line 318) | def test_kto_lora(self, temp_dir): FILE: tests/e2e/test_embeddings_lr.py class TestEmbeddingsLrScale (line 15) | class TestEmbeddingsLrScale(unittest.TestCase): method test_train_w_embedding_lr_scale (line 21) | def test_train_w_embedding_lr_scale(self, temp_dir): method test_train_w_embedding_lr (line 65) | def test_train_w_embedding_lr(self, temp_dir): FILE: tests/e2e/test_evaluate.py class TestE2eEvaluate (line 12) | class TestE2eEvaluate: method test_evaluate (line 15) | def test_evaluate(self, temp_dir): FILE: tests/e2e/test_falcon.py class TestFalcon (line 17) | class TestFalcon(unittest.TestCase): method test_lora (line 24) | def test_lora(self, temp_dir): method test_lora_added_vocab (line 75) | def test_lora_added_vocab(self, temp_dir): method test_ft (line 130) | def test_ft(self, temp_dir): FILE: tests/e2e/test_gemma2.py class TestGemma2 (line 15) | class TestGemma2: method test_lora_gemma2 (line 24) | def test_lora_gemma2(self, temp_dir, sample_packing): method test_fft_gemma2 (line 78) | def test_fft_gemma2(self, temp_dir, sample_packing): FILE: tests/e2e/test_gemma3_text.py class TestGemma3Text (line 15) | class TestGemma3Text: method test_lora_gemma3_text (line 24) | def test_lora_gemma3_text(self, temp_dir, sample_packing): method test_fft_gemma3_text (line 78) | def test_fft_gemma3_text(self, temp_dir, sample_packing): FILE: tests/e2e/test_imports.py class TestImports (line 8) | class TestImports(unittest.TestCase): method test_import_causal_trainer (line 13) | def test_import_causal_trainer(self): method test_import_rl_trainer (line 16) | def test_import_rl_trainer(self): FILE: tests/e2e/test_llama.py class TestLlama (line 15) | class TestLlama: method test_fft_trust_remote_code (line 20) | def test_fft_trust_remote_code(self, temp_dir): method test_fix_untrained_tokens (line 59) | def test_fix_untrained_tokens(self, temp_dir): method test_fix_untrained_tokens_already_trained (line 105) | def test_fix_untrained_tokens_already_trained(self, temp_dir): method test_batch_flattening (line 149) | def test_batch_flattening(self, tf32, temp_dir): FILE: tests/e2e/test_llama_pretrain.py class TestPretrainLlama (line 13) | class TestPretrainLlama: method test_pretrain (line 24) | def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_a... FILE: tests/e2e/test_llama_vision.py class TestLlamaVision (line 15) | class TestLlamaVision(unittest.TestCase): method test_lora_llama_vision_text_only_dataset (line 21) | def test_lora_llama_vision_text_only_dataset(self, temp_dir): method test_lora_llama_vision_multimodal_dataset (line 67) | def test_lora_llama_vision_multimodal_dataset(self, temp_dir): FILE: tests/e2e/test_load_model.py function fixture_temp_dir (line 14) | def fixture_temp_dir(): class TestLoadModelUtils (line 20) | class TestLoadModelUtils: method setup_method (line 25) | def setup_method(self): method test_convert_embedding_modules_dtype (line 71) | def test_convert_embedding_modules_dtype( FILE: tests/e2e/test_lora_llama.py class TestLoraLlama (line 15) | class TestLoraLlama(unittest.TestCase): method test_lora (line 21) | def test_lora(self, temp_dir): FILE: tests/e2e/test_mamba.py class TestMamba (line 18) | class TestMamba(unittest.TestCase): method test_fft (line 24) | def test_fft(self, temp_dir): FILE: tests/e2e/test_mistral.py class TestMistral (line 17) | class TestMistral(unittest.TestCase): method test_lora (line 23) | def test_lora(self, temp_dir): method test_ft (line 69) | def test_ft(self, temp_dir): FILE: tests/e2e/test_mixtral.py class TestMixtral (line 18) | class TestMixtral(unittest.TestCase): method test_qlora_w_fa2 (line 24) | def test_qlora_w_fa2(self, temp_dir): method test_qlora_wo_fa2 (line 79) | def test_qlora_wo_fa2(self, temp_dir): method test_16bit_lora_w_fa2 (line 134) | def test_16bit_lora_w_fa2(self, temp_dir): method test_16bit_lora_wo_fa2 (line 192) | def test_16bit_lora_wo_fa2(self, temp_dir): method test_ft (line 250) | def test_ft(self, temp_dir): FILE: tests/e2e/test_optimizers.py class TestCustomOptimizers (line 23) | class TestCustomOptimizers(unittest.TestCase): method test_optimi_adamw (line 29) | def test_optimi_adamw(self, temp_dir): method test_adopt_adamw (line 74) | def test_adopt_adamw(self, temp_dir): method test_muon (line 119) | def test_muon(self, temp_dir): method test_dion (line 165) | def test_dion(self, temp_dir): method test_fft_schedule_free_adamw (line 206) | def test_fft_schedule_free_adamw(self, temp_dir): method test_came_pytorch (line 243) | def test_came_pytorch(self, temp_dir): function test_flash_optimizers (line 300) | def test_flash_optimizers(tmp_path, optimizer_name, expected_class, lear... FILE: tests/e2e/test_packing_loss.py class TestPackedLlama (line 17) | class TestPackedLlama(unittest.TestCase): method test_loss_packed (line 23) | def test_loss_packed(self, temp_dir): FILE: tests/e2e/test_phi.py class TestPhi (line 15) | class TestPhi(unittest.TestCase): method test_phi_ft (line 21) | def test_phi_ft(self, temp_dir): method test_phi_qlora (line 66) | def test_phi_qlora(self, temp_dir): FILE: tests/e2e/test_preprocess.py class TestPreprocess (line 13) | class TestPreprocess: method test_w_deepspeed (line 16) | def test_w_deepspeed(self, temp_dir): FILE: tests/e2e/test_process_reward_model_smollm2.py class TestProcessRewardSmolLM2 (line 15) | class TestProcessRewardSmolLM2(unittest.TestCase): method test_prm (line 21) | def test_prm(self, temp_dir): FILE: tests/e2e/test_profiler.py function fixture_profiler_base_cfg (line 16) | def fixture_profiler_base_cfg(): class TestProfiler (line 45) | class TestProfiler: method test_profiler_saves (line 50) | def test_profiler_saves(self, profiler_base_cfg, temp_dir): method test_profiler_saves_w_start (line 64) | def test_profiler_saves_w_start(self, profiler_base_cfg, temp_dir): method test_profiler_saves_past_end (line 83) | def test_profiler_saves_past_end( method test_profiler_never_started (line 100) | def test_profiler_never_started(self, profiler_base_cfg, temp_dir): FILE: tests/e2e/test_qat.py class TestQATLlama (line 17) | class TestQATLlama: method test_qat (line 22) | def test_qat(self, temp_dir): method test_qat_dpo (line 70) | def test_qat_dpo(self, temp_dir): class TestMXFP4Schema (line 137) | class TestMXFP4Schema: method test_validate_mxfp4_dtype (line 140) | def test_validate_mxfp4_dtype(self): method test_qat_config_with_mxfp4 (line 144) | def test_qat_config_with_mxfp4(self): method test_qat_config_mxfp4_invalid_group_size (line 154) | def test_qat_config_mxfp4_invalid_group_size(self): FILE: tests/e2e/test_quantization.py function model (line 39) | def model(): class TestQuantization (line 104) | class TestQuantization: method test_get_ptq_config (line 115) | def test_get_ptq_config( method test_get_ptq_config_mxfp4 (line 123) | def test_get_ptq_config_mxfp4(self): method test_get_ptq_config_mxfp4_invalid_group_size (line 130) | def test_get_ptq_config_mxfp4_invalid_group_size(self): method test_get_ptq_config_int4_weight_only (line 138) | def test_get_ptq_config_int4_weight_only(self): method test_quantize_model_for_ptq (line 150) | def test_quantize_model_for_ptq( method test_quantize_model_for_ptq_fp8 (line 183) | def test_quantize_model_for_ptq_fp8( method test_quantize_model_for_ptq_nvfp4 (line 207) | def test_quantize_model_for_ptq_nvfp4( method test_prepare_model_for_qat (line 242) | def test_prepare_model_for_qat( method test_prepare_model_for_qat_mxfp4 (line 290) | def test_prepare_model_for_qat_mxfp4( method test_convert_qat_model (line 312) | def test_convert_qat_model(self, model): class TestQuantizationCallback (line 346) | class TestQuantizationCallback: method trainer_state (line 352) | def trainer_state(self): method test_qat_callback_fake_quant_after_n_steps (line 358) | def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_st... method test_qat_callback_fake_quant_after_n_steps_is_none (line 408) | def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, tr... FILE: tests/e2e/test_qwen.py class TestE2eQwen (line 15) | class TestE2eQwen: method test_dpo (line 21) | def test_dpo(self, base_model, temp_dir): FILE: tests/e2e/test_reward_model_smollm2.py class TestRewardModelLoraSmolLM2 (line 15) | class TestRewardModelLoraSmolLM2(unittest.TestCase): method test_rm_lora (line 21) | def test_rm_lora(self, temp_dir): FILE: tests/e2e/test_save_first_step.py class TestSaveFirstStepCallback (line 18) | class TestSaveFirstStepCallback(unittest.TestCase): method test_save_first_step (line 22) | def test_save_first_step(self, temp_dir): method test_no_save_first_step (line 61) | def test_no_save_first_step(self, temp_dir): FILE: tests/e2e/test_schedulers.py class TestCustomSchedulers (line 15) | class TestCustomSchedulers(unittest.TestCase): method test_rex_scheduler (line 21) | def test_rex_scheduler(self, temp_dir): FILE: tests/e2e/test_streaming.py class TestStreamingDatasets (line 15) | class TestStreamingDatasets: method test_streaming_dataset (line 22) | def test_streaming_dataset(self, temp_dir, sample_packing): FILE: tests/e2e/test_tokenizer.py function test_tokenizer_no_save_jinja_files (line 15) | def test_tokenizer_no_save_jinja_files(temp_dir): FILE: tests/e2e/utils.py function with_temp_dir (line 20) | def with_temp_dir(test_func): function most_recent_subdir (line 35) | def most_recent_subdir(path): function require_torch_2_4_1 (line 45) | def require_torch_2_4_1(test_case): function require_torch_2_5_1 (line 57) | def require_torch_2_5_1(test_case): function require_torch_2_6_0 (line 69) | def require_torch_2_6_0(test_case): function require_torch_2_7_0 (line 81) | def require_torch_2_7_0(test_case): function require_torch_2_8_0 (line 93) | def require_torch_2_8_0(test_case): function require_torch_lt_2_6_0 (line 105) | def require_torch_lt_2_6_0(test_case): function require_vllm (line 117) | def require_vllm(test_case): function require_llmcompressor (line 130) | def require_llmcompressor(test_case): function requires_sm_ge_100 (line 143) | def requires_sm_ge_100(test_case): function requires_cuda_ge_8_9 (line 152) | def requires_cuda_ge_8_9(test_case): function is_hopper (line 161) | def is_hopper(): function require_hopper (line 166) | def require_hopper(test_case): function supports_fp8 (line 170) | def supports_fp8(test_case): function check_tensorboard (line 177) | def check_tensorboard( function check_model_output_exists (line 202) | def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None: FILE: tests/hf_offline_utils.py function reload_modules (line 10) | def reload_modules(hf_hub_offline): function enable_hf_offline (line 25) | def enable_hf_offline(test_func): function disable_hf_offline (line 56) | def disable_hf_offline(test_func): function hf_offline_context (line 88) | def hf_offline_context(hf_hub_offline): FILE: tests/integrations/test_diffusion.py function mock_tokenizer (line 16) | def mock_tokenizer(): function diffusion_config (line 26) | def diffusion_config(): function diffusion_trainer_instance (line 41) | def diffusion_trainer_instance(mock_tokenizer, diffusion_config): class TestDiffusionTrainer (line 52) | class TestDiffusionTrainer: method test_forward_process_basic (line 55) | def test_forward_process_basic(self, diffusion_trainer_instance): method test_forward_process_with_labels (line 78) | def test_forward_process_with_labels(self, diffusion_trainer_instance): method test_forward_process_with_attention_mask (line 106) | def test_forward_process_with_attention_mask(self, diffusion_trainer_i... method test_bidirectional_attention_mask_no_packing (line 120) | def test_bidirectional_attention_mask_no_packing(self, diffusion_train... method test_bidirectional_attention_mask_with_packing (line 131) | def test_bidirectional_attention_mask_with_packing( method test_compute_loss_basic (line 152) | def test_compute_loss_basic(self, diffusion_trainer_instance): method test_compute_loss_sft (line 177) | def test_compute_loss_sft(self, diffusion_trainer_instance): method test_compute_loss_no_masked_tokens (line 205) | def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance): method test_cache_special_token_ids (line 227) | def test_cache_special_token_ids(self, mock_tokenizer): method test_cache_special_token_ids_no_tokenizer (line 234) | def test_cache_special_token_ids_no_tokenizer(self): method test_main_compute_loss_interface (line 242) | def test_main_compute_loss_interface(self, diffusion_trainer_instance): method test_missing_input_ids_raises_error (line 268) | def test_missing_input_ids_raises_error(self, diffusion_trainer_instan... FILE: tests/integrations/test_diffusion_callback.py class DummyTrainer (line 11) | class DummyTrainer: method __init__ (line 14) | def __init__(self, use_eval: bool): method get_train_dataloader (line 43) | def get_train_dataloader(self): method get_eval_dataloader (line 47) | def get_eval_dataloader(self): function test_callback_uses_correct_dataloader (line 53) | def test_callback_uses_correct_dataloader(monkeypatch, use_eval): FILE: tests/integrations/test_kd_chat_template.py class TestChatTemplateStrategyWithKDv2 (line 12) | class TestChatTemplateStrategyWithKDv2: method v2_strategy (line 16) | def v2_strategy(self): method test_v2_prepare_kd_fields_adds_target_token_ids (line 43) | def test_v2_prepare_kd_fields_adds_target_token_ids(self, v2_strategy): method test_v2_prepare_kd_fields_handles_missing_field (line 58) | def test_v2_prepare_kd_fields_handles_missing_field(self, v2_strategy): method test_v2_transform_requires_target_token_ids (line 67) | def test_v2_transform_requires_target_token_ids(self, v2_strategy): FILE: tests/integrations/test_liger.py function fixture_cfg (line 14) | def fixture_cfg(): class TestValidation (line 32) | class TestValidation: method inject_fixtures (line 40) | def inject_fixtures(self, caplog): method test_deprecated_swiglu (line 44) | def test_deprecated_swiglu(self, minimal_liger_cfg): method test_conflict_swiglu_ligergluactivation (line 63) | def test_conflict_swiglu_ligergluactivation(self, minimal_liger_cfg): method test_use_token_scaling_require_flce (line 79) | def test_use_token_scaling_require_flce(self, minimal_liger_cfg): FILE: tests/integrations/test_routing_parity.py function _require_triton (line 21) | def _require_triton(): function _make_softmax_block (line 30) | def _make_softmax_block(T=8, H=16, E=4, K=2): function _make_sigmoid_block (line 43) | def _make_sigmoid_block( function hidden_states (line 75) | def hidden_states(T, H): class TestSoftmaxRoutingParity (line 84) | class TestSoftmaxRoutingParity: method _require (line 88) | def _require(self): method test_weights_match (line 91) | def test_weights_match(self): method test_logits_not_returned_by_scattermoe (line 121) | def test_logits_not_returned_by_scattermoe(self): method test_no_renorm (line 129) | def test_no_renorm(self): method test_various_expert_counts (line 150) | def test_various_expert_counts(self): class TestSigmoidRoutingParity (line 181) | class TestSigmoidRoutingParity: method _require (line 185) | def _require(self): method test_weights_match_with_groups (line 188) | def test_weights_match_with_groups(self): method test_weights_match_no_groups (line 224) | def test_weights_match_no_groups(self): method test_bias_on_block_parity (line 252) | def test_bias_on_block_parity(self): method test_scaling_factor_parity (line 279) | def test_scaling_factor_parity(self): method test_no_renorm_parity (line 307) | def test_no_renorm_parity(self): class TestSharedExpertParity (line 341) | class TestSharedExpertParity: method _require (line 345) | def _require(self): method _get_both_fns (line 348) | def _get_both_fns(self): method test_shared_expert_singular (line 358) | def test_shared_expert_singular(self): method test_shared_experts_plural (line 366) | def test_shared_experts_plural(self): method test_shared_mlp (line 374) | def test_shared_mlp(self): method test_no_shared_expert (line 382) | def test_no_shared_expert(self): method test_shared_expert_gate_only_in_scattermoe (line 390) | def test_shared_expert_gate_only_in_scattermoe(self): class TestRouteDispatcherParity (line 425) | class TestRouteDispatcherParity: method _require (line 429) | def _require(self): method test_route_dispatches_softmax (line 432) | def test_route_dispatches_softmax(self): method test_route_dispatches_sigmoid (line 453) | def test_route_dispatches_sigmoid(self): FILE: tests/integrations/test_scattermoe_autotune_telemetry.py function _make_mock_config (line 24) | def _make_mock_config(kwargs, num_warps=4, num_stages=3): function _make_mock_kernel (line 29) | def _make_mock_kernel(cache=None): function _make_mock_lora_ops (line 36) | def _make_mock_lora_ops( function _real_lora_ops_module_names (line 49) | def _real_lora_ops_module_names(): class TestAutotuneCollector (line 70) | class TestAutotuneCollector: method test_empty_cache_returns_empty_list (line 78) | def test_empty_cache_returns_empty_list(self): method test_populated_cache_returns_configs (line 90) | def test_populated_cache_returns_configs(self): method test_multiple_kernels_and_keys (line 113) | def test_multiple_kernels_and_keys(self): method test_extra_key_elements_stored (line 135) | def test_extra_key_elements_stored(self): method test_no_module_in_sys_modules_returns_empty (line 156) | def test_no_module_in_sys_modules_returns_empty(self): method test_finds_module_under_hash_suffixed_name (line 166) | def test_finds_module_under_hash_suffixed_name(self): class TestAutotuneReportCallback (line 195) | class TestAutotuneReportCallback: method test_reports_once_on_first_step (line 198) | def test_reports_once_on_first_step(self): method test_retries_until_step_5_then_gives_up (line 232) | def test_retries_until_step_5_then_gives_up(self): method test_reports_on_retry_when_data_arrives (line 251) | def test_reports_on_retry_when_data_arrives(self): method test_includes_gpu_info (line 292) | def test_includes_gpu_info(self): method test_skips_send_when_telemetry_disabled (line 337) | def test_skips_send_when_telemetry_disabled(self): class TestKernelsPluginCallbackRegistration (line 369) | class TestKernelsPluginCallbackRegistration: method test_scattermoe_registers_callback (line 372) | def test_scattermoe_registers_callback(self): method test_no_scattermoe_no_callback (line 388) | def test_no_scattermoe_no_callback(self): FILE: tests/integrations/test_scattermoe_lora.py class TestKernelsArgsValidator (line 29) | class TestKernelsArgsValidator: method test_disables_lora_mlp_kernel_when_scattermoe (line 36) | def test_disables_lora_mlp_kernel_when_scattermoe(self): method test_mlp_kernel_disabled_without_lora (line 49) | def test_mlp_kernel_disabled_without_lora(self): method test_lora_mlp_kernel_false_unchanged (line 62) | def test_lora_mlp_kernel_false_unchanged(self): method test_no_change_when_scattermoe_disabled (line 74) | def test_no_change_when_scattermoe_disabled(self): class TestParallelExpertsScaling (line 87) | class TestParallelExpertsScaling: method test_scaling_zero_preserved (line 90) | def test_scaling_zero_preserved(self): method test_scaling_none_defaults_to_one (line 125) | def test_scaling_none_defaults_to_one(self): method test_scaling_positive_preserved (line 154) | def test_scaling_positive_preserved(self): class TestSingle2ScatterBounds (line 192) | class TestSingle2ScatterBounds: method test_non_aligned_k (line 195) | def test_non_aligned_k(self): method test_non_aligned_n (line 215) | def test_non_aligned_n(self): method test_non_aligned_both (line 234) | def test_non_aligned_both(self): class TestGroupCoeffNone (line 260) | class TestGroupCoeffNone: method test_group_with_none_coeff (line 263) | def test_group_with_none_coeff(self): method test_group_with_coeff (line 275) | def test_group_with_coeff(self): class TestLayerReturnValues (line 293) | class TestLayerReturnValues: method test_hf_scatter_moe_returns_single_tensor (line 296) | def test_hf_scatter_moe_returns_single_tensor(self): method test_scatter_moe_gated_mlp_docstring_no_router_logits (line 312) | def test_scatter_moe_gated_mlp_docstring_no_router_logits(self): function _make_softmax_gate (line 333) | def _make_softmax_gate(E=4, H=16, K=2): function _make_sigmoid_gate_with_bias (line 343) | def _make_sigmoid_gate_with_bias(E=16, H=16): function _make_sigmoid_moe_block (line 351) | def _make_sigmoid_moe_block( function _skip_without_triton (line 383) | def _skip_without_triton(): class TestSigmoidRoutingInScatterMoE (line 387) | class TestSigmoidRoutingInScatterMoE: method _require_triton (line 391) | def _require_triton(self): method test_output_shapes (line 394) | def test_output_shapes(self): method test_weights_nonnegative (line 412) | def test_weights_nonnegative(self): method test_group_selection_restricts_experts (line 426) | def test_group_selection_restricts_experts(self): method test_scaling_factor_applied (line 448) | def test_scaling_factor_applied(self): method test_bias_on_gate (line 468) | def test_bias_on_gate(self): method test_bias_on_block (line 483) | def test_bias_on_block(self): method test_gate_lora_delta_applied (line 498) | def test_gate_lora_delta_applied(self): method test_no_bias_does_not_crash (line 520) | def test_no_bias_does_not_crash(self): method test_missing_topk_group_defaults_to_n_group (line 546) | def test_missing_topk_group_defaults_to_n_group(self): class TestRoutingStrategyDetection (line 576) | class TestRoutingStrategyDetection: method _require_triton (line 580) | def _require_triton(self): method test_softmax_for_qwen_style (line 583) | def test_softmax_for_qwen_style(self): method test_sigmoid_for_glm_style (line 602) | def test_sigmoid_for_glm_style(self): method test_sigmoid_for_minimax_m2_style (line 618) | def test_sigmoid_for_minimax_m2_style(self): class TestGenericSharedExpert (line 639) | class TestGenericSharedExpert: method _require_triton (line 643) | def _require_triton(self): method test_shared_expert_singular (line 646) | def test_shared_expert_singular(self): method test_shared_experts_plural (line 659) | def test_shared_experts_plural(self): method test_shared_mlp (line 672) | def test_shared_mlp(self): method test_shared_expert_with_gate (line 685) | def test_shared_expert_with_gate(self): method test_no_shared_expert (line 703) | def test_no_shared_expert(self): FILE: tests/integrations/test_scattermoe_lora_kernels.py function _requires_cuda (line 36) | def _requires_cuda(): function _setup (line 48) | def _setup(E, K, N, T, top_k, R, seed=42): function _reference_fwd (line 61) | def _reference_fwd(x, W, sei, ssi, eo, k, lora_A, lora_B, scaling, E): function _reference_dX (line 82) | def _reference_dX(dy_grouped, W, sei, ssi, eo, lora_A, lora_B, scaling, E): function _reference_bwd_lora (line 102) | def _reference_bwd_lora(dy, grouped_x, lora_A, lora_B, eo, E, scaling): class TestScatter2ScatterLoRAForward (line 142) | class TestScatter2ScatterLoRAForward: method config (line 146) | def config(self, request): method test_matches_reference (line 149) | def test_matches_reference(self, config): method test_output_shape (line 168) | def test_output_shape(self, config): class TestScatter2ScatterLoRADX (line 189) | class TestScatter2ScatterLoRADX: method config (line 193) | def config(self, request): method test_matches_reference (line 196) | def test_matches_reference(self, config): class TestGroupBwdLoRA (line 223) | class TestGroupBwdLoRA: method config (line 227) | def config(self, request): method test_matches_reference (line 230) | def test_matches_reference(self, config): method test_zero_expert_tokens (line 260) | def test_zero_expert_tokens(self): class TestScatterMoELoRAAutograd (line 294) | class TestScatterMoELoRAAutograd: method config (line 298) | def config(self, request): method test_gradients_exist_and_finite (line 301) | def test_gradients_exist_and_finite(self, config): method test_split_matches_fused (line 337) | def test_split_matches_fused(self): method test_scaling_zero_gives_base_only (line 377) | def test_scaling_zero_gives_base_only(self): FILE: tests/integrations/test_sonicmoe.py class TestKernelsArgs (line 20) | class TestKernelsArgs: method test_mutual_exclusivity_raises (line 21) | def test_mutual_exclusivity_raises(self): method test_sonicmoe_only (line 25) | def test_sonicmoe_only(self): method test_scattermoe_only (line 30) | def test_scattermoe_only(self): method test_neither_set (line 35) | def test_neither_set(self): method test_disables_mlp_kernel_when_sonicmoe (line 40) | def test_disables_mlp_kernel_when_sonicmoe(self): class TestConcatenatedToInterleaved (line 47) | class TestConcatenatedToInterleaved: method sample_tensor (line 49) | def sample_tensor(self): method test_interleave_rows_alternate (line 56) | def test_interleave_rows_alternate(self, sample_tensor): method test_interleave_handles_list_input (line 74) | def test_interleave_handles_list_input(self, sample_tensor): method test_reverse_op_type (line 83) | def test_reverse_op_type(self): class TestInterleavedToConcatenated (line 89) | class TestInterleavedToConcatenated: method interleaved_tensor (line 91) | def interleaved_tensor(self): method test_deinterleave_gate_up_separated (line 101) | def test_deinterleave_gate_up_separated(self, interleaved_tensor): method test_reverse_op_type (line 118) | def test_reverse_op_type(self): class TestRoundTrip (line 124) | class TestRoundTrip: method concat_tensor (line 126) | def concat_tensor(self): method test_interleave_then_deinterleave_is_identity (line 132) | def test_interleave_then_deinterleave_is_identity(self, concat_tensor): method test_reverse_op_chain_is_identity (line 145) | def test_reverse_op_chain_is_identity(self, concat_tensor): method test_various_shapes (line 159) | def test_various_shapes(self): class TestWeightConverterRegistration (line 177) | class TestWeightConverterRegistration: method test_register_appends_interleave_op (line 178) | def test_register_appends_interleave_op(self): method test_double_registration_is_idempotent (line 196) | def test_double_registration_is_idempotent(self): method test_register_unsupported_model_type_warns (line 215) | def test_register_unsupported_model_type_warns(self): function _make_qwen_moe_block (line 220) | def _make_qwen_moe_block(T=8, H=16, E=4, K=2): function _make_glm_moe_block (line 231) | def _make_glm_moe_block(T=8, H=16, E=16, K=4, n_group=2, topk_group=1): function _make_minimax_m2_moe_block (line 249) | def _make_minimax_m2_moe_block(T=8, H=16, E=16, K=4): class TestSoftmaxTopkRouting (line 270) | class TestSoftmaxTopkRouting: method test_output_shapes (line 271) | def test_output_shapes(self): method test_scores_are_float32 (line 282) | def test_scores_are_float32(self): method test_token_indices_sorted_ascending (line 289) | def test_token_indices_sorted_ascending(self): method test_expert_indices_in_range (line 299) | def test_expert_indices_in_range(self): method test_renormalized_scores_sum_to_one (line 308) | def test_renormalized_scores_sum_to_one(self): class TestSigmoidTopkRouting (line 317) | class TestSigmoidTopkRouting: method test_output_shapes (line 318) | def test_output_shapes(self): method test_scores_are_float32 (line 329) | def test_scores_are_float32(self): method test_token_indices_sorted_ascending (line 336) | def test_token_indices_sorted_ascending(self): method test_expert_indices_in_range (line 345) | def test_expert_indices_in_range(self): method test_scores_are_nonnegative (line 354) | def test_scores_are_nonnegative(self): method test_scaling_factor_applied (line 362) | def test_scaling_factor_applied(self): method test_group_selection_restricts_experts (line 375) | def test_group_selection_restricts_experts(self): class TestMiniMaxM2SigmoidRouting (line 391) | class TestMiniMaxM2SigmoidRouting: method test_output_shapes (line 394) | def test_output_shapes(self): method test_bias_on_block_not_gate (line 406) | def test_bias_on_block_not_gate(self): FILE: tests/integrations/test_sonicmoe_gradients.py function _make_softmax_moe_block (line 20) | def _make_softmax_moe_block(weight): function _make_sigmoid_moe_block (line 31) | def _make_sigmoid_moe_block(weight, bias): class TestSoftmaxTopkRoutingGradcheck (line 46) | class TestSoftmaxTopkRoutingGradcheck: method test_gradcheck_wrt_gate_weight (line 49) | def test_gradcheck_wrt_gate_weight(self): method test_gradcheck_wrt_hidden_states (line 64) | def test_gradcheck_wrt_hidden_states(self): method test_gradcheck_wrt_router_logits (line 79) | def test_gradcheck_wrt_router_logits(self): method test_no_norm_variant (line 94) | def test_no_norm_variant(self): class TestSigmoidTopkRoutingGradcheck (line 111) | class TestSigmoidTopkRoutingGradcheck: method test_gradcheck_wrt_gate_weight (line 114) | def test_gradcheck_wrt_gate_weight(self): method test_gradcheck_wrt_hidden_states (line 130) | def test_gradcheck_wrt_hidden_states(self): method test_gradcheck_wrt_bias (line 146) | def test_gradcheck_wrt_bias(self): FILE: tests/integrations/test_swanlab.py class TestSwanLabConfigValidators (line 37) | class TestSwanLabConfigValidators: method test_valid_swanlab_mode_cloud (line 40) | def test_valid_swanlab_mode_cloud(self): method test_valid_swanlab_mode_local (line 45) | def test_valid_swanlab_mode_local(self): method test_valid_swanlab_mode_offline (line 50) | def test_valid_swanlab_mode_offline(self): method test_valid_swanlab_mode_disabled (line 55) | def test_valid_swanlab_mode_disabled(self): method test_invalid_swanlab_mode (line 60) | def test_invalid_swanlab_mode(self): method test_swanlab_mode_none_allowed (line 72) | def test_swanlab_mode_none_allowed(self): method test_valid_swanlab_project (line 77) | def test_valid_swanlab_project(self): method test_swanlab_project_none_allowed (line 82) | def test_swanlab_project_none_allowed(self): method test_empty_swanlab_project_rejected (line 87) | def test_empty_swanlab_project_rejected(self): method test_whitespace_only_project_rejected (line 95) | def test_whitespace_only_project_rejected(self): method test_use_swanlab_true_requires_project (line 103) | def test_use_swanlab_true_requires_project(self): method test_use_swanlab_true_with_project_valid (line 112) | def test_use_swanlab_true_with_project_valid(self): method test_use_swanlab_false_no_project_valid (line 118) | def test_use_swanlab_false_no_project_valid(self): method test_use_swanlab_none_no_project_valid (line 124) | def test_use_swanlab_none_no_project_valid(self): class TestSwanLabPluginRegister (line 132) | class TestSwanLabPluginRegister: method test_register_without_use_swanlab (line 135) | def test_register_without_use_swanlab(self): method test_register_use_swanlab_missing_project (line 142) | def test_register_use_swanlab_missing_project(self): method test_register_use_swanlab_with_project_valid (line 155) | def test_register_use_swanlab_with_project_valid(self): method test_register_invalid_mode (line 162) | def test_register_invalid_mode(self): method test_register_valid_modes (line 179) | def test_register_valid_modes(self): method test_register_auto_enable_swanlab (line 193) | def test_register_auto_enable_swanlab(self): method test_register_cloud_mode_without_api_key_warns (line 202) | def test_register_cloud_mode_without_api_key_warns(self, caplog): class TestMultiLoggerDetection (line 222) | class TestMultiLoggerDetection: method test_single_logger_no_warning (line 225) | def test_single_logger_no_warning(self, caplog): method test_two_loggers_warning (line 237) | def test_two_loggers_warning(self, caplog): method test_three_loggers_error (line 254) | def test_three_loggers_error(self, caplog): method test_multi_logger_with_comet (line 275) | def test_multi_logger_with_comet(self, caplog): method test_multi_logger_with_comet_project (line 291) | def test_multi_logger_with_comet_project(self, caplog): class TestSwanLabPluginPreModelLoad (line 309) | class TestSwanLabPluginPreModelLoad: method test_pre_model_load_disabled (line 312) | def test_pre_model_load_disabled(self): method test_pre_model_load_import_error (line 321) | def test_pre_model_load_import_error(self): method test_pre_model_load_non_main_process_skips (line 339) | def test_pre_model_load_non_main_process_skips( method test_pre_model_load_distributed_logging (line 357) | def test_pre_model_load_distributed_logging( class TestSwanLabInitKwargs (line 381) | class TestSwanLabInitKwargs: method test_custom_branding_added_to_config (line 384) | def test_custom_branding_added_to_config(self): method test_api_key_passed_directly (line 403) | def test_api_key_passed_directly(self): method test_private_deployment_hosts_passed_directly (line 423) | def test_private_deployment_hosts_passed_directly(self): method test_full_private_deployment_init (line 447) | def test_full_private_deployment_init(self, mock_is_main_process): method test_env_vars_not_set_for_api_params (line 484) | def test_env_vars_not_set_for_api_params(self): class TestLarkNotificationIntegration (line 525) | class TestLarkNotificationIntegration: method test_lark_callback_registration_with_webhook_only (line 528) | def test_lark_callback_registration_with_webhook_only(self): method test_lark_callback_registration_with_secret (line 564) | def test_lark_callback_registration_with_secret(self): method test_lark_callback_not_registered_without_webhook (line 598) | def test_lark_callback_not_registered_without_webhook(self): method test_lark_import_error_handled_gracefully (line 621) | def test_lark_import_error_handled_gracefully(self, caplog): method test_lark_warning_for_missing_secret (line 658) | def test_lark_warning_for_missing_secret(self, caplog): class TestSwanLabPluginIntegration (line 692) | class TestSwanLabPluginIntegration: method test_full_lifecycle_valid_config (line 695) | def test_full_lifecycle_valid_config(self): method test_lifecycle_with_multi_logger_warning (line 724) | def test_lifecycle_with_multi_logger_warning(self, caplog): method test_lifecycle_invalid_config_fails_early (line 741) | def test_lifecycle_invalid_config_fails_early(self): method test_full_lifecycle_with_lark_notifications (line 754) | def test_full_lifecycle_with_lark_notifications(self): class TestCompletionLogger (line 795) | class TestCompletionLogger: method test_completion_logger_initialization (line 798) | def test_completion_logger_initialization(self): method test_add_dpo_completion (line 806) | def test_add_dpo_completion(self): method test_add_kto_completion (line 828) | def test_add_kto_completion(self): method test_add_orpo_completion (line 850) | def test_add_orpo_completion(self): method test_add_grpo_completion (line 871) | def test_add_grpo_completion(self): method test_memory_bounded_buffer (line 892) | def test_memory_bounded_buffer(self): method test_log_to_swanlab_when_not_initialized (line 913) | def test_log_to_swanlab_when_not_initialized(self): method test_log_to_swanlab_success (line 929) | def test_log_to_swanlab_success(self): method test_clear_buffer (line 957) | def test_clear_buffer(self): method test_repr (line 973) | def test_repr(self): class TestSwanLabRLHFCompletionCallback (line 992) | class TestSwanLabRLHFCompletionCallback: method test_callback_initialization (line 995) | def test_callback_initialization(self): method test_trainer_type_detection_dpo (line 1010) | def test_trainer_type_detection_dpo(self): method test_trainer_type_detection_kto (line 1029) | def test_trainer_type_detection_kto(self): method test_on_train_end_logs_completions (line 1047) | def test_on_train_end_logs_completions(self): class TestSwanLabPluginCompletionIntegration (line 1074) | class TestSwanLabPluginCompletionIntegration: method test_completion_callback_registered_for_dpo_trainer (line 1077) | def test_completion_callback_registered_for_dpo_trainer(self): method test_completion_callback_not_registered_for_non_rlhf_trainer (line 1114) | def test_completion_callback_not_registered_for_non_rlhf_trainer(self): method test_completion_callback_not_registered_when_disabled (line 1141) | def test_completion_callback_not_registered_when_disabled(self): class TestSwanLabProfiling (line 1170) | class TestSwanLabProfiling: method test_profiling_context_logs_duration (line 1173) | def test_profiling_context_logs_duration(self): method test_profiling_context_skips_when_swanlab_disabled (line 1197) | def test_profiling_context_skips_when_swanlab_disabled(self): method test_profiling_context_skips_when_swanlab_not_initialized (line 1211) | def test_profiling_context_skips_when_swanlab_not_initialized(self): method test_profiling_decorator (line 1228) | def test_profiling_decorator(self): method test_profiling_config (line 1256) | def test_profiling_config(self): method test_profiling_config_when_disabled (line 1280) | def test_profiling_config_when_disabled(self): method test_profiling_context_advanced (line 1289) | def test_profiling_context_advanced(self): method test_profiling_with_exception (line 1318) | def test_profiling_with_exception(self): FILE: tests/monkeypatch/test_llama_attn_hijack_flash.py class TestMonkeyPatchUtils (line 17) | class TestMonkeyPatchUtils(unittest.TestCase): method test_get_cu_seqlens_1d (line 22) | def test_get_cu_seqlens_1d(self): method test_get_cu_seqlens_from_pos_ids_1d (line 27) | def test_get_cu_seqlens_from_pos_ids_1d(self): method test_get_cu_seqlens_from_pos_ids_2d (line 34) | def test_get_cu_seqlens_from_pos_ids_2d(self): method test_get_max_seqlen_in_batch (line 48) | def test_get_max_seqlen_in_batch(self): method test_get_unpad_data (line 53) | def test_get_unpad_data(self): FILE: tests/monkeypatch/test_pixtral_flash_attention_patch.py class TestPixtralFlashAttentionPatchIntegration (line 7) | class TestPixtralFlashAttentionPatchIntegration: method test_pixtral_flash_attention_patch (line 11) | def test_pixtral_flash_attention_patch(self): FILE: tests/monkeypatch/test_qwen3_next_modeling_patch.py class TestQwen3NextModelingPatchIntegration (line 10) | class TestQwen3NextModelingPatchIntegration: method test_qwen3_next_decoder_layer_patch (line 14) | def test_qwen3_next_decoder_layer_patch(self): method test_qwen3_next_gateddelta_layer_patch (line 44) | def test_qwen3_next_gateddelta_layer_patch(self): method test_qwen3_next_imports_patch (line 74) | def test_qwen3_next_imports_patch(self): method test_qwen3_next_modeling_packing_patch (line 89) | def test_qwen3_next_modeling_packing_patch(self): function test_get_cu_seqlens_utility (line 100) | def test_get_cu_seqlens_utility(): FILE: tests/monkeypatch/test_trainer_accelerator_args.py class TestTrainerAcceleratorArgs (line 12) | class TestTrainerAcceleratorArgs(unittest.TestCase): method test_check_create_accelerate_code_is_patchable (line 17) | def test_check_create_accelerate_code_is_patchable(self): FILE: tests/monkeypatch/test_trainer_context_parallel_patch.py function restore_trainer_prepare_method (line 14) | def restore_trainer_prepare_method(): function test_patch_attention_guard (line 36) | def test_patch_attention_guard(restore_trainer_prepare_method): function test_patch_is_idempotent (line 58) | def test_patch_is_idempotent(restore_trainer_prepare_method): FILE: tests/monkeypatch/test_trainer_loss_calc.py class TestTrainerLossCalc (line 11) | class TestTrainerLossCalc(unittest.TestCase): method test_trainer_loss_calc_is_patchable (line 16) | def test_trainer_loss_calc_is_patchable(self): FILE: tests/monkeypatch/test_trl_vllm.py class TestSplitTensorDict (line 19) | class TestSplitTensorDict(unittest.TestCase): method setUp (line 22) | def setUp(self): method test_scalar_int_preserved (line 27) | def test_scalar_int_preserved(self): method test_scalar_float_preserved (line 34) | def test_scalar_float_preserved(self): method test_scalar_bool_preserved (line 40) | def test_scalar_bool_preserved(self): method test_none_preserved (line 46) | def test_none_preserved(self): method test_tensor_split (line 52) | def test_tensor_split(self): method test_0d_tensor_preserved (line 61) | def test_0d_tensor_preserved(self): method test_list_split (line 67) | def test_list_split(self): class TestShuffleSequenceDict (line 74) | class TestShuffleSequenceDict(unittest.TestCase): method setUp (line 77) | def setUp(self): method test_scalar_int_preserved (line 82) | def test_scalar_int_preserved(self): method test_scalar_float_preserved (line 87) | def test_scalar_float_preserved(self): method test_scalar_bool_preserved (line 92) | def test_scalar_bool_preserved(self): method test_none_preserved (line 97) | def test_none_preserved(self): method test_tensor_permuted (line 102) | def test_tensor_permuted(self): method test_list_permuted (line 111) | def test_list_permuted(self): method test_0d_tensor_preserved (line 118) | def test_0d_tensor_preserved(self): class TestExtractLogprobs (line 124) | class TestExtractLogprobs(unittest.TestCase): method setUp (line 127) | def setUp(self): method _make_output (line 132) | def _make_output(self, logprob_values): method test_nan_replaced_with_zero (line 155) | def test_nan_replaced_with_zero(self): method test_normal_values_preserved (line 163) | def test_normal_values_preserved(self): method test_none_logprobs_returns_none (line 169) | def test_none_logprobs_returns_none(self): method test_token_ids_extracted (line 183) | def test_token_ids_extracted(self): class TestPatchApplication (line 189) | class TestPatchApplication(unittest.TestCase): method test_batch_update_added_to_client (line 192) | def test_batch_update_added_to_client(self): method test_extract_logprobs_patched (line 200) | def test_extract_logprobs_patched(self): method test_utils_patched (line 211) | def test_utils_patched(self): method test_patch_idempotent (line 226) | def test_patch_idempotent(self): class TestBatchUpdateChunking (line 236) | class TestBatchUpdateChunking(unittest.TestCase): method test_no_chunk_single_batch (line 239) | def test_no_chunk_single_batch(self): method test_chunk_splits_params (line 259) | def test_chunk_splits_params(self): FILE: tests/monkeypatch/test_voxtral_modeling_patch.py class TestVoxtralModelingPatchIntegration (line 6) | class TestVoxtralModelingPatchIntegration: method test_voxtral_conditional_generation_patch (line 10) | def test_voxtral_conditional_generation_patch(self): FILE: tests/patched/test_validation.py function fixture_cfg (line 22) | def fixture_cfg(): class BaseValidation (line 39) | class BaseValidation: method inject_fixtures (line 47) | def inject_fixtures(self, caplog): class TestValidation (line 51) | class TestValidation(BaseValidation): method test_defaults (line 56) | def test_defaults(self, minimal_cfg): method test_zero3_qlora_use_reentrant_false (line 68) | def test_zero3_qlora_use_reentrant_false(self, minimal_cfg): method test_deepspeed_empty (line 87) | def test_deepspeed_empty(self, minimal_cfg): method test_deepspeed_not_set (line 101) | def test_deepspeed_not_set(self, minimal_cfg): method test_datasets_min_length (line 115) | def test_datasets_min_length(self): method test_datasets_min_length_empty (line 132) | def test_datasets_min_length_empty(self): method test_pretrain_dataset_min_length (line 147) | def test_pretrain_dataset_min_length(self): method test_valid_pretrain_dataset (line 165) | def test_valid_pretrain_dataset(self): method test_valid_sft_dataset (line 184) | def test_valid_sft_dataset(self): method test_batch_size_unused_warning (line 202) | def test_batch_size_unused_warning(self): method test_batch_size_more_params (line 222) | def test_batch_size_more_params(self): method test_lr_as_float (line 240) | def test_lr_as_float(self, minimal_cfg): method test_model_config_remap (line 254) | def test_model_config_remap(self, minimal_cfg): method test_model_type_remap (line 267) | def test_model_type_remap(self, minimal_cfg): method test_reward_model_defaults (line 280) | def test_reward_model_defaults(self, minimal_cfg): method test_process_reward_model_defaults (line 294) | def test_process_reward_model_defaults(self, minimal_cfg): method test_model_revision_remap (line 308) | def test_model_revision_remap(self, minimal_cfg): method test_qlora (line 321) | def test_qlora(self, minimal_cfg): method test_qlora_merge (line 378) | def test_qlora_merge(self, minimal_cfg): method test_hf_use_auth_token (line 425) | def test_hf_use_auth_token(self, minimal_cfg): method test_gradient_accumulations_or_batch_size (line 449) | def test_gradient_accumulations_or_batch_size(self): method test_falcon_fsdp (line 470) | def test_falcon_fsdp(self, minimal_cfg): method test_mpt_gradient_checkpointing (line 512) | def test_mpt_gradient_checkpointing(self, minimal_cfg): method test_flash_optimum (line 529) | def test_flash_optimum(self, minimal_cfg): method test_adamw_hyperparams (line 594) | def test_adamw_hyperparams(self, minimal_cfg): method test_deprecated_packing (line 656) | def test_deprecated_packing(self, minimal_cfg): method test_packing (line 671) | def test_packing(self, minimal_cfg): method test_packing_autoset (line 690) | def test_packing_autoset(self, minimal_cfg): method test_merge_lora_no_bf16_fail (line 710) | def test_merge_lora_no_bf16_fail(self, minimal_cfg): method test_no_conflict_save_strategy (line 744) | def test_no_conflict_save_strategy(self, minimal_cfg): method test_no_conflict_eval_strategy (line 820) | def test_no_conflict_eval_strategy(self, minimal_cfg): method test_eval_table_size_conflict_eval_packing (line 963) | def test_eval_table_size_conflict_eval_packing(self, minimal_cfg): method test_load_in_x_bit_without_adapter (line 1020) | def test_load_in_x_bit_without_adapter(self, minimal_cfg): method test_warmup_step_no_conflict (line 1075) | def test_warmup_step_no_conflict(self, minimal_cfg): method test_unfrozen_parameters_w_peft_layers_to_transform (line 1114) | def test_unfrozen_parameters_w_peft_layers_to_transform(self, minimal_... method test_hub_model_id_save_value_warns_save_stragey_no (line 1134) | def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_c... method test_hub_model_id_save_value_warns_random_value (line 1141) | def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg): method test_hub_model_id_save_value_steps (line 1150) | def test_hub_model_id_save_value_steps(self, minimal_cfg): method test_hub_model_id_save_value_epochs (line 1160) | def test_hub_model_id_save_value_epochs(self, minimal_cfg): method test_hub_model_id_save_value_none (line 1170) | def test_hub_model_id_save_value_none(self, minimal_cfg): method test_hub_model_id_save_value_no_set_save_strategy (line 1177) | def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg): method test_dpo_beta_deprecation (line 1184) | def test_dpo_beta_deprecation(self, minimal_cfg): method test_eval_strategy_remap (line 1193) | def test_eval_strategy_remap(self, minimal_cfg): method test_torch_version_adopt_req (line 1211) | def test_torch_version_adopt_req(self, minimal_cfg): method test_cfg_throws_error_with_s2_attention_and_sample_packing (line 1243) | def test_cfg_throws_error_with_s2_attention_and_sample_packing(self, m... class TestTorchCompileValidation (line 1258) | class TestTorchCompileValidation(BaseValidation): method test_torch_compile_auto (line 1263) | def test_torch_compile_auto(self, minimal_cfg): class TestSampleOptimConfigValidation (line 1298) | class TestSampleOptimConfigValidation(BaseValidation): method test_batch_flattening_auto_enables (line 1303) | def test_batch_flattening_auto_enables(self, minimal_cfg): method test_batch_flattening_auto_no_fa (line 1319) | def test_batch_flattening_auto_no_fa(self, minimal_cfg): method test_batch_flattening_auto_mbsz_1 (line 1335) | def test_batch_flattening_auto_mbsz_1(self, minimal_cfg): method test_batch_flattening_auto_packing (line 1351) | def test_batch_flattening_auto_packing(self, minimal_cfg): class TestValidationCheckModelConfig (line 1368) | class TestValidationCheckModelConfig(BaseValidation): method test_llama_add_tokens_adapter (line 1373) | def test_llama_add_tokens_adapter(self, minimal_cfg): method test_phi_add_tokens_adapter (line 1420) | def test_phi_add_tokens_adapter(self, minimal_cfg): class TestValidationWandb (line 1468) | class TestValidationWandb(BaseValidation): method test_wandb_set_run_id_to_name (line 1473) | def test_wandb_set_run_id_to_name(self, minimal_cfg): method test_wandb_sets_env (line 1506) | def test_wandb_sets_env(self, minimal_cfg): method test_wandb_set_disabled (line 1542) | def test_wandb_set_disabled(self, minimal_cfg): class TestValidationComet (line 1565) | class TestValidationComet(BaseValidation): method test_comet_sets_env (line 1570) | def test_comet_sets_env(self, minimal_cfg): class TestValidationMLflow (line 1666) | class TestValidationMLflow(BaseValidation): method test_hf_mlflow_artifacts_config_sets_env (line 1671) | def test_hf_mlflow_artifacts_config_sets_env(self, minimal_cfg): method test_mlflow_not_used_by_default (line 1694) | def test_mlflow_not_used_by_default(self, minimal_cfg): class TestDataloaderValidation (line 1721) | class TestDataloaderValidation(BaseValidation): method test_dataloader_auto_defaults (line 1726) | def test_dataloader_auto_defaults(self, minimal_cfg): FILE: tests/prompt_strategies/conftest.py function fixture_assistant_dataset (line 16) | def fixture_assistant_dataset(): function fixture_sharegpt_dataset (line 32) | def fixture_sharegpt_dataset(): function fixture_basic_dataset (line 48) | def fixture_basic_dataset(): function fixture_toolcalling_dataset (line 65) | def fixture_toolcalling_dataset(): function fixture_llama3_tokenizer (line 110) | def fixture_llama3_tokenizer( function fixture_smollm2_tokenizer (line 120) | def fixture_smollm2_tokenizer(): function fixture_mistralv03_tokenizer (line 127) | def fixture_mistralv03_tokenizer( function fixture_phi35_tokenizer (line 138) | def fixture_phi35_tokenizer(): function fixture_phi4_tokenizer (line 145) | def fixture_phi4_tokenizer(): function fixture_gemma2_tokenizer (line 151) | def fixture_gemma2_tokenizer(): function fixture_magistral_tokenizer (line 158) | def fixture_magistral_tokenizer(): function fixture_devstral_tokenizer (line 166) | def fixture_devstral_tokenizer(): function fixture_devstral_1_1_tokenizer (line 174) | def fixture_devstral_1_1_tokenizer(): function qwen3_tokenizer_fixture (line 183) | def qwen3_tokenizer_fixture( function fixture_mistralv03_chat_template_jinja_w_system (line 192) | def fixture_mistralv03_chat_template_jinja_w_system() -> str: function fixture_gemma2_chat_template_jinja_w_system (line 197) | def fixture_gemma2_chat_template_jinja_w_system() -> str: function fixture_llama3_2_vision_with_hardcoded_date (line 202) | def fixture_llama3_2_vision_with_hardcoded_date() -> str: function fixture_chat_template_jinja_with_optional_fields (line 223) | def fixture_chat_template_jinja_with_optional_fields() -> str: function basic_jinja_template_analyzer (line 233) | def basic_jinja_template_analyzer(): function mistral_jinja_template_analyzer (line 247) | def mistral_jinja_template_analyzer(mistralv03_tokenizer_chat_template_j... FILE: tests/prompt_strategies/messages/test_chat.py class TestMessagesChatLlama3 (line 14) | class TestMessagesChatLlama3: method test_llama3_load (line 19) | def test_llama3_load(self, llama3_tokenizer, assistant_dataset): FILE: tests/prompt_strategies/test_alpaca.py function fixture_alpaca_dataset (line 18) | def fixture_alpaca_dataset(): function fixture_tokenizer (line 32) | def fixture_tokenizer(): class TestAlpacaChatml (line 52) | class TestAlpacaChatml: method test_no_double_im_end (line 57) | def test_no_double_im_end(self, alpaca_dataset, tokenizer): method test_no_train_on_input (line 79) | def test_no_train_on_input(self, alpaca_dataset, tokenizer): method test_w_train_on_input (line 101) | def test_w_train_on_input(self, alpaca_dataset, tokenizer): FILE: tests/prompt_strategies/test_chat_template_ds_schema_unification.py function fixture_messages_w_tools (line 15) | def fixture_messages_w_tools(): function qwen3_chat_template_strategy (line 26) | def qwen3_chat_template_strategy(qwen3_tokenizer): class TestSchemaUnification (line 40) | class TestSchemaUnification: method test_schema_unification_single_prompt (line 45) | def test_schema_unification_single_prompt( method test_schema_unification_batched (line 55) | def test_schema_unification_batched( FILE: tests/prompt_strategies/test_chat_template_utils.py function fixture_llama3_tokenizer (line 21) | def fixture_llama3_tokenizer(): class TestGetChatTemplateUtils (line 27) | class TestGetChatTemplateUtils: method test_known_chat_template (line 32) | def test_known_chat_template(self): method test_invalid_chat_template (line 36) | def test_invalid_chat_template(self): method test_tokenizer_default_no_tokenizer (line 41) | def test_tokenizer_default_no_tokenizer(self): method test_tokenizer_default_no_chat_template_on_tokenizer (line 45) | def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_... method test_tokenizer_default_with_chat_template_on_tokenizer (line 49) | def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama... method test_tokenizer_default_fallback_no_tokenizer (line 56) | def test_tokenizer_default_fallback_no_tokenizer(self): method test_tokenizer_default_fallback_no_chat_template_on_tokenizer (line 60) | def test_tokenizer_default_fallback_no_chat_template_on_tokenizer( method test_tokenizer_default_fallback_with_chat_template_on_tokenizer (line 68) | def test_tokenizer_default_fallback_with_chat_template_on_tokenizer( method test_jinja_template_mode (line 77) | def test_jinja_template_mode(self): method test_jinja_template_mode_no_jinja_template (line 82) | def test_jinja_template_mode_no_jinja_template(self): method test_extract_chat_template_args (line 86) | def test_extract_chat_template_args(self): FILE: tests/prompt_strategies/test_chat_templates.py class TestAssistantChatTemplateLlama3 (line 20) | class TestAssistantChatTemplateLlama3: method test_llama3_load (line 25) | def test_llama3_load(self, llama3_tokenizer, assistant_dataset): method test_llama3 (line 74) | def test_llama3(self, llama3_tokenizer, assistant_dataset): method test_phi35 (line 116) | def test_phi35(self, phi35_tokenizer, assistant_dataset): method test_llama3_with_training_data (line 174) | def test_llama3_with_training_data(self, llama3_tokenizer, assistant_d... class TestSharegptChatTemplateLlama3 (line 230) | class TestSharegptChatTemplateLlama3: method test_llama3_assistant (line 235) | def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset): method test_llama3_human (line 295) | def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset): method test_llama3_system_human (line 355) | def test_llama3_system_human(self, llama3_tokenizer, basic_dataset): class TestAssistantToolCallingChatTemplateLlama32Vision (line 420) | class TestAssistantToolCallingChatTemplateLlama32Vision: method test_llama32vision_train_on_assistant (line 425) | def test_llama32vision_train_on_assistant( method test_llama32vision_train_on_tools (line 492) | def test_llama32vision_train_on_tools( FILE: tests/prompt_strategies/test_chat_templates_advanced.py class TestChatTemplateConfigurations (line 49) | class TestChatTemplateConfigurations: method setup_tokenizer (line 55) | def setup_tokenizer( method _should_skip_turn (line 88) | def _should_skip_turn(self, tokenizer, turn, turn_idx, start_idx, end_... method test_train_on_inputs_true (line 104) | def test_train_on_inputs_true( method test_train_on_inputs_false (line 163) | def test_train_on_inputs_false( method test_roles_to_train_human_assistant_only (line 228) | def test_roles_to_train_human_assistant_only( method test_roles_to_train_all (line 293) | def test_roles_to_train_all( method test_empty_roles_to_train (line 348) | def test_empty_roles_to_train( method test_train_on_eos_all (line 388) | def test_train_on_eos_all( method test_train_on_eos_turn (line 434) | def test_train_on_eos_turn( method test_train_on_eos_last (line 513) | def test_train_on_eos_last( method test_train_on_eos_none (line 565) | def test_train_on_eos_none( method test_drop_system_message (line 611) | def test_drop_system_message( method test_custom_roles (line 651) | def test_custom_roles( method test_message_field_training (line 728) | def test_message_field_training( method test_get_chat_template_variables (line 919) | def test_get_chat_template_variables( method test_eot_tokens_conflict_with_eos_token (line 967) | def test_eot_tokens_conflict_with_eos_token( method test_eot_token_backward_compatibility (line 1014) | def test_eot_token_backward_compatibility( method test_token_not_in_template (line 1054) | def test_token_not_in_template( method test_custom_eot_tokens (line 1101) | def test_custom_eot_tokens( method test_multiple_train_on_eot_settings (line 1180) | def test_multiple_train_on_eot_settings( class TestChatTemplateToolCalling (line 1276) | class TestChatTemplateToolCalling: method test_tool_calling_with_llama4_template (line 1281) | def test_tool_calling_with_llama4_template( FILE: tests/prompt_strategies/test_chat_templates_mistral.py function test_mistral_chat_template (line 24) | def test_mistral_chat_template( function test_magistral_tokenizer_pad_method (line 312) | def test_magistral_tokenizer_pad_method(magistral_tokenizer: "HFMistralT... function test_magistral_tool_calling (line 442) | def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"): function test_magistral_tokenizer_call_method (line 755) | def test_magistral_tokenizer_call_method( FILE: tests/prompt_strategies/test_chat_templates_thinking.py function messages_w_reasoning_fixture (line 15) | def messages_w_reasoning_fixture(): class TestSplitThinking (line 58) | class TestSplitThinking: method test_splits_think (line 63) | def test_splits_think(self, messages_w_reasoning, qwen3_tokenizer): FILE: tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py function qwen3_instruct_chat_template_strategy (line 17) | def qwen3_instruct_chat_template_strategy(qwen3_tokenizer): class TestQwen3IdenticalConversationArgs (line 47) | class TestQwen3IdenticalConversationArgs: method fixture_conversation_dict_args_dataset (line 53) | def fixture_conversation_dict_args_dataset(self): method fixture_conversation_str_args_dataset (line 83) | def fixture_conversation_str_args_dataset(self): method fixture_conversation_mixed_time_types_dataset (line 114) | def fixture_conversation_mixed_time_types_dataset(self): method test_dict_and_str_args_produce_identical_output (line 152) | def test_dict_and_str_args_produce_identical_output( method test_str_args_with_mixed_time_types_no_error (line 192) | def test_str_args_with_mixed_time_types_no_error( class TestQwen3IdenticalToolsParameters (line 217) | class TestQwen3IdenticalToolsParameters: method fixture_tools_dict_params_dataset (line 223) | def fixture_tools_dict_params_dataset(self): method fixture_tools_str_params_dataset (line 280) | def fixture_tools_str_params_dataset(self): method fixture_tools_mixed_type_params_dataset (line 333) | def fixture_tools_mixed_type_params_dataset(self): method test_dict_and_str_params_produce_equivalent_output (line 409) | def test_dict_and_str_params_produce_equivalent_output( method test_str_params_with_mixed_types_no_error (line 477) | def test_str_params_with_mixed_types_no_error( FILE: tests/prompt_strategies/test_dpo_chat_templates.py function fixture_assistant_dataset (line 18) | def fixture_assistant_dataset(): function fixture_custom_assistant_dataset (line 50) | def fixture_custom_assistant_dataset(): function fixture_argilla_chat_dataset (line 82) | def fixture_argilla_chat_dataset(): function fixture_phi3_tokenizer (line 113) | def fixture_phi3_tokenizer(): function fixture_gemma_tokenizer (line 121) | def fixture_gemma_tokenizer(): class TestAssistantDPOChatTemplateLlama3 (line 127) | class TestAssistantDPOChatTemplateLlama3: method test_llama3_defaults (line 132) | def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset): method test_llama3_configured (line 156) | def test_llama3_configured(self, llama3_tokenizer, custom_assistant_da... class TestAssistantDPOChatTemplatePhi3 (line 191) | class TestAssistantDPOChatTemplatePhi3: method test_phi3_defaults (line 196) | def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset): class TestAssistantDPOChatTemplateGemma (line 220) | class TestAssistantDPOChatTemplateGemma: method test_gemma_defaults (line 225) | def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset): class TestArgillaChatDPOChatTemplate (line 249) | class TestArgillaChatDPOChatTemplate: method test_llama3_argilla_chat (line 254) | def test_llama3_argilla_chat(self, llama3_tokenizer, argilla_chat_data... method test_phi3_argilla_chat (line 276) | def test_phi3_argilla_chat(self, phi3_tokenizer, argilla_chat_dataset): FILE: tests/prompt_strategies/test_dpo_chatml.py function fixture_cfg (line 18) | def fixture_cfg(): class TestDPOChatml (line 35) | class TestDPOChatml: method test_default (line 42) | def test_default(self, minimal_dpo_cfg): FILE: tests/prompt_strategies/test_jinja_template_analyzer.py class TestJinjaTemplateAnalyzer (line 13) | class TestJinjaTemplateAnalyzer: method test_basic_variable_extraction (line 18) | def test_basic_variable_extraction(self, basic_jinja_template_analyzer): method test_mixtral_variable_extraction (line 26) | def test_mixtral_variable_extraction(self, mistral_jinja_template_anal... method test_message_property_access (line 50) | def test_message_property_access(self, basic_jinja_template_analyzer): method test_detailed_analysis (line 60) | def test_detailed_analysis(self, basic_jinja_template_analyzer): method test_nested_property_access (line 76) | def test_nested_property_access(self): method test_loop_variable_handling (line 88) | def test_loop_variable_handling(self): method test_conditional_variable_usage (line 107) | def test_conditional_variable_usage(self): method test_complex_expressions (line 124) | def test_complex_expressions(self): method test_basic_msg_vars (line 142) | def test_basic_msg_vars(self, basic_jinja_template_analyzer): method test_mixtral_msg_vars (line 149) | def test_mixtral_msg_vars(self, mistral_jinja_template_analyzer): FILE: tests/prompt_strategies/test_raw_io.py function fixture_sharegpt_dataset (line 18) | def fixture_sharegpt_dataset(): function fixture_tokenizer (line 46) | def fixture_tokenizer(): class TestRawInputOutputPrompts (line 59) | class TestRawInputOutputPrompts: method test_segment_prompts (line 64) | def test_segment_prompts(self, segments_dataset, tokenizer): FILE: tests/prompt_strategies/test_stepwise.py class TestStepWiseSupervisedPromptTokenizingStrategy (line 16) | class TestStepWiseSupervisedPromptTokenizingStrategy: method stepwise_supervised_dataset (line 22) | def stepwise_supervised_dataset(self): method tokenizer (line 38) | def tokenizer(self): method test_stepwise_supervised_dataset (line 41) | def test_stepwise_supervised_dataset(self, tokenizer, stepwise_supervi... FILE: tests/telemetry/conftest.py function del_track_env (line 7) | def del_track_env(monkeypatch): FILE: tests/telemetry/test_callbacks.py function calc_expected_metrics (line 14) | def calc_expected_metrics(step, last_step, current_time, last_time, star... function mock_time (line 28) | def mock_time(): function mock_telemetry_manager (line 36) | def mock_telemetry_manager(): function mock_runtime_metrics_tracker (line 45) | def mock_runtime_metrics_tracker(): function training_args (line 65) | def training_args(): function trainer_state (line 71) | def trainer_state(): function trainer_control (line 81) | def trainer_control(): function callback (line 88) | def callback(mock_telemetry_manager, mock_runtime_metrics_tracker): class TestTelemetryCallback (line 93) | class TestTelemetryCallback: method test_initialization (line 96) | def test_initialization(self, callback, mock_runtime_metrics_tracker): method test_on_train_begin (line 105) | def test_on_train_begin( method test_on_train_end (line 120) | def test_on_train_end( method test_on_epoch_begin (line 146) | def test_on_epoch_begin( method test_on_epoch_end (line 164) | def test_on_epoch_end( method test_on_step_end_no_report (line 180) | def test_on_step_end_no_report( method test_on_step_end_report_interval_steps (line 206) | def test_on_step_end_report_interval_steps( method test_on_step_end_report_time_elapsed (line 256) | def test_on_step_end_report_time_elapsed( method test_on_step_end_first_step (line 300) | def test_on_step_end_first_step( method test_log_history_empty (line 341) | def test_log_history_empty( FILE: tests/telemetry/test_errors.py function reset_error_flag (line 13) | def reset_error_flag(monkeypatch): function example_stack_trace (line 23) | def example_stack_trace(): function windows_stack_trace (line 37) | def windows_stack_trace(): function mixed_stack_trace (line 51) | def mixed_stack_trace(): function venv_stack_trace (line 67) | def venv_stack_trace(): function dist_packages_stack_trace (line 83) | def dist_packages_stack_trace(): function project_stack_trace (line 99) | def project_stack_trace(): function test_sanitize_stack_trace (line 112) | def test_sanitize_stack_trace(example_stack_trace): function test_sanitize_windows_paths (line 129) | def test_sanitize_windows_paths(windows_stack_trace): function test_sanitize_mixed_paths (line 155) | def test_sanitize_mixed_paths(mixed_stack_trace): function test_sanitize_venv_paths (line 169) | def test_sanitize_venv_paths(venv_stack_trace): function test_sanitize_dist_packages (line 185) | def test_sanitize_dist_packages(dist_packages_stack_trace): function test_sanitize_project_paths (line 203) | def test_sanitize_project_paths(project_stack_trace): function mock_telemetry_manager (line 220) | def mock_telemetry_manager(): function test_send_errors_successful_execution (line 229) | def test_send_errors_successful_execution(mock_telemetry_manager): function test_send_errors_with_exception (line 241) | def test_send_errors_with_exception(mock_telemetry_manager): function test_send_errors_nested_calls (line 262) | def test_send_errors_nested_calls(mock_telemetry_manager): function test_send_errors_telemetry_disable (line 282) | def test_send_errors_telemetry_disable(): function test_error_handled_reset (line 300) | def test_error_handled_reset(): function test_module_path_resolution (line 324) | def test_module_path_resolution(mock_telemetry_manager): FILE: tests/telemetry/test_manager.py function mock_whitelist (line 15) | def mock_whitelist(tmp_path): function telemetry_manager_class (line 28) | def telemetry_manager_class(): function manager (line 40) | def manager(telemetry_manager_class, mock_whitelist): function test_singleton_instance (line 55) | def test_singleton_instance(telemetry_manager_class): function test_telemetry_enabled_by_default (line 68) | def test_telemetry_enabled_by_default(telemetry_manager_class): function test_telemetry_enabled_with_explicit_opt_in (line 79) | def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class): function test_telemetry_disabled_with_axolotl_do_not_track (line 89) | def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_... function test_telemetry_disabled_with_do_not_track (line 99) | def test_telemetry_disabled_with_do_not_track(telemetry_manager_class): function test_telemetry_disabled_for_non_main_process (line 111) | def test_telemetry_disabled_for_non_main_process(telemetry_manager_class): function test_is_whitelisted (line 121) | def test_is_whitelisted(telemetry_manager_class, mock_whitelist): function test_system_info_collection (line 140) | def test_system_info_collection(manager): function test_send_event (line 152) | def test_send_event(telemetry_manager_class): function test_send_system_info (line 174) | def test_send_system_info(telemetry_manager_class): function test_redacted_properties (line 187) | def test_redacted_properties(telemetry_manager_class): function test_disable_telemetry (line 234) | def test_disable_telemetry(manager): function test_exception_handling_during_send (line 242) | def test_exception_handling_during_send(manager): function test_shutdown (line 257) | def test_shutdown(manager): FILE: tests/telemetry/test_runtime_metrics.py function mock_time (line 13) | def mock_time(): function mock_telemetry_manager (line 23) | def mock_telemetry_manager(): function mock_psutil (line 35) | def mock_psutil(): function mock_torch (line 48) | def mock_torch(): class TestRuntimeMetrics (line 62) | class TestRuntimeMetrics: method test_initialization (line 65) | def test_initialization(self): method test_elapsed_time (line 78) | def test_elapsed_time(self, mock_time): method test_epoch_time (line 87) | def test_epoch_time(self): method test_average_epoch_time (line 102) | def test_average_epoch_time(self): method test_steps_per_second (line 123) | def test_steps_per_second(self, mock_time): method test_to_dict_basic (line 137) | def test_to_dict_basic(self, mock_time): method test_to_dict_with_epochs (line 157) | def test_to_dict_with_epochs(self, mock_time): method test_to_dict_with_gpu_memory (line 179) | def test_to_dict_with_gpu_memory(self, mock_time): class TestRuntimeMetricsTracker (line 197) | class TestRuntimeMetricsTracker: method test_initialization (line 201) | def test_initialization(self, mock_time, mock_telemetry_manager): method test_start_epoch (line 209) | def test_start_epoch( method test_end_epoch (line 229) | def test_end_epoch(self, mock_time, mock_telemetry_manager): method test_update_step (line 245) | def test_update_step( method test_update_memory_metrics (line 270) | def test_update_memory_metrics( method test_get_memory_metrics (line 323) | def test_get_memory_metrics(self, mock_psutil, mock_torch, mock_teleme... FILE: tests/test_chunked_xentropy.py function chunked_fixtures (line 13) | def chunked_fixtures(): function test_chunked_forward (line 25) | def test_chunked_forward(chunked_fixtures): FILE: tests/test_context_parallel_batch_size.py function fixture_cp_base_cfg (line 13) | def fixture_cp_base_cfg(min_base_cfg): class TestContextParallelBatchSize (line 26) | class TestContextParallelBatchSize: method test_batch_size_with_context_parallelism (line 38) | def test_batch_size_with_context_parallelism( FILE: tests/test_convert.py class TestJsonParser (line 17) | class TestJsonParser: method test_parse_valid_json_array (line 18) | def test_parse_valid_json_array(self): method test_parse_valid_json_object (line 23) | def test_parse_valid_json_object(self): method test_parse_invalid_json_raises (line 28) | def test_parse_invalid_json_raises(self): class TestJsonlSerializer (line 34) | class TestJsonlSerializer: method test_serialize_single_item (line 35) | def test_serialize_single_item(self): method test_serialize_multiple_items (line 40) | def test_serialize_multiple_items(self): method test_serialize_empty_list (line 48) | def test_serialize_empty_list(self): class TestFileReaderWriter (line 54) | class TestFileReaderWriter: method test_read_write_roundtrip (line 55) | def test_read_write_roundtrip(self, tmp_path): class TestStdoutWriter (line 66) | class TestStdoutWriter: method test_write_to_stdout (line 67) | def test_write_to_stdout(self, capsys): class TestJsonToJsonlConverter (line 74) | class TestJsonToJsonlConverter: method test_convert_json_to_jsonl (line 75) | def test_convert_json_to_jsonl(self, tmp_path): FILE: tests/test_data.py class TestEncodePretraining (line 15) | class TestEncodePretraining(unittest.TestCase): method setUp (line 21) | def setUp(self): method test_encode_pretraining (line 33) | def test_encode_pretraining(self): method test_md5 (line 61) | def test_md5(self): method test_excess_length_strategy (line 67) | def test_excess_length_strategy(self): FILE: tests/test_datasets.py class TestDatasetPreparation (line 29) | class TestDatasetPreparation: method tokenizer (line 33) | def tokenizer( method dataset_fixture (line 40) | def dataset_fixture(self): method test_load_hub (line 53) | def test_load_hub(self, tokenizer): method test_load_local_hub (line 82) | def test_load_local_hub(self, tokenizer): method test_load_from_save_to_disk (line 129) | def test_load_from_save_to_disk(self, tokenizer, dataset_fixture): method test_load_from_dir_of_parquet (line 161) | def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture): method test_load_from_dir_of_json (line 200) | def test_load_from_dir_of_json(self, tokenizer, dataset_fixture): method test_load_from_single_parquet (line 239) | def test_load_from_single_parquet(self, tokenizer, dataset_fixture): method test_load_from_single_json (line 272) | def test_load_from_single_json(self, tokenizer, dataset_fixture): method test_load_hub_with_dpo (line 306) | def test_load_hub_with_dpo(self): method test_load_hub_with_revision (line 330) | def test_load_hub_with_revision(self, tokenizer): method test_load_hub_with_revision_with_dpo (line 363) | def test_load_hub_with_revision_with_dpo( method test_load_local_hub_with_revision (line 398) | def test_load_local_hub_with_revision( method test_loading_local_dataset_folder (line 453) | def test_loading_local_dataset_folder(self, tokenizer): FILE: tests/test_dict.py class DictDefaultTest (line 10) | class DictDefaultTest(unittest.TestCase): method test_dict_default (line 15) | def test_dict_default(self): method test_dict_or_operator (line 40) | def test_dict_or_operator(self): method test_dict_missingkey (line 67) | def test_dict_missingkey(self): method test_dict_or (line 72) | def test_dict_or(self): method test_dict_nested_missingparentkey (line 79) | def test_dict_nested_missingparentkey(self): method test_dict_shorthand_assignment (line 91) | def test_dict_shorthand_assignment(self): FILE: tests/test_exact_deduplication.py function verify_deduplication (line 24) | def verify_deduplication(actual_dataset, expected_dataset, dataset_name): class TestDeduplicateIndividualFunctions (line 49) | class TestDeduplicateIndividualFunctions(unittest.TestCase): method setUp (line 52) | def setUp(self): method test_deduplication (line 71) | def test_deduplication(self): method test_exact_duplicates (line 80) | def test_exact_duplicates(self): method test_partial_duplicates (line 102) | def test_partial_duplicates(self): method test_combined_duplicates_empty (line 128) | def test_combined_duplicates_empty(self): method test_combined_duplicates_one (line 159) | def test_combined_duplicates_one(self): class TestDeduplicateRLDataset (line 197) | class TestDeduplicateRLDataset: method cfg (line 201) | def cfg(self): method test_load_with_deduplication (line 219) | def test_load_with_deduplication( method test_load_without_deduplication (line 247) | def test_load_without_deduplication( class TestDeduplicateNonRL (line 277) | class TestDeduplicateNonRL(unittest.TestCase): method setUp (line 281) | def setUp(self) -> None: method test_prepare_dataset_with_deduplication_train (line 310) | def test_prepare_dataset_with_deduplication_train(self): method test_prepare_dataset_with_deduplication_eval (line 337) | def test_prepare_dataset_with_deduplication_eval(self): method test_prepare_dataset_without_deduplication (line 364) | def test_prepare_dataset_without_deduplication(self): class TestWrongCollisions (line 396) | class TestWrongCollisions(unittest.TestCase): method setUp (line 399) | def setUp(self): method test_deduplication_dataset_only (line 416) | def test_deduplication_dataset_only(self): FILE: tests/test_freeze.py class TestFreezeLayersExcept (line 19) | class TestFreezeLayersExcept(unittest.TestCase): method setUp (line 24) | def setUp(self): method test_freeze_layers_with_dots_in_name (line 27) | def test_freeze_layers_with_dots_in_name(self): method test_freeze_layers_without_dots_in_name (line 38) | def test_freeze_layers_without_dots_in_name(self): method test_freeze_layers_regex_patterns (line 49) | def test_freeze_layers_regex_patterns(self): method test_all_layers_frozen (line 61) | def test_all_layers_frozen(self): method test_all_layers_unfrozen (line 72) | def test_all_layers_unfrozen(self): method test_freeze_layers_with_range_pattern_start_end (line 83) | def test_freeze_layers_with_range_pattern_start_end(self): method test_freeze_layers_with_range_pattern_single_index (line 109) | def test_freeze_layers_with_range_pattern_single_index(self): method test_freeze_layers_with_range_pattern_start_omitted (line 124) | def test_freeze_layers_with_range_pattern_start_omitted(self): method test_freeze_layers_with_range_pattern_end_omitted (line 150) | def test_freeze_layers_with_range_pattern_end_omitted(self): method test_freeze_layers_with_range_pattern_merge_included (line 176) | def test_freeze_layers_with_range_pattern_merge_included(self): method test_freeze_layers_with_range_pattern_merge_intersect (line 202) | def test_freeze_layers_with_range_pattern_merge_intersect(self): method test_freeze_layers_with_range_pattern_merge_separate (line 228) | def test_freeze_layers_with_range_pattern_merge_separate(self): method _assert_gradient_output (line 257) | def _assert_gradient_output(self, expected): class _SubLayerModule (line 271) | class _SubLayerModule(nn.Module): method __init__ (line 272) | def __init__(self): class _TestModel (line 277) | class _TestModel(nn.Module): method __init__ (line 278) | def __init__(self): FILE: tests/test_loaders.py class TestModelsUtils (line 15) | class TestModelsUtils: method setup_method (line 18) | def setup_method(self) -> None: method test_set_device_map_config (line 45) | def test_set_device_map_config(self): method test_set_quantization_config (line 64) | def test_set_quantization_config( method test_message_property_mapping (line 105) | def test_message_property_mapping(self): method test_get_parallel_config_kwargs (line 192) | def test_get_parallel_config_kwargs( FILE: tests/test_logging_config_file_capture.py function read (line 7) | def read(path: str) -> str: function _reset_logging_state (line 13) | def _reset_logging_state(): function test_axolotl_logs_captured_at_all_levels (line 25) | def test_axolotl_logs_captured_at_all_levels(monkeypatch): function test_third_party_logs_filtered_and_warning_captured (line 49) | def test_third_party_logs_filtered_and_warning_captured(monkeypatch): function test_prepare_debug_log_idempotent_and_no_duplicate (line 82) | def test_prepare_debug_log_idempotent_and_no_duplicate(monkeypatch): FILE: tests/test_lora.py class TestLoRALoad (line 25) | class TestLoRALoad: method test_load_lora_weights (line 30) | def test_load_lora_weights(self): method test_load_lora_weights_empty_dropout (line 50) | def test_load_lora_weights_empty_dropout(self): FILE: tests/test_normalize_config.py class NormalizeConfigTestCase (line 16) | class NormalizeConfigTestCase(unittest.TestCase): method _get_base_cfg (line 21) | def _get_base_cfg(self): method test_base_model_config_set_when_empty (line 40) | def test_base_model_config_set_when_empty(self): method test_chat_template_chatml (line 47) | def test_chat_template_chatml(self): method test_bf16_auto_setter_available (line 71) | def test_bf16_auto_setter_available(self, mock_bf16_avail): method test_bf16_auto_setter_not_available (line 82) | def test_bf16_auto_setter_not_available(self, mock_bf16_avail): method test_bf16_disables_fp16 (line 94) | def test_bf16_disables_fp16(self, mock_bf16_avail): method test_migrate_fsdp_config (line 105) | def test_migrate_fsdp_config(self): method test_migrate_fsdp_config_no_fsdp_config (line 152) | def test_migrate_fsdp_config_no_fsdp_config(self): method test_migrate_fsdp_config_empty_fsdp_config (line 161) | def test_migrate_fsdp_config_empty_fsdp_config(self): method test_migrate_fsdp_config_mixed_keys (line 170) | def test_migrate_fsdp_config_mixed_keys(self): FILE: tests/test_opentelemetry_callback.py function mock_otel_config (line 11) | def mock_otel_config(): function mock_trainer_state (line 23) | def mock_trainer_state(): function mock_training_args (line 34) | def mock_training_args(): function mock_trainer_control (line 42) | def mock_trainer_control(): class TestOpenTelemetryConfig (line 49) | class TestOpenTelemetryConfig: method test_config_schema_valid (line 52) | def test_config_schema_valid(self): method test_config_defaults (line 68) | def test_config_defaults(self): method test_config_disabled_by_default (line 80) | def test_config_disabled_by_default(self): class TestOpenTelemetryCallback (line 89) | class TestOpenTelemetryCallback: method test_callback_import (line 92) | def test_callback_import(self): method test_callback_graceful_fallback (line 98) | def test_callback_graceful_fallback(self, mock_otel_config): method test_callback_initialization_enabled (line 109) | def test_callback_initialization_enabled(self, mock_otel_config): method test_metrics_server_lifecycle (line 126) | def test_metrics_server_lifecycle( method test_metrics_recording (line 155) | def test_metrics_recording( method test_evaluation_metrics (line 189) | def test_evaluation_metrics( method test_thread_safety (line 222) | def test_thread_safety(self, mock_otel_config): class TestOpenTelemetryIntegration (line 239) | class TestOpenTelemetryIntegration: method test_availability_check (line 242) | def test_availability_check(self): method test_prometheus_endpoint_basic (line 249) | def test_prometheus_endpoint_basic( class TestOpenTelemetryCallbackMethods (line 296) | class TestOpenTelemetryCallbackMethods: method test_step_end_callback (line 299) | def test_step_end_callback( method test_epoch_end_callback (line 325) | def test_epoch_end_callback( FILE: tests/test_packed_batch_sampler.py function fixture_tokenizer (line 19) | def fixture_tokenizer(): class TestBatchedSamplerPacking (line 25) | class TestBatchedSamplerPacking: method test_packing (line 42) | def test_packing( FILE: tests/test_packed_dataset.py class TestPacking (line 17) | class TestPacking(unittest.TestCase): method setUp (line 23) | def setUp(self) -> None: method test_lora_packing (line 34) | def test_lora_packing(self, temp_dir): FILE: tests/test_packed_pretraining.py class TestPretrainingPacking (line 16) | class TestPretrainingPacking: method random_text (line 22) | def random_text(self): method test_packing_stream_dataset (line 50) | def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text): FILE: tests/test_perplexity.py function metric (line 13) | def metric(tokenizer): function model (line 18) | def model(): function tokenizer (line 25) | def tokenizer(): function test_perplexity_longer_than_stride (line 31) | def test_perplexity_longer_than_stride(model, metric): function test_perplexity_short (line 42) | def test_perplexity_short(model, metric): FILE: tests/test_prompt_tokenizers.py class TestPromptTokenizationStrategies (line 56) | class TestPromptTokenizationStrategies: method test_no_sys_prompt (line 62) | def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens): method test_alpaca (line 84) | def test_alpaca(self, tokenizer_huggyllama_w_special_tokens): class TestInstructionWSystemPromptTokenizingStrategy (line 103) | class TestInstructionWSystemPromptTokenizingStrategy: method test_system_alpaca (line 109) | def test_system_alpaca(self, tokenizer_huggyllama_w_special_tokens): class Llama2ChatTokenizationTest (line 134) | class Llama2ChatTokenizationTest: method test_llama2_chat_integration (line 140) | def test_llama2_chat_integration(self, tokenizer_llama2_7b): method compare_with_transformers_integration (line 166) | def compare_with_transformers_integration(self, tokenizer_llama2_7b): class OrpoTokenizationTest (line 210) | class OrpoTokenizationTest: method test_orpo_integration (line 214) | def test_orpo_integration( FILE: tests/test_prompters.py class AlpacaPrompterTest (line 14) | class AlpacaPrompterTest(unittest.TestCase): method test_prompt_style_w_none (line 19) | def test_prompt_style_w_none(self): method test_prompt_style_w_instruct (line 25) | def test_prompt_style_w_instruct(self): method test_prompt_style_w_phi (line 45) | def test_prompt_style_w_phi(self): method test_prompt_style_w_chat (line 58) | def test_prompt_style_w_chat(self): method test_system_prompt (line 78) | def test_system_prompt(self): class UnpromptedPrompterTest (line 95) | class UnpromptedPrompterTest(unittest.TestCase): method test_prompt_style_w_none (line 100) | def test_prompt_style_w_none(self): method test_prompt_style_w_instruct (line 107) | def test_prompt_style_w_instruct(self): method test_prompt_style_w_chat (line 116) | def test_prompt_style_w_chat(self): class MultipleChoiceExplainPrompterTest (line 126) | class MultipleChoiceExplainPrompterTest(unittest.TestCase): method test_prompt_style_w_chat (line 131) | def test_prompt_style_w_chat(self): FILE: tests/test_revision_parameter.py class TestRevisionParameter (line 10) | class TestRevisionParameter: method test_load_tokenizer_passes_revision (line 18) | def test_load_tokenizer_passes_revision( method test_load_tokenizer_omits_revision_when_unset (line 43) | def test_load_tokenizer_omits_revision_when_unset( method test_modify_tokenizer_files_passes_revision (line 65) | def test_modify_tokenizer_files_passes_revision( method test_modify_tokenizer_files_defaults_revision_to_main (line 81) | def test_modify_tokenizer_files_defaults_revision_to_main( method test_load_processor_passes_revision (line 95) | def test_load_processor_passes_revision(self, mock_auto_processor): method test_load_processor_omits_revision_when_unset (line 117) | def test_load_processor_omits_revision_when_unset(self, mock_auto_proc... FILE: tests/test_save_deduplicated.py class TestSFTSaveDeduplicatedBeforeSave (line 13) | class TestSFTSaveDeduplicatedBeforeSave: method test_dedup_called_before_save_sft (line 22) | def test_dedup_called_before_save_sft( method test_no_dedup_when_disabled_sft (line 90) | def test_no_dedup_when_disabled_sft( class TestRLSaveDeduplicatedBeforeSave (line 136) | class TestRLSaveDeduplicatedBeforeSave: method test_dedup_called_before_save_rl (line 147) | def test_dedup_called_before_save_rl( FILE: tests/test_schedulers.py class TestCosineConstantLr (line 13) | class TestCosineConstantLr(unittest.TestCase): method setUp (line 18) | def setUp(self): method test_schedulers (line 33) | def test_schedulers(self): FILE: tests/test_streaming.py class TestStreamingConfig (line 16) | class TestStreamingConfig(unittest.TestCase): method test_streaming_multipack_buffer_size_deprecation (line 19) | def test_streaming_multipack_buffer_size_deprecation(self): method test_streaming_multipack_buffer_size_new (line 43) | def test_streaming_multipack_buffer_size_new(self): method test_both_buffer_sizes_raises_error (line 60) | def test_both_buffer_sizes_raises_error(self): class TestStreamingDatasetPreparation (line 80) | class TestStreamingDatasetPreparation(unittest.TestCase): method setUp (line 83) | def setUp(self): method test_prepare_datasets_with_streaming_true (line 89) | def test_prepare_datasets_with_streaming_true(self, mock_prepare_strea... method test_prepare_datasets_with_pretraining_dataset (line 105) | def test_prepare_datasets_with_pretraining_dataset(self, mock_prepare_... method test_prepare_datasets_without_streaming (line 120) | def test_prepare_datasets_without_streaming(self, mock_prepare_standard): class TestStreamingWithSamplePacking (line 135) | class TestStreamingWithSamplePacking(unittest.TestCase): method setUp (line 138) | def setUp(self): method test_streaming_sft_with_sample_packing_sets_split (line 144) | def test_streaming_sft_with_sample_packing_sets_split(self, mock_load_... method test_multipack_attn_forced_true_for_sft (line 166) | def test_multipack_attn_forced_true_for_sft(self): method test_multipack_attn_respects_config_for_pretraining (line 201) | def test_multipack_attn_respects_config_for_pretraining(self): FILE: tests/test_tensor_parallel_batch_size.py function fixture_tp_base_cfg (line 13) | def fixture_tp_base_cfg(min_base_cfg): class TestTensorParallelBatchSize (line 25) | class TestTensorParallelBatchSize: method test_batch_size_with_tensor_parallelism (line 37) | def test_batch_size_with_tensor_parallelism( FILE: tests/test_tokenizers.py class TestTokenizers (line 15) | class TestTokenizers: method test_default_use_fast (line 22) | def test_default_use_fast(self): method test_dont_use_fast (line 33) | def test_dont_use_fast(self): method test_special_tokens_modules_to_save (line 44) | def test_special_tokens_modules_to_save(self): method test_add_additional_special_tokens (line 79) | def test_add_additional_special_tokens(self): method test_added_tokens_overrides (line 96) | def test_added_tokens_overrides(self, temp_dir): method test_added_tokens_overrides_gemma3 (line 122) | def test_added_tokens_overrides_gemma3(self, temp_dir): method test_added_tokens_overrides_with_toolargeid (line 147) | def test_added_tokens_overrides_with_toolargeid(self, temp_dir): FILE: tests/test_train.py function fixture_train_base_cfg (line 10) | def fixture_train_base_cfg(min_base_cfg): class TestTrain (line 23) | class TestTrain: method test_batch_size_ddp (line 33) | def test_batch_size_ddp( FILE: tests/test_triton_kernels.py function _ref_entropy (line 28) | def _ref_entropy(logits): function _ref_selective_log_softmax (line 34) | def _ref_selective_log_softmax(logits, index): class TestEntropyFromLogits (line 51) | class TestEntropyFromLogits: method test_correctness_various_shapes (line 62) | def test_correctness_various_shapes(self, B, L): method test_2d_input (line 73) | def test_2d_input(self): method test_large_vocab (line 82) | def test_large_vocab(self): method test_uniform_distribution (line 91) | def test_uniform_distribution(self): method test_peaked_distribution (line 106) | def test_peaked_distribution(self): method test_bfloat16 (line 115) | def test_bfloat16(self): method test_float16 (line 124) | def test_float16(self): method test_non_contiguous_3d_transpose (line 133) | def test_non_contiguous_3d_transpose(self): method test_non_contiguous_3d_slice (line 145) | def test_non_contiguous_3d_slice(self): method test_many_rows_beyond_max_grid (line 157) | def test_many_rows_beyond_max_grid(self): method test_entropy_non_negative (line 166) | def test_entropy_non_negative(self): class TestSelectiveLogSoftmax (line 179) | class TestSelectiveLogSoftmax: method test_correctness_various_shapes (line 191) | def test_correctness_various_shapes(self, B, L, K): method test_squeezed_index (line 205) | def test_squeezed_index(self): method test_large_vocab (line 217) | def test_large_vocab(self): method test_bfloat16 (line 227) | def test_bfloat16(self): method test_fp32_tight_tolerance (line 239) | def test_fp32_tight_tolerance(self): method test_all_same_index (line 250) | def test_all_same_index(self): method test_last_index (line 260) | def test_last_index(self): method test_output_always_nonpositive (line 270) | def test_output_always_nonpositive(self): method test_many_rows_beyond_max_grid (line 280) | def test_many_rows_beyond_max_grid(self): class TestSelectiveLogSoftmaxBackward (line 296) | class TestSelectiveLogSoftmaxBackward: method test_gradient_matches_reference (line 306) | def test_gradient_matches_reference(self, B, L, V, K): method test_gradient_bfloat16_full_vocab (line 330) | def test_gradient_bfloat16_full_vocab(self): method test_gradient_k1_squeezed (line 348) | def test_gradient_k1_squeezed(self): class TestSelectiveLogSoftmaxOOBSafety (line 376) | class TestSelectiveLogSoftmaxOOBSafety: method test_negative_indices_no_crash (line 379) | def test_negative_indices_no_crash(self): method test_index_exceeds_vocab_no_crash (line 395) | def test_index_exceeds_vocab_no_crash(self): method test_mixed_valid_invalid_multi_index (line 415) | def test_mixed_valid_invalid_multi_index(self): method test_oob_backward_no_crash (line 438) | def test_oob_backward_no_crash(self): method test_oob_backward_valid_rows_correct (line 454) | def test_oob_backward_valid_rows_correct(self): FILE: tests/test_utils_tee.py function _dummy_cfg (line 5) | def _dummy_cfg(output_dir: str, append: bool = False): function read (line 20) | def read(path: str) -> str: function test_file_only_stream_writes_after_prepare (line 25) | def test_file_only_stream_writes_after_prepare(monkeypatch): function test_stdout_is_mirrored_after_prepare (line 50) | def test_stdout_is_mirrored_after_prepare(capsys, monkeypatch): function test_truncate_vs_append_behavior (line 72) | def test_truncate_vs_append_behavior(monkeypatch): FILE: tests/test_validation_dataset.py function fixture_cfg (line 16) | def fixture_cfg(): class BaseValidation (line 27) | class BaseValidation: method inject_fixtures (line 35) | def inject_fixtures(self, caplog): class TestValidationCheckDatasetConfig (line 39) | class TestValidationCheckDatasetConfig(BaseValidation): method test_dataset_config_no_drop_param (line 44) | def test_dataset_config_no_drop_param(self, minimal_cfg): method test_dataset_default_chat_template_no_drop_param (line 82) | def test_dataset_default_chat_template_no_drop_param(self, minimal_cfg): method test_dataset_partial_default_chat_template_no_drop_param (line 137) | def test_dataset_partial_default_chat_template_no_drop_param(self, min... method test_dataset_chatml_chat_template_no_drop_param (line 193) | def test_dataset_chatml_chat_template_no_drop_param(self, minimal_cfg): method test_dataset_sharegpt_deprecation (line 250) | def test_dataset_sharegpt_deprecation(self, minimal_cfg): method test_message_property_mappings (line 306) | def test_message_property_mappings(self, minimal_cfg): class TestOptimizerValidation (line 326) | class TestOptimizerValidation(BaseValidation): method test_muon_deepspeed (line 331) | def test_muon_deepspeed(self, minimal_cfg): method test_muon_fsdp (line 349) | def test_muon_fsdp(self, minimal_cfg): FILE: tests/utils/callbacks/test_dynamic_checkpoint.py class TestDynamicCheckpointCallbackInit (line 14) | class TestDynamicCheckpointCallbackInit: method test_callback_disabled_by_default (line 17) | def test_callback_disabled_by_default(self): method test_callback_disabled_when_none (line 29) | def test_callback_disabled_when_none(self): method test_callback_enabled_when_configured (line 41) | def test_callback_enabled_when_configured(self): method test_default_trigger_filename (line 54) | def test_default_trigger_filename(self): method test_check_interval_default (line 66) | def test_check_interval_default(self): class TestDynamicCheckpointFileDetection (line 79) | class TestDynamicCheckpointFileDetection: method test_trigger_file_detected_and_deleted (line 82) | def test_trigger_file_detected_and_deleted(self): method test_check_interval_honored (line 114) | def test_check_interval_honored(self): method test_no_file_no_trigger (line 151) | def test_no_file_no_trigger(self): method test_file_deletion_error_handling (line 178) | def test_file_deletion_error_handling(self): class TestDynamicCheckpointMultiGPU (line 212) | class TestDynamicCheckpointMultiGPU: method test_only_rank_0_checks_file (line 215) | def test_only_rank_0_checks_file(self): method test_broadcast_synchronization (line 255) | def test_broadcast_synchronization(self): class TestDynamicCheckpointSignalHandling (line 298) | class TestDynamicCheckpointSignalHandling: method test_signal_trigger_via_callback (line 301) | def test_signal_trigger_via_callback(self): method test_signal_not_registered_when_disabled (line 345) | def test_signal_not_registered_when_disabled(self): class TestDynamicCheckpointDisabled (line 365) | class TestDynamicCheckpointDisabled: method test_disabled_callback_does_nothing (line 368) | def test_disabled_callback_does_nothing(self): FILE: tests/utils/data/test_utils.py class TestHandleLongSeqInDataset (line 14) | class TestHandleLongSeqInDataset(unittest.TestCase): method test_drop_strategy_removes_long_sequences (line 19) | def test_drop_strategy_removes_long_sequences(self): method test_drop_strategy_is_default (line 50) | def test_drop_strategy_is_default(self): method test_truncate_strategy_truncates_long_sequences (line 74) | def test_truncate_strategy_truncates_long_sequences(self): method test_truncate_strategy_truncates_all_auxiliary_fields (line 117) | def test_truncate_strategy_truncates_all_auxiliary_fields(self): method test_raise_strategy_raises_on_long_sequences (line 159) | def test_raise_strategy_raises_on_long_sequences(self): method test_min_sequence_len_filters_short_sequences (line 182) | def test_min_sequence_len_filters_short_sequences(self): method test_dataset_without_input_ids_column (line 211) | def test_dataset_without_input_ids_column(self): method test_truncate_filters_short_before_truncating (line 233) | def test_truncate_filters_short_before_truncating(self): method test_case_insensitive_strategy (line 280) | def test_case_insensitive_strategy(self): method test_raise_strategy_silently_drops_short_sequences (line 304) | def test_raise_strategy_silently_drops_short_sequences(self): method test_drop_boundary_sequence_equal_to_sequence_len (line 330) | def test_drop_boundary_sequence_equal_to_sequence_len(self): method test_truncate_boundary_sequence_equal_to_sequence_len (line 356) | def test_truncate_boundary_sequence_equal_to_sequence_len(self): method test_empty_dataset (line 381) | def test_empty_dataset(self): method test_all_sequences_dropped_returns_empty_dataset (line 398) | def test_all_sequences_dropped_returns_empty_dataset(self): method test_iterable_dataset_skips_processing (line 422) | def test_iterable_dataset_skips_processing(self): method test_truncate_with_partial_auxiliary_fields (line 446) | def test_truncate_with_partial_auxiliary_fields(self): method test_min_sample_len_defaults_to_two_when_not_set (line 478) | def test_min_sample_len_defaults_to_two_when_not_set(self): method test_invalid_strategy_falls_through_to_drop (line 505) | def test_invalid_strategy_falls_through_to_drop(self): FILE: tests/utils/lora/test_config_validation_lora.py class TestLoRAConfigValidation (line 7) | class TestLoRAConfigValidation: method test_basic_configuration_validation (line 10) | def test_basic_configuration_validation(self): method test_qlora_4bit_validation (line 46) | def test_qlora_4bit_validation(self): method test_lora_kernels_trust_remote_code_incompatible (line 97) | def test_lora_kernels_trust_remote_code_incompatible(self, kernel_field): method test_lora_kernels_trust_remote_code_false (line 114) | def test_lora_kernels_trust_remote_code_false(self): FILE: tests/utils/lora/test_freeze_lora.py class TestLoRAParameterFreezing (line 13) | class TestLoRAParameterFreezing: method setup_method (line 16) | def setup_method(self): method create_mock_lora_layer (line 19) | def create_mock_lora_layer( method test_parameter_freezing_adapters_disabled (line 47) | def test_parameter_freezing_adapters_disabled(self): method test_parameter_freezing_adapters_merged (line 61) | def test_parameter_freezing_adapters_merged(self): method test_parameter_freezing_no_adapters (line 76) | def test_parameter_freezing_no_adapters(self): method test_parameter_active_adapters_enabled (line 91) | def test_parameter_active_adapters_enabled(self): method test_parameter_shapes_consistency (line 107) | def test_parameter_shapes_consistency(self): method test_parameter_dtypes_consistency (line 121) | def test_parameter_dtypes_consistency(self): method test_quantization_state_handling (line 134) | def test_quantization_state_handling(self): method test_multiple_adapters_active_adapter_selection (line 145) | def test_multiple_adapters_active_adapter_selection(self): class TestLoRAParameterFreezingIntegration (line 167) | class TestLoRAParameterFreezingIntegration: method test_parameter_freezing_with_real_lora_layer (line 173) | def test_parameter_freezing_with_real_lora_layer(self): method test_parameter_freezing_gradient_behavior (line 209) | def test_parameter_freezing_gradient_behavior(self): FILE: tests/utils/lora/test_merge_lora.py class TestAdapterMergeUnmerge (line 9) | class TestAdapterMergeUnmerge: method setup_method (line 12) | def setup_method(self): method create_mock_base_model (line 16) | def create_mock_base_model(self, vocab_size=1000, hidden_size=256): method create_mock_lora_model (line 38) | def create_mock_lora_model(self, base_model, r=8, alpha=16): method test_basic_lora_merge_unmerge_cycle (line 86) | def test_basic_lora_merge_unmerge_cycle(self): method test_merge_weight_calculation_accuracy (line 109) | def test_merge_weight_calculation_accuracy(self): method test_cli_do_merge_functionality (line 121) | def test_cli_do_merge_functionality(self, mock_load_model, tmp_path): method test_quantized_model_merge_compatibility (line 148) | def test_quantized_model_merge_compatibility(self): method test_memory_efficient_merge_with_cpu_offload (line 162) | def test_memory_efficient_merge_with_cpu_offload(self, tmp_path): FILE: tests/utils/schemas/validation/test_activation_offloading.py class TestActivationOffloading (line 7) | class TestActivationOffloading: method test_gc_converts_offload_wo_lora (line 12) | def test_gc_converts_offload_wo_lora(self, min_base_cfg): method test_ac_offload_impl_noop_wo_adapter (line 24) | def test_ac_offload_impl_noop_wo_adapter(self, min_base_cfg): FILE: tests/utils/schemas/validation/test_default_values.py class TestDefaultConfigValues (line 7) | class TestDefaultConfigValues: method test_pad_to_sequence_len (line 10) | def test_pad_to_sequence_len(self, min_base_cfg): FILE: tests/utils/schemas/validation/test_fsdp.py class TestFSDPValidation (line 11) | class TestFSDPValidation: method test_fsdp_version_from_fsdp_config (line 16) | def test_fsdp_version_from_fsdp_config(self, min_base_cfg): method test_fsdp_version_in_fsdp_config (line 27) | def test_fsdp_version_in_fsdp_config(self, min_base_cfg): method test_fsdp_offload_w_8bit_optim (line 40) | def test_fsdp_offload_w_8bit_optim(self, min_base_cfg): method test_fsdp2_w_8bit_optim (line 53) | def test_fsdp2_w_8bit_optim(self, min_base_cfg): method test_fsdp2_w_cpu_ram_efficient_loading (line 67) | def test_fsdp2_w_cpu_ram_efficient_loading(self, min_base_cfg): method test_fsdp2_cpu_offload_pin_memory_requires_offload_params (line 80) | def test_fsdp2_cpu_offload_pin_memory_requires_offload_params(self, mi... method test_fsdp1_cpu_offload_pin_memory_not_supported (line 94) | def test_fsdp1_cpu_offload_pin_memory_not_supported(self, min_base_cfg): method test_fsdp2_cpu_offload_pin_memory_w_offload_params (line 108) | def test_fsdp2_cpu_offload_pin_memory_w_offload_params(self, min_base_... method test_fsdp_prefixes_removed (line 120) | def test_fsdp_prefixes_removed(self, min_base_cfg): method test_muon_fsdp1_rejected (line 139) | def test_muon_fsdp1_rejected(self, min_base_cfg): method test_fsdp2_dpo (line 159) | def test_fsdp2_dpo(self, min_base_cfg, rl): FILE: tests/utils/schemas/validation/test_moe_quant.py function gpu_caps (line 10) | def gpu_caps(): function env_caps (line 21) | def env_caps(): class TestQuantizeMoeExpertsValidation (line 25) | class TestQuantizeMoeExpertsValidation: method test_requires_adapter (line 28) | def test_requires_adapter(self, min_base_cfg, gpu_caps, env_caps): method test_requires_quantization (line 39) | def test_requires_quantization(self, min_base_cfg, gpu_caps, env_caps): method test_valid_qlora_4bit (line 51) | def test_valid_qlora_4bit(self, min_base_cfg, gpu_caps, env_caps): method test_valid_lora_8bit (line 64) | def test_valid_lora_8bit(self, min_base_cfg, gpu_caps, env_caps): method test_false_skips_validation (line 77) | def test_false_skips_validation(self, min_base_cfg, gpu_caps, env_caps): method test_rejects_lora_target_linear (line 88) | def test_rejects_lora_target_linear(self, min_base_cfg, gpu_caps, env_... method test_default_is_false (line 102) | def test_default_is_false(self, min_base_cfg, gpu_caps, env_caps): class TestLoraTargetParametersDropout (line 109) | class TestLoraTargetParametersDropout: method test_rejects_nonzero_dropout (line 112) | def test_rejects_nonzero_dropout(self, min_base_cfg): method test_zero_dropout_passes (line 126) | def test_zero_dropout_passes(self, min_base_cfg): class TestPeftPatchIdempotency (line 141) | class TestPeftPatchIdempotency: method test_double_call_does_not_stack_wrappers (line 144) | def test_double_call_does_not_stack_wrappers(self): class TestMoeAdapterTrainMergeRoundtrip (line 165) | class TestMoeAdapterTrainMergeRoundtrip: method _make_classes (line 173) | def _make_classes(): method _make_quantized_model (line 202) | def _make_quantized_model(): method _make_plain_model (line 232) | def _make_plain_model(): method test_train_save_merge_no_size_mismatch (line 237) | def test_train_save_merge_no_size_mismatch(self, tmp_path): FILE: tests/utils/test_grpo_rw_fnc.py function test_get_rollout_func_loads_successfully (line 8) | def test_get_rollout_func_loads_successfully(): function test_get_rollout_func_invalid_module_raises_error (line 15) | def test_get_rollout_func_invalid_module_raises_error(): FILE: tests/utils/test_import_helper.py function test_get_cls_from_module_str (line 10) | def test_get_cls_from_module_str(): function test_get_cls_from_module_str_empty_string (line 15) | def test_get_cls_from_module_str_empty_string(): function test_get_cls_from_module_str_whitespace_only (line 20) | def test_get_cls_from_module_str_whitespace_only(): function test_get_cls_from_module_str_invalid_format (line 25) | def test_get_cls_from_module_str_invalid_format(): function test_get_cls_from_module_str_nonexistent_module (line 30) | def test_get_cls_from_module_str_nonexistent_module(): function test_get_cls_from_module_str_nonexistent_class (line 35) | def test_get_cls_from_module_str_nonexistent_class(): FILE: tests/utils/test_mistral3_processor.py function mock_tokenizer (line 14) | def mock_tokenizer(): function processor (line 20) | def processor(mock_tokenizer): class TestMistral3ProcessorInit (line 24) | class TestMistral3ProcessorInit: method test_tokenizer_is_set (line 25) | def test_tokenizer_is_set(self, processor, mock_tokenizer): method test_chat_template_is_none (line 28) | def test_chat_template_is_none(self, processor): method test_audio_tokenizer_is_none (line 31) | def test_audio_tokenizer_is_none(self, processor): class TestApplyChatTemplateTokenized (line 35) | class TestApplyChatTemplateTokenized: method batched_conversations (line 39) | def batched_conversations(self): method test_returns_batch_feature_with_pixel_values (line 51) | def test_returns_batch_feature_with_pixel_values( method test_returns_batch_feature_without_pixel_values (line 72) | def test_returns_batch_feature_without_pixel_values( class TestApplyChatTemplateNotTokenized (line 89) | class TestApplyChatTemplateNotTokenized: method test_single_conversation_returns_unwrapped (line 90) | def test_single_conversation_returns_unwrapped(self, processor, mock_t... method test_batched_conversations_returns_list (line 106) | def test_batched_conversations_returns_list(self, processor, mock_toke... class TestCall (line 126) | class TestCall: method test_delegates_to_tokenizer (line 127) | def test_delegates_to_tokenizer(self, processor, mock_tokenizer): class TestReturnTensorsValidation (line 139) | class TestReturnTensorsValidation: method test_rejects_non_pt_return_tensors (line 140) | def test_rejects_non_pt_return_tensors(self, processor): FILE: tests/utils/test_train.py function test_determine_last_checkpoint (line 9) | def test_determine_last_checkpoint(temp_dir):