SYMBOL INDEX (178 symbols across 17 files) FILE: mlx_lm_lora/synthetic_prompts.py function parse_args (line 38) | def parse_args(): function load_documents (line 125) | def load_documents(docs_dir: str) -> List[Dict[str, str]]: function create_generation_prompt (line 169) | def create_generation_prompt( function clean_latex_for_json (line 210) | def clean_latex_for_json(text: str) -> str: function generate_dataset (line 223) | def generate_dataset(args): function main (line 434) | def main(): FILE: mlx_lm_lora/train.py function load_reward_functions_from_file (line 139) | def load_reward_functions_from_file(file_path): function calculate_iters (line 156) | def calculate_iters(train_set, batch_size, epochs) -> int: function load_reference_model (line 166) | def load_reference_model(args): function load_judge_model (line 177) | def load_judge_model(args, reference_model=None): function build_parser (line 194) | def build_parser(): function train_model (line 504) | def train_model( function evaluate_model (line 772) | def evaluate_model( function build_lora_config (line 1046) | def build_lora_config(args): function run (line 1060) | def run(args, training_callback: TrainingCallback = None): function main (line 1168) | def main(args=None): FILE: mlx_lm_lora/train_judge.py function load_reward_functions_from_file (line 82) | def load_reward_functions_from_file(file_path): function calculate_iters (line 100) | def calculate_iters(train_set, batch_size, epochs) -> int: function build_parser (line 110) | def build_parser(): function train_model (line 257) | def train_model( function evaluate_model (line 345) | def evaluate_model(args, model: nn.Module, tokenizer, test_set): function run (line 359) | def run(args, training_callback: TrainingCallback = None): function main (line 430) | def main(args=None): FILE: mlx_lm_lora/trainer/cpo_trainer.py function get_token_scores (line 19) | def get_token_scores(model, x, mask, cache=None): function compute_score (line 25) | def compute_score(scores, mask, loss_type): function cpo_loss (line 30) | def cpo_loss( function iterate_cpo_batches (line 80) | def iterate_cpo_batches(dataset, batch_size, max_seq_length, train=False): function evaluate_cpo (line 139) | def evaluate_cpo( function train_cpo (line 210) | def train_cpo( FILE: mlx_lm_lora/trainer/datasets.py class GRPODataset (line 10) | class GRPODataset: method __init__ (line 11) | def __init__( method __getitem__ (line 44) | def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]: method __len__ (line 47) | def __len__(self) -> int: method process (line 50) | def process(self, d): class PreferenceDataset (line 54) | class PreferenceDataset: method __init__ (line 55) | def __init__( method __getitem__ (line 69) | def __getitem__(self, idx: int): method __len__ (line 72) | def __len__(self): method process (line 75) | def process(self, d): class JudgeDataset (line 79) | class JudgeDataset: method __init__ (line 80) | def __init__( method process (line 100) | def process(self, d): method __getitem__ (line 132) | def __getitem__(self, idx: int): method __len__ (line 135) | def __len__(self): class PromptDataset (line 139) | class PromptDataset: method __init__ (line 140) | def __init__( method process (line 150) | def process(self, d): method __getitem__ (line 170) | def __getitem__(self, idx: int): method __len__ (line 173) | def __len__(self): class DPODataset (line 177) | class DPODataset: method __init__ (line 178) | def __init__( method __getitem__ (line 217) | def __getitem__(self, idx: int): method __len__ (line 220) | def __len__(self): method process (line 223) | def process(self, d): class ORPODataset (line 227) | class ORPODataset: method __init__ (line 228) | def __init__( method _extract_content (line 320) | def _extract_content(self, data): method __len__ (line 339) | def __len__(self): method process (line 342) | def process(self, d): method __getitem__ (line 345) | def __getitem__(self, idx: int): class TextDataset (line 353) | class TextDataset: method __init__ (line 358) | def __init__( method process (line 368) | def process(self, d): method __getitem__ (line 374) | def __getitem__(self, idx: int): method __len__ (line 377) | def __len__(self): class ChatDataset (line 381) | class ChatDataset: method __init__ (line 387) | def __init__( method process (line 399) | def process(self, d): method __getitem__ (line 418) | def __getitem__(self, idx: int): method __len__ (line 421) | def __len__(self): class CompletionsDataset (line 425) | class CompletionsDataset: method __init__ (line 432) | def __init__( method process (line 446) | def process(self, d): method __getitem__ (line 469) | def __getitem__(self, idx: int): method __len__ (line 472) | def __len__(self): class ConcatenatedDataset (line 476) | class ConcatenatedDataset: method __init__ (line 477) | def __init__(self, data: List[Any]): method __getitem__ (line 481) | def __getitem__(self, idx: int): method process (line 491) | def process(self, d): method __len__ (line 494) | def __len__(self): class CacheDataset (line 498) | class CacheDataset: method __init__ (line 499) | def __init__(self, data: Any): method itemlen (line 503) | def itemlen(self, idx: int): method __getitem__ (line 506) | def __getitem__(self, idx: int): method __len__ (line 511) | def __len__(self): function create_dataset (line 515) | def create_dataset( function load_local_dataset (line 612) | def load_local_dataset( function load_hf_dataset (line 629) | def load_hf_dataset( function load_custom_hf_dataset (line 656) | def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): function load_dataset (line 715) | def load_dataset(args, tokenizer: PreTrainedTokenizer): FILE: mlx_lm_lora/trainer/dpo_trainer.py class DPOTrainingArgs (line 25) | class DPOTrainingArgs(SFTTrainingArgs): function get_token_scores (line 44) | def get_token_scores(model, x, mask, cache=None): function compute_score (line 50) | def compute_score(scores, mask, loss_type): function dpo_loss (line 55) | def dpo_loss( function iterate_dpo_batches (line 110) | def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False): function evaluate_dpo (line 169) | def evaluate_dpo( function train_dpo (line 261) | def train_dpo( FILE: mlx_lm_lora/trainer/grpo_reward_functions.py function register_reward_function (line 12) | def register_reward_function(name: str = None): function get_reward_function (line 38) | def get_reward_function(name: str) -> RewardFunctions: function get_default_reward_functions (line 58) | def get_default_reward_functions() -> List[RewardFunctions]: function list_available_reward_functions (line 71) | def list_available_reward_functions() -> List[str]: function r1_extract_xml_answer (line 78) | def r1_extract_xml_answer(text: str) -> str: function r1_int_reward_func (line 89) | def r1_int_reward_func( function r1_accuracy_reward_func (line 99) | def r1_accuracy_reward_func( function r1_soft_format_reward_func (line 111) | def r1_soft_format_reward_func( function r1_strict_format_reward_func (line 145) | def r1_strict_format_reward_func( function r1_count_xml (line 156) | def r1_count_xml( FILE: mlx_lm_lora/trainer/grpo_trainer.py class GRPOTrainingArgs (line 28) | class GRPOTrainingArgs(SFTTrainingArgs): function get_per_token_logps (line 91) | def get_per_token_logps(model: nn.Module, inputs, lengths): function generate_grpo (line 112) | def generate_grpo( function calculate_rewards_and_advantages (line 205) | def calculate_rewards_and_advantages( function grpo_loss (line 351) | def grpo_loss( function iterate_grpo_batches (line 540) | def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False): function evaluate_grpo (line 593) | def evaluate_grpo( function train_grpo (line 737) | def train_grpo( FILE: mlx_lm_lora/trainer/judge.py class LLMPairwiseJudge (line 175) | class LLMPairwiseJudge: method __init__ (line 176) | def __init__( method judge (line 188) | def judge( class LLMPPOJudge (line 233) | class LLMPPOJudge: method __init__ (line 234) | def __init__( method judge (line 246) | def judge( class HumanPairwiseJudge (line 315) | class HumanPairwiseJudge: method __init__ (line 316) | def __init__( method judge (line 322) | def judge( FILE: mlx_lm_lora/trainer/online_dpo_trainer.py class OnlineDPOTrainingArgs (line 24) | class OnlineDPOTrainingArgs(SFTTrainingArgs): function generate_for_online_dpo (line 61) | def generate_for_online_dpo( function compute_score (line 99) | def compute_score(scores, mask, loss_type): function online_dpo_loss (line 106) | def online_dpo_loss( function iterate_online_dpo_batches (line 162) | def iterate_online_dpo_batches(dataset, batch_size, max_seq_length, trai... function evaluate_online_dpo (line 190) | def evaluate_online_dpo( function train_online_dpo (line 367) | def train_online_dpo( FILE: mlx_lm_lora/trainer/orpo_trainer.py class ORPOTrainingArgs (line 25) | class ORPOTrainingArgs(SFTTrainingArgs): function get_logps (line 35) | def get_logps(model, tokens, mask, cache=None): function orpo_loss (line 54) | def orpo_loss( function iterate_orpo_batches (line 95) | def iterate_orpo_batches(dataset, batch_size, max_seq_length, train=False): function evaluate_orpo (line 170) | def evaluate_orpo( function train_orpo (line 228) | def train_orpo( FILE: mlx_lm_lora/trainer/ppo_trainer.py class PPOTrainingArgs (line 25) | class PPOTrainingArgs(OnlineDPOTrainingArgs): function ppo_loss (line 31) | def ppo_loss( function evaluate_ppo (line 113) | def evaluate_ppo( function train_ppo (line 289) | def train_ppo( FILE: mlx_lm_lora/trainer/rlhf_reinforce_trainer.py class RLHFReinforceTrainingArgs (line 22) | class RLHFReinforceTrainingArgs(SFTTrainingArgs): function compute_kl_penalty (line 35) | def compute_kl_penalty(logits_policy, logits_ref, masks): function rlhf_reinforce_loss (line 44) | def rlhf_reinforce_loss( function get_model_logits (line 90) | def get_model_logits(model, tokens, masks): function evaluate_rlhf_reinforce (line 97) | def evaluate_rlhf_reinforce( function train_rlhf_reinforce (line 220) | def train_rlhf_reinforce( FILE: mlx_lm_lora/trainer/sft_trainer.py function reset_prompt_cache (line 25) | def reset_prompt_cache(cache): function _find_cache_offset (line 59) | def _find_cache_offset(cache): function grad_checkpoint (line 76) | def grad_checkpoint(layer): class SFTTrainingArgs (line 93) | class SFTTrainingArgs: function _symmetric_fake_quantize_tensor (line 162) | def _symmetric_fake_quantize_tensor(x, bits: int, group_size: int): function _install_qat_hooks (line 202) | def _install_qat_hooks(model, args: SFTTrainingArgs): function default_loss (line 243) | def default_loss(model, batch, lengths, cache=None): function iterate_batches (line 260) | def iterate_batches( function evaluate_sft (line 313) | def evaluate_sft( function train_sft (line 368) | def train_sft( FILE: mlx_lm_lora/trainer/xpo_trainer.py class XPOTrainingArgs (line 25) | class XPOTrainingArgs(OnlineDPOTrainingArgs): function get_current_alpha (line 34) | def get_current_alpha( function xpo_loss (line 45) | def xpo_loss( function iterate_online_dpo_batches (line 129) | def iterate_online_dpo_batches(dataset, batch_size, max_seq_length, trai... function evaluate_xpo (line 157) | def evaluate_xpo( function train_xpo (line 335) | def train_xpo( FILE: mlx_lm_lora/utils.py function calculate_iters (line 20) | def calculate_iters(train_set, batch_size, epochs) -> int: function find_lmstudio_models_path (line 30) | def find_lmstudio_models_path() -> Path: function save_pretrained (line 45) | def save_pretrained( function save_pretrained_merged (line 113) | def save_pretrained_merged( function from_pretrained (line 169) | def from_pretrained( function push_to_hub (line 245) | def push_to_hub( function save_to_lmstudio_merged (line 301) | def save_to_lmstudio_merged( function save_pretrained_merged_vision (line 336) | def save_pretrained_merged_vision( FILE: mlx_lm_lora/visuals.py class Colors (line 1) | class Colors: function print_banner (line 26) | def print_banner(): function print_info (line 46) | def print_info(message): function print_success (line 51) | def print_success(message): function print_warning (line 56) | def print_warning(message): function print_error (line 61) | def print_error(message): function print_section (line 66) | def print_section(title):