SYMBOL INDEX (1986 symbols across 186 files) FILE: examples/datasets/deepmath_103k.py class ScriptArguments (line 23) | class ScriptArguments: function process_example (line 50) | def process_example(example): FILE: examples/datasets/hh-rlhf-helpful-base.py class ScriptArguments (line 24) | class ScriptArguments: function common_start (line 49) | def common_start(str1: str, str2: str) -> str: function extract_dialogue (line 61) | def extract_dialogue(example: str) -> list[dict[str, str]]: FILE: examples/datasets/llava_instruct_mix.py class ScriptArguments (line 24) | class ScriptArguments: function process_example (line 51) | def process_example(example): function filter_long_examples (line 61) | def filter_long_examples(example): function split_prompt_completion (line 66) | def split_prompt_completion(example): FILE: examples/datasets/lm-human-preferences-descriptiveness.py class ScriptArguments (line 23) | class ScriptArguments: function samples_not_all_same (line 51) | def samples_not_all_same(example): function to_prompt_completion (line 55) | def to_prompt_completion(example, tokenizer): FILE: examples/datasets/lm-human-preferences-sentiment.py class ScriptArguments (line 23) | class ScriptArguments: function to_prompt_completion (line 50) | def to_prompt_completion(example, tokenizer): FILE: examples/datasets/math_shepherd.py class ScriptArguments (line 25) | class ScriptArguments: function process_example (line 52) | def process_example(example): FILE: examples/datasets/prm800k.py class ScriptArguments (line 23) | class ScriptArguments: function process_example (line 50) | def process_example(example): function process_batch (line 89) | def process_batch(examples): FILE: examples/datasets/rlaif-v.py class ScriptArguments (line 23) | class ScriptArguments: function to_conversational (line 50) | def to_conversational(example): FILE: examples/datasets/tldr.py class ScriptArguments (line 23) | class ScriptArguments: function to_prompt_completion (line 50) | def to_prompt_completion(example): FILE: examples/datasets/tldr_preference.py class ScriptArguments (line 23) | class ScriptArguments: function to_preference (line 50) | def to_preference(example): FILE: examples/datasets/ultrafeedback-prompt.py class ScriptArguments (line 23) | class ScriptArguments: function to_unpaired_preference (line 50) | def to_unpaired_preference(example): function drop_long_prompt (line 55) | def drop_long_prompt(example): FILE: examples/datasets/ultrafeedback.py class ScriptArguments (line 23) | class ScriptArguments: function to_unpaired_preference (line 87) | def to_unpaired_preference(example, model_name, aspect): FILE: examples/scripts/async_grpo.py function format_sample (line 39) | def format_sample(sample): function main (line 43) | def main() -> None: FILE: examples/scripts/bco.py function embed_prompt (line 90) | def embed_prompt(input_ids: torch.LongTensor, attention_mask: torch.Long... FILE: examples/scripts/evals/judge_tldr.py class ScriptArguments (line 54) | class ScriptArguments: FILE: examples/scripts/grpo_2048.py class Game2048Env (line 32) | class Game2048Env: method reset (line 33) | def reset(self, **kwargs) -> str: method move (line 41) | def move(self, direction: str) -> str: method _spawn (line 60) | def _spawn(self) -> None: method _merge_line (line 68) | def _merge_line(line: list[int]) -> tuple[list[int], int]: method _apply_move (line 85) | def _apply_move(self, direction: str) -> tuple[bool, int]: method _can_move (line 117) | def _can_move(self) -> bool: method _render (line 128) | def _render(self) -> str: function reward_score (line 132) | def reward_score(environments, **kwargs): function main (line 136) | def main() -> None: FILE: examples/scripts/grpo_agent.py function query_reward (line 56) | def query_reward(completions, answer, **kwargs): function correctness_reward (line 123) | def correctness_reward(completions, answer, **kwargs): function structure_reward (line 151) | def structure_reward(completions, **kwargs): class TimeoutError (line 193) | class TimeoutError(Exception): function timeout (line 200) | def timeout(seconds): function query_biogrid (line 214) | def query_biogrid(sql_command: str) -> list[tuple]: function format_example (line 240) | def format_example(example): FILE: examples/scripts/grpo_vlm.py function make_conversation (line 119) | def make_conversation(example): function filter_big_images (line 129) | def filter_big_images(example): function convert_to_rgb (line 135) | def convert_to_rgb(example): FILE: examples/scripts/gspo.py function make_conversation (line 106) | def make_conversation(example): FILE: examples/scripts/gspo_vlm.py function make_conversation (line 108) | def make_conversation(example): function filter_big_images (line 118) | def filter_big_images(example): function convert_to_rgb (line 124) | def convert_to_rgb(example): FILE: examples/scripts/mpo_vlm.py function ensure_rgb (line 104) | def ensure_rgb(example): FILE: examples/scripts/nemo_gym/train_multi_environment.py class NeMoGymGRPOConfig (line 40) | class NeMoGymGRPOConfig(GRPOConfig): function get_agent_servers (line 45) | def get_agent_servers( function reward_fn (line 76) | def reward_fn(completions: list[str], **kwargs) -> list[float]: function call_nemo_gym_agents (line 82) | async def call_nemo_gym_agents( function nemo_gym_rollout_func (line 139) | def nemo_gym_rollout_func(prompts: list[str], trainer: GRPOTrainer) -> d... function load_dataset_from_jsonl (line 295) | def load_dataset_from_jsonl(path: str) -> Dataset: function main (line 312) | def main(): FILE: examples/scripts/online_dpo_vlm.py function make_conversation (line 162) | def make_conversation(example): function filter_big_images (line 173) | def filter_big_images(example): function convert_to_rgb (line 179) | def convert_to_rgb(example): FILE: examples/scripts/openenv/browsergym.py function parse_args (line 97) | def parse_args() -> argparse.Namespace: function sanitize_name (line 278) | def sanitize_name(name: str) -> str: function make_user_prompt (line 309) | def make_user_prompt(goal: str, step_num: int, axtree: str, error: str =... function parse_action (line 330) | def parse_action(response_text: str) -> str: function rollout_once (line 342) | def rollout_once( function reward_completion (line 456) | def reward_completion(completions: list[str], **kwargs) -> list[float]: function main (line 469) | def main() -> None: FILE: examples/scripts/openenv/browsergym_llm.py function parse_args (line 79) | def parse_args() -> argparse.Namespace: function sanitize_name (line 239) | def sanitize_name(name: str) -> str: function make_user_prompt (line 273) | def make_user_prompt(goal: str, step_num: int, axtree: str, error: str =... function parse_action (line 294) | def parse_action(response_text: str) -> str: function rollout_once (line 306) | def rollout_once( function reward_completion (line 393) | def reward_completion(completions: list[str], **kwargs) -> list[float]: function main (line 406) | def main() -> None: FILE: examples/scripts/openenv/carla.py function parse_args (line 58) | def parse_args(): class CarlaGRPOEnv (line 113) | class CarlaGRPOEnv: method __init__ (line 114) | def __init__(self): method _describe (line 119) | def _describe(obs) -> str: method _advance (line 132) | def _advance(self, ticks: int = SIM_TICKS): method reset (line 141) | def reset(self, **kwargs) -> str | None: method observe (line 146) | def observe(self) -> str: method emergency_stop (line 157) | def emergency_stop(self) -> str: method lane_change (line 169) | def lane_change(self, direction: str) -> str: function reward_func (line 185) | def reward_func(completions, environments, **kwargs): FILE: examples/scripts/openenv/catch.py function parse_args (line 91) | def parse_args(): function start_env_server (line 135) | def start_env_server(env_host: str, env_port: int): function reward_from_env (line 200) | def reward_from_env(completions, **kwargs): function main (line 205) | def main(): FILE: examples/scripts/openenv/echo.py function reward_func (line 42) | def reward_func(completions, environments, **kwargs): class MyEchoEnv (line 46) | class MyEchoEnv: method __init__ (line 47) | def __init__(self): method reset (line 50) | def reset(self, **kwargs) -> None | str: method step (line 54) | def step(self, message: str) -> str: method get_reward (line 68) | def get_reward(self) -> float: FILE: examples/scripts/openenv/sudoku.py function parse_args (line 113) | def parse_args() -> argparse.Namespace: function resolve_system_prompt (line 189) | def resolve_system_prompt(path: str) -> str: function sanitize_name (line 196) | def sanitize_name(name: str) -> str: function extract_sudoku_move (line 200) | def extract_sudoku_move(text: str) -> str: function is_valid_board_state (line 217) | def is_valid_board_state(board_str: str) -> bool: function parse_board (line 222) | def parse_board(board_str: str) -> list[list[int]]: function count_filled_cells (line 244) | def count_filled_cells(board_str: str) -> int: function get_valid_numbers (line 252) | def get_valid_numbers(grid: list[list[int]], row: int, col: int) -> set[... function extract_empty_cells_with_candidates (line 279) | def extract_empty_cells_with_candidates( function extract_empty_cells (line 304) | def extract_empty_cells(board_str: str) -> list[tuple[int, int]]: function extract_board_only (line 325) | def extract_board_only(text: str) -> str: function make_compact_prompt (line 353) | def make_compact_prompt( function check_move_targets_empty_cell (line 417) | def check_move_targets_empty_cell(move: str, board_str: str) -> bool: function extract_feedback (line 431) | def extract_feedback(observation) -> dict: function rollout_once (line 457) | def rollout_once( function reward_empty_cell (line 671) | def reward_empty_cell(completions: list[str], **kwargs) -> list[float]: function reward_valid_moves (line 679) | def reward_valid_moves(completions: list[str], **kwargs) -> list[float]: function reward_correct (line 687) | def reward_correct(completions: list[str], **kwargs) -> list[float]: function reward_repetition (line 695) | def reward_repetition(completions: list[str], **kwargs) -> list[float]: function reward_progress (line 703) | def reward_progress(completions: list[str], **kwargs) -> list[float]: function main (line 716) | def main() -> None: FILE: examples/scripts/openenv/wordle.py class WordleEnv (line 50) | class WordleEnv: method __init__ (line 51) | def __init__(self): method reset (line 54) | def reset(self, **kwargs) -> None | str: method guess (line 63) | def guess(self, guess: str) -> str: function reward (line 91) | def reward(environments, **kwargs) -> list[float]: function main (line 95) | def main() -> None: FILE: examples/scripts/ppo/ppo.py function prepare_dataset (line 134) | def prepare_dataset(dataset, tokenizer): FILE: examples/scripts/ppo/ppo_tldr.py function prepare_dataset (line 137) | def prepare_dataset(dataset, tokenizer): FILE: examples/scripts/rloo.py function main (line 51) | def main(): FILE: examples/scripts/rloo_vlm.py function make_conversation (line 119) | def make_conversation(example): function filter_big_images (line 129) | def filter_big_images(example): function convert_to_rgb (line 135) | def convert_to_rgb(example): FILE: examples/scripts/sft_gemma3.py function main (line 42) | def main(): FILE: examples/scripts/sft_gpt_oss.py function main (line 63) | def main(script_args, training_args, model_args): FILE: examples/scripts/sft_nemotron_3.py function main (line 70) | def main(script_args, training_args, model_args): FILE: examples/scripts/sft_tiny_aya_tool_calling.py function create_conversation (line 80) | def create_conversation(sample): function main (line 101) | def main(): FILE: examples/scripts/sft_video_llm.py function download_video (line 73) | def download_video(url: str, cache_dir: str) -> str: function prepare_dataset (line 94) | def prepare_dataset(example: dict[str, Any], cache_dir: str) -> dict[str... function collate_fn (line 124) | def collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: class CustomScriptArguments (line 165) | class CustomScriptArguments(ScriptArguments): FILE: examples/scripts/sft_vlm_gemma3.py function process_vision_info (line 83) | def process_vision_info(messages: list[dict]) -> list[Image.Image]: function format_data (line 102) | def format_data(samples: dict[str, any]) -> dict[str, list]: function prepare_dataset (line 127) | def prepare_dataset(dataset: DatasetDict, dataset_name: str) -> DatasetD... function main (line 143) | def main(): FILE: scripts/add_copyrights.py function get_tracked_python_files (line 37) | def get_tracked_python_files(): function check_and_add_copyright (line 52) | def check_and_add_copyright(file_path): function main (line 73) | def main(): FILE: scripts/generate_harmony_dataset.py class ScriptArguments (line 22) | class ScriptArguments: function main (line 49) | def main(test_size, push_to_hub, repo_id): FILE: scripts/generate_tiny_models.py function push_to_hub (line 106) | def push_to_hub(model, tokenizer, generation_config, prefix=None, suffix... function init_weights_tiny_model (line 127) | def init_weights_tiny_model(model): FILE: scripts/generate_toolcall_dataset.py class ScriptArguments (line 24) | class ScriptArguments: function main (line 51) | def main(test_size, push_to_hub, repo_id): FILE: scripts/generate_zen_dataset.py class ScriptArguments (line 22) | class ScriptArguments: function main (line 49) | def main(test_size, push_to_hub, repo_id): FILE: scripts/generate_zen_image_dataset.py class ScriptArguments (line 26) | class ScriptArguments: function main (line 53) | def main(test_size, push_to_hub, repo_id): FILE: scripts/generate_zen_multi_image_dataset.py class ScriptArguments (line 26) | class ScriptArguments: function main (line 53) | def main(test_size, push_to_hub, repo_id): FILE: scripts/log_reports.py function process_log_file (line 34) | def process_log_file(log): function main (line 65) | def main(slack_channel_name): FILE: tests/conftest.py function apply_model_revisions (line 44) | def apply_model_revisions(monkeypatch): function cleanup_gpu (line 77) | def cleanup_gpu(): FILE: tests/distributed/test_distributed.py function run_command (line 30) | def run_command(command: list[str], env: dict[str, str]) -> None: function get_config_path (line 36) | def get_config_path(lazy_shared_datadir): class TestDistributed (line 44) | class TestDistributed( method test_sft (line 68) | def test_sft(self, config, get_config_path): method test_dpo (line 103) | def test_dpo(self, config, get_config_path): method test_sft_dataset_streaming (line 138) | def test_sft_dataset_streaming(self, config, get_config_path): method test_sft_peft (line 175) | def test_sft_peft(self, config, get_config_path): method test_reward (line 211) | def test_reward(self, config, get_config_path): method test_rloo (line 240) | def test_rloo(self, config, get_config_path): method test_grpo (line 270) | def test_grpo(self, config, get_config_path): FILE: tests/experimental/test_async_grpo_trainer.py function dummy_reward_func (line 29) | def dummy_reward_func(completions, **kwargs): class _StubRolloutWorker (line 33) | class _StubRolloutWorker: method __init__ (line 36) | def __init__(self, tokenizer, dataset, num_generations: int = 8, sampl... method _make_sample_iter (line 42) | def _make_sample_iter(self, tokenizer, dataset, num_generations): method _fill_queue (line 70) | def _fill_queue(self): method start (line 74) | def start(self): method update_model_version (line 77) | def update_model_version(self, version): method stop (line 81) | def stop(self): method pause (line 84) | def pause(self): method resume (line 87) | def resume(self): method send_weights (line 90) | def send_weights(self, iterator): class TestAsyncGRPOTrainer (line 94) | class TestAsyncGRPOTrainer(TrlTestCase): method test_init_minimal (line 95) | def test_init_minimal(self): method test_training (line 106) | def test_training(self): FILE: tests/experimental/test_bco_trainer.py class TestBCOTrainer (line 35) | class TestBCOTrainer(TrlTestCase): method test_train (line 48) | def test_train(self, config_name): method test_train_with_precompute (line 84) | def test_train_with_precompute(self): method test_train_eval (line 121) | def test_train_eval(self): method test_init_with_ref_model_is_model (line 149) | def test_init_with_ref_model_is_model(self): method test_tokenize_and_process_tokens (line 172) | def test_tokenize_and_process_tokens(self): method test_train_without_providing_ref_model (line 226) | def test_train_without_providing_ref_model(self): method test_train_udm (line 260) | def test_train_udm(self): method test_train_without_providing_ref_model_with_lora (line 310) | def test_train_without_providing_ref_model_with_lora(self): method test_generate_during_eval_no_wandb (line 348) | def test_generate_during_eval_no_wandb(self): method test_lora_train_and_save (line 379) | def test_lora_train_and_save(self): method test_compute_metrics (line 411) | def test_compute_metrics(self): FILE: tests/experimental/test_cpo_trainer.py class TestCPOTrainer (line 25) | class TestCPOTrainer(TrlTestCase): method setup_method (line 26) | def setup_method(self): method test_cpo_trainer (line 48) | def test_cpo_trainer(self, name, loss_type, config_name): method test_cpo_trainer_with_lora (line 103) | def test_cpo_trainer_with_lora(self, config_name): method test_compute_metrics (line 151) | def test_compute_metrics(self): method test_alphapo_trainer (line 181) | def test_alphapo_trainer(self): FILE: tests/experimental/test_dppo_trainer.py class TestDPPODivergenceMask (line 24) | class TestDPPODivergenceMask: method make_trainer (line 28) | def make_trainer(divergence_type="binary_tv", epsilon=0.2, epsilon_hig... method compute_divergence_mask (line 41) | def compute_divergence_mask( method test_binary_tv_no_masking_within_threshold (line 60) | def test_binary_tv_no_masking_within_threshold(self): method test_binary_tv_masks_positive_advantage_high_divergence (line 72) | def test_binary_tv_masks_positive_advantage_high_divergence(self): method test_binary_tv_masks_negative_advantage_low_divergence (line 83) | def test_binary_tv_masks_negative_advantage_low_divergence(self): method test_binary_tv_respects_completion_mask (line 94) | def test_binary_tv_respects_completion_mask(self): method test_topk_tv_requires_topk_inputs (line 105) | def test_topk_tv_requires_topk_inputs(self): class TestDPPOTrainer (line 132) | class TestDPPOTrainer(TrlTestCase): method test_training_binary (line 134) | def test_training_binary(self, divergence_type): method test_training_conversational (line 164) | def test_training_conversational(self, config_name): FILE: tests/experimental/test_gkd_trainer.py class TestGKDTrainerGenerateOnPolicy (line 28) | class TestGKDTrainerGenerateOnPolicy(TrlTestCase): method setup_class (line 30) | def setup_class(cls): method test_generate_on_policy_outputs_deterministic (line 43) | def test_generate_on_policy_outputs_deterministic(self): method test_generate_on_policy_outputs (line 91) | def test_generate_on_policy_outputs(self): class TestGeneralizedJSDLoss (line 126) | class TestGeneralizedJSDLoss(TrlTestCase): method setup_method (line 127) | def setup_method(self): method test_uniform_distribution (line 134) | def test_uniform_distribution(self): method test_generalized_jsd_loss_edge_cases (line 139) | def test_generalized_jsd_loss_edge_cases(self): method test_output_shape (line 158) | def test_output_shape(self): method test_beta_values (line 163) | def test_beta_values(self): method test_temperature_scaling (line 168) | def test_temperature_scaling(self): method test_reduction_methods (line 173) | def test_reduction_methods(self): method test_symmetry (line 186) | def test_symmetry(self): method test_zero_loss_for_identical_inputs (line 195) | def test_zero_loss_for_identical_inputs(self): class TestGKDTrainer (line 201) | class TestGKDTrainer(TrlTestCase): method setup_method (line 202) | def setup_method(self): method test_gkd_trainer (line 209) | def test_gkd_trainer(self): method test_gkd_trainer_with_liger (line 240) | def test_gkd_trainer_with_liger(self): method test_generation_config_init (line 265) | def test_generation_config_init(self): FILE: tests/experimental/test_gold_trainer.py function openr1_examples (line 27) | def openr1_examples(): function countdown_examples (line 40) | def countdown_examples(): function _teacher_inputs_from_collator (line 52) | def _teacher_inputs_from_collator(student_tok, teacher_tok, batch): function _assert_alignment_covers_completion (line 76) | def _assert_alignment_covers_completion(loss_fn, batch, teacher_input_id... function test_chatml_collator_preserves_completion_llama (line 97) | def test_chatml_collator_preserves_completion_llama(llama_tokenizer, qwe... function test_chatml_collator_preserves_completion_llama_countdown (line 142) | def test_chatml_collator_preserves_completion_llama_countdown(llama_toke... function test_chatml_collator_preserves_completion_smollm (line 187) | def test_chatml_collator_preserves_completion_smollm(smollm_tokenizer, q... function build_config (line 231) | def build_config(**overrides): function llama_tokenizer (line 250) | def llama_tokenizer(): function qwen_tokenizer (line 258) | def qwen_tokenizer(): function smollm_tokenizer (line 266) | def smollm_tokenizer(): function encode_prompt_completion (line 273) | def encode_prompt_completion(tokenizer, prompt, completion): function pad_tokens (line 284) | def pad_tokens(ids, pad_id, target_length): function pad_labels (line 288) | def pad_labels(labels, target_length): function test_process_completions_to_buffer_left_pads_prompt_retokenization (line 292) | def test_process_completions_to_buffer_left_pads_prompt_retokenization(): function test_alignment_groups_cover_all_tokens (line 374) | def test_alignment_groups_cover_all_tokens(llama_tokenizer, qwen_tokeniz... function test_merge_probabilities_multiplies_split_tokens (line 389) | def test_merge_probabilities_multiplies_split_tokens(): function test_initialize_vocabulary_mapping_contains_common_tokens (line 411) | def test_initialize_vocabulary_mapping_contains_common_tokens(llama_toke... function test_get_start_and_size_answers_skips_prompt_tokens (line 431) | def test_get_start_and_size_answers_skips_prompt_tokens(): function test_generate_on_policy_outputs_masks_prompt (line 450) | def test_generate_on_policy_outputs_masks_prompt(llama_tokenizer): function test_generate_on_policy_outputs_masks_prompt_smollm (line 502) | def test_generate_on_policy_outputs_masks_prompt_smollm(smollm_tokenizer... function test_generalized_jsd_loss_accepts_probability_inputs (line 550) | def test_generalized_jsd_loss_accepts_probability_inputs(): function test_uldloss_handles_llama_student_qwen_teacher_sequence (line 570) | def test_uldloss_handles_llama_student_qwen_teacher_sequence(llama_token... function test_uldloss_handles_smollm_student_qwen_teacher_sequence (line 619) | def test_uldloss_handles_smollm_student_qwen_teacher_sequence(smollm_tok... function test_uldloss_hybrid_config_beta_zero (line 668) | def test_uldloss_hybrid_config_beta_zero(llama_tokenizer, qwen_tokenizer): FILE: tests/experimental/test_grpo_with_replay_buffer_trainer.py class TestReplayBuffer (line 29) | class TestReplayBuffer: method setup_method (line 30) | def setup_method(self): method test_add (line 33) | def test_add(self): method test_add_more_than_maxlen (line 53) | def test_add_more_than_maxlen(self): method test_sample (line 75) | def test_sample(self): class TestUpdateWithReplayBuffer (line 97) | class TestUpdateWithReplayBuffer: method setup_method (line 98) | def setup_method(self): method _prepopulate_buffer (line 112) | def _prepopulate_buffer(self, with_pixels=False, with_logprobs=False): method _make_inputs (line 136) | def _make_inputs(self, group_advantages, with_pixels=False, with_logpr... method test_update_with_replay_buffer_no_variance (line 149) | def test_update_with_replay_buffer_no_variance(self): method test_update_with_replay_buffer_with_variance (line 164) | def test_update_with_replay_buffer_with_variance(self): method test_update_with_mixed_variance (line 174) | def test_update_with_mixed_variance(self): method test_update_with_inputs_different_seq_len (line 193) | def test_update_with_inputs_different_seq_len(self): class TestGRPOWithReplayBufferTrainer (line 255) | class TestGRPOWithReplayBufferTrainer(TrlTestCase): method test_training_with_replay_buffer (line 256) | def test_training_with_replay_buffer(self, scale_rewards): FILE: tests/experimental/test_gspo_token_trainer.py class TestGSPOTokenTrainer (line 30) | class TestGSPOTokenTrainer(TrlTestCase): method test_training (line 31) | def test_training(self): FILE: tests/experimental/test_judges.py class RandomBinaryJudge (line 28) | class RandomBinaryJudge(BaseBinaryJudge): method judge (line 33) | def judge(self, prompts, completions, gold_completions=None, shuffle_o... class TestJudges (line 37) | class TestJudges(TrlTestCase): method _get_prompts_and_pairwise_completions (line 38) | def _get_prompts_and_pairwise_completions(self): method _get_prompts_and_single_completions (line 43) | def _get_prompts_and_single_completions(self): method test_all_true_judge (line 48) | def test_all_true_judge(self): method test_hugging_face_judge (line 56) | def test_hugging_face_judge(self): method load_pair_rm_judge (line 64) | def load_pair_rm_judge(self): method test_pair_rm_judge (line 83) | def test_pair_rm_judge(self): method test_pair_rm_judge_return_scores (line 100) | def test_pair_rm_judge_return_scores(self): FILE: tests/experimental/test_kto_trainer.py class TestKTOTrainer (line 27) | class TestKTOTrainer(TrlTestCase): method setup_method (line 28) | def setup_method(self): method test_kto_trainer (line 44) | def test_kto_trainer(self, config_name, loss_type, pre_compute, eval_d... method test_kto_trainer_with_ref_model_is_model (line 82) | def test_kto_trainer_with_ref_model_is_model(self): method test_tokenize_and_process_tokens (line 101) | def test_tokenize_and_process_tokens(self): method test_kto_trainer_without_providing_ref_model (line 175) | def test_kto_trainer_without_providing_ref_model(self): method test_kto_trainer_without_providing_ref_model_with_lora (line 212) | def test_kto_trainer_without_providing_ref_model_with_lora(self): method test_kto_trainer_generate_during_eval_no_wandb (line 261) | def test_kto_trainer_generate_during_eval_no_wandb(self): method test_kto_trainer_with_liger (line 292) | def test_kto_trainer_with_liger(self): method test_compute_metrics (line 322) | def test_compute_metrics(self): FILE: tests/experimental/test_merge_model_callback.py class TestMergeModelCallback (line 29) | class TestMergeModelCallback(TrlTestCase): method setup_method (line 30) | def setup_method(self): method test_callback (line 37) | def test_callback(self): method test_every_checkpoint (line 59) | def test_every_checkpoint(self): FILE: tests/experimental/test_minillm_trainer.py class TestMiniLLMTrainer (line 25) | class TestMiniLLMTrainer(TrlTestCase): method test_train (line 26) | def test_train(self): FILE: tests/experimental/test_modeling_value_head.py class TestReferenceModel (line 24) | class TestReferenceModel(TrlTestCase): method setup_method (line 25) | def setup_method(self): method test_independent_reference (line 31) | def test_independent_reference(self): method test_shared_layers (line 65) | def test_shared_layers(self): FILE: tests/experimental/test_nash_md_trainer.py class TestGeometricMixtureWrapper (line 33) | class TestGeometricMixtureWrapper(TrlTestCase): method setup_method (line 34) | def setup_method(self): method test_forward (line 45) | def test_forward(self): method test_mixture_coefficient (line 55) | def test_mixture_coefficient(self): method test_prepare_inputs_for_generation (line 70) | def test_prepare_inputs_for_generation(self): class TestNashMDTrainer (line 81) | class TestNashMDTrainer(TrlTestCase): method setup_method (line 82) | def setup_method(self): method test_nash_md_trainer_training (line 91) | def test_nash_md_trainer_training(self, config_name): method test_training_with_peft (line 120) | def test_training_with_peft(self): method test_training_with_peft_and_ref_model (line 148) | def test_training_with_peft_and_ref_model(self): method test_training_pre_pefted_model_implicit_ref_with_reward_model (line 177) | def test_training_pre_pefted_model_implicit_ref_with_reward_model(self): method test_nash_md_trainer_judge_training (line 209) | def test_nash_md_trainer_judge_training(self, config_name): FILE: tests/experimental/test_online_dpo_trainer.py class TestOnlineDPOTrainer (line 44) | class TestOnlineDPOTrainer(TrlTestCase): method setup_method (line 45) | def setup_method(self): method test_training (line 58) | def test_training(self, config_name): method test_training_model_str (line 83) | def test_training_model_str(self): method test_training_with_ref_model (line 108) | def test_training_with_ref_model(self): method test_ref_model_is_model (line 134) | def test_ref_model_is_model(self): method test_training_with_peft (line 156) | def test_training_with_peft(self): method test_training_with_peft_and_ref_model (line 185) | def test_training_with_peft_and_ref_model(self): method test_training_with_judge (line 216) | def test_training_with_judge(self, config_name): method test_training_with_vllm_server (line 244) | def test_training_with_vllm_server(self, config_name): method test_training_with_vllm_colocate (line 285) | def test_training_with_vllm_colocate(self): method test_vllm_config_validation (line 344) | def test_vllm_config_validation(self): method test_generation_config_setup (line 371) | def test_generation_config_setup(self): method test_training_with_transformers_paged (line 410) | def test_training_with_transformers_paged(self, config_name): method test_training_with_reward_funcs (line 439) | def test_training_with_reward_funcs(self, config_name): class TestOnlineDPOVisionTrainer (line 472) | class TestOnlineDPOVisionTrainer(TrlTestCase): method test_online_dpo_vlm_trainer (line 480) | def test_online_dpo_vlm_trainer(self, model_id): FILE: tests/experimental/test_orpo_trainer.py class TestORPOTrainer (line 25) | class TestORPOTrainer(TrlTestCase): method setup_method (line 26) | def setup_method(self): method test_orpo_trainer (line 45) | def test_orpo_trainer(self, name, config_name): method test_orpo_trainer_with_lora (line 98) | def test_orpo_trainer_with_lora(self, config_name): method test_compute_metrics (line 145) | def test_compute_metrics(self): FILE: tests/experimental/test_ppo_trainer.py class TestBatchGeneration (line 74) | class TestBatchGeneration(TrlTestCase): method setup_method (line 75) | def setup_method(self): method test_mini_batch_generation (line 95) | def test_mini_batch_generation(self): method test_single_batch_generation (line 114) | def test_single_batch_generation(self): class BaseTester (line 134) | class BaseTester: class VHeadModelTester (line 135) | class VHeadModelTester(TrlTestCase): method setup_method (line 140) | def setup_method(self): method test_value_head (line 143) | def test_value_head(self): method test_value_head_shape (line 151) | def test_value_head_shape(self): method test_value_head_init_random (line 159) | def test_value_head_init_random(self): method test_value_head_not_str (line 168) | def test_value_head_not_str(self): method test_from_save_trl (line 178) | def test_from_save_trl(self): method test_from_save_trl_sharded (line 194) | def test_from_save_trl_sharded(self): method test_from_save_transformers_sharded (line 209) | def test_from_save_transformers_sharded(self): method test_from_save_transformers (line 229) | def test_from_save_transformers(self): class TestCausalLMValueHeadModel (line 264) | class TestCausalLMValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase): method teardown_method (line 273) | def teardown_method(self): method test_inference (line 277) | def test_inference(self): method test_dropout_config (line 293) | def test_dropout_config(self): method test_dropout_kwargs (line 305) | def test_dropout_kwargs(self): method test_generate (line 323) | def test_generate(self, model_name): method test_transformers_bf16_kwargs (line 334) | def test_transformers_bf16_kwargs(self): method test_push_to_hub (line 359) | def test_push_to_hub(self): class TestSeq2SeqValueHeadModel (line 378) | class TestSeq2SeqValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase): method teardown_method (line 387) | def teardown_method(self): method test_inference (line 391) | def test_inference(self): method test_dropout_config (line 408) | def test_dropout_config(self): method test_dropout_kwargs (line 420) | def test_dropout_kwargs(self): method test_generate (line 438) | def test_generate(self, model_name): method test_push_to_hub (line 451) | def test_push_to_hub(self): method test_transformers_bf16_kwargs (line 469) | def test_transformers_bf16_kwargs(self): class TestPeftModel (line 493) | class TestPeftModel(TrlTestCase): method setup_method (line 494) | def setup_method(self): method test_create_peft_model (line 504) | def test_create_peft_model(self): method test_peft_requires_grad (line 513) | def test_peft_requires_grad(self): method test_check_peft_model_nb_trainable_params (line 525) | def test_check_peft_model_nb_trainable_params(self): method test_create_peft_model_from_config (line 543) | def test_create_peft_model_from_config(self): method test_create_bnb_peft_model_from_config (line 562) | def test_create_bnb_peft_model_from_config(self): method test_save_pretrained_peft (line 588) | def test_save_pretrained_peft(self): method test_load_pretrained_peft (line 622) | def test_load_pretrained_peft(self): method test_continue_training_peft_model (line 647) | def test_continue_training_peft_model(self): class TestCore (line 662) | class TestCore(TrlTestCase): method setup_method (line 667) | def setup_method(self): method test_masked_mean (line 672) | def test_masked_mean(self): method test_masked_var (line 675) | def test_masked_var(self): method test_masked_whiten (line 678) | def test_masked_whiten(self): class TestPPOTrainer (line 689) | class TestPPOTrainer(TrlTestCase): method setup_method (line 690) | def setup_method(self): method test_basic_training (line 715) | def test_basic_training(self): method test_peft_training (line 767) | def test_peft_training(self): FILE: tests/experimental/test_prm_trainer.py class TestComputeAccuracy (line 34) | class TestComputeAccuracy(TrlTestCase): method test_token_classification_task (line 35) | def test_token_classification_task(self): method test_token_classification_task_with_ignored_tokens_0 (line 49) | def test_token_classification_task_with_ignored_tokens_0(self): method test_token_classification_task_with_ignored_tokens_1 (line 63) | def test_token_classification_task_with_ignored_tokens_1(self): method test_rewards_comparison_task (line 77) | def test_rewards_comparison_task(self, caplog): class TestTokenizeRow (line 101) | class TestTokenizeRow(TrlTestCase): method setup_method (line 102) | def setup_method(self): method test_tokenize_row_no_truncation (line 125) | def test_tokenize_row_no_truncation(self): method test_tokenize_row_train_on_last_step_only (line 149) | def test_tokenize_row_train_on_last_step_only(self): method test_tokenize_row_completion_truncation (line 172) | def test_tokenize_row_completion_truncation(self): method test_tokenize_row_prompt_completion_truncation (line 196) | def test_tokenize_row_prompt_completion_truncation(self): method test_tokenize_row_multi_token_separator (line 220) | def test_tokenize_row_multi_token_separator(self): class TestPRMTrainer (line 245) | class TestPRMTrainer(TrlTestCase): method setup_method (line 246) | def setup_method(self): method test_train_full (line 252) | def test_train_full(self, train_on_last_step_only): method test_train_full_pretokenized (line 272) | def test_train_full_pretokenized(self): method test_train_lora (line 326) | def test_train_lora(self): method test_tags (line 370) | def test_tags(self): FILE: tests/experimental/test_utils.py class TestDataCollatorForChatML (line 24) | class TestDataCollatorForChatML(TrlTestCase): method setup_method (line 25) | def setup_method(self): method test_data_collator_for_chatml (line 50) | def test_data_collator_for_chatml(self): FILE: tests/experimental/test_winrate_callback.py class HalfPairwiseJudge (line 30) | class HalfPairwiseJudge(BasePairwiseJudge): method judge (line 33) | def judge(self, prompts, completions, shuffle_order=True, return_score... class TrainerWithRefModel (line 41) | class TrainerWithRefModel(Trainer): method __init__ (line 44) | def __init__(self, model, ref_model, args, train_dataset, eval_dataset... class TestWinRateCallback (line 56) | class TestWinRateCallback(TrlTestCase): method setup_method (line 57) | def setup_method(self): method test_basic (line 86) | def test_basic(self): method test_without_ref_model (line 112) | def test_without_ref_model(self): method test_soft_judge (line 138) | def test_soft_judge(self): method test_lora (line 182) | def test_lora(self): FILE: tests/experimental/test_xpo_trainer.py class TestXPOTrainer (line 31) | class TestXPOTrainer(TrlTestCase): method setup_method (line 32) | def setup_method(self): method test_xpo_trainer_training (line 41) | def test_xpo_trainer_training(self, config_name): method test_training_with_peft (line 70) | def test_training_with_peft(self): method test_training_with_peft_and_ref_model (line 98) | def test_training_with_peft_and_ref_model(self): method test_training_pre_pefted_model_implicit_ref (line 127) | def test_training_pre_pefted_model_implicit_ref(self): method test_xpo_trainer_judge_training (line 157) | def test_xpo_trainer_judge_training(self, config_name): FILE: tests/experimental/testing_utils.py class RandomPairwiseJudge (line 19) | class RandomPairwiseJudge(BasePairwiseJudge): method judge (line 24) | def judge(self, prompts, completions, shuffle_order=True, return_score... FILE: tests/test_activation_offloading.py class TestActivationOffloading (line 30) | class TestActivationOffloading(TrlTestCase): method test_offloading_with_peft_models (line 33) | def test_offloading_with_peft_models(self) -> None: method test_noop_manager_with_offloading (line 80) | def test_noop_manager_with_offloading(self): method test_min_offload_size (line 110) | def test_min_offload_size(self): method test_real_hf_model (line 127) | def test_real_hf_model(self): method test_tensor_deduplication (line 159) | def test_tensor_deduplication(self): method test_parameter_filtering (line 200) | def test_parameter_filtering(self): FILE: tests/test_callbacks.py class TestLogCompletionsCallback (line 27) | class TestLogCompletionsCallback(TrlTestCase): method setup_method (line 28) | def setup_method(self): method test_basic_wandb (line 45) | def test_basic_wandb(self): method test_basic_comet (line 82) | def test_basic_comet(self): class TestBEMACallback (line 120) | class TestBEMACallback(TrlTestCase): method setup_method (line 121) | def setup_method(self): method test_model_saved (line 136) | def test_model_saved(self): method test_update_frequency_0 (line 154) | def test_update_frequency_0(self): method test_update_frequency_1 (line 174) | def test_update_frequency_1(self): method test_update_frequency_2 (line 194) | def test_update_frequency_2(self): method test_no_bema (line 214) | def test_no_bema(self): method test_no_ema (line 227) | def test_no_ema(self): FILE: tests/test_chat_template_utils.py class TestCloneChatTemplate (line 33) | class TestCloneChatTemplate(TrlTestCase): method test_clone (line 34) | def test_clone(self): method test_clone_with_resize (line 45) | def test_clone_with_resize(self): method test_clone_with_resize_and_extra_tokens_already_in_vocab (line 60) | def test_clone_with_resize_and_extra_tokens_already_in_vocab(self): method test_apply_new_chat_template (line 80) | def test_apply_new_chat_template(self): method test_clone_with_sequence_classification_model (line 99) | def test_clone_with_sequence_classification_model(self): class TestAddResponseSchema (line 126) | class TestAddResponseSchema: method test_add_response_schema (line 127) | def test_add_response_schema(self, tokenizer_name): class TestIsChatTemplatePrefixPreserving (line 146) | class TestIsChatTemplatePrefixPreserving: method test_prefix_preserving_template (line 147) | def test_prefix_preserving_template(self): method test_non_prefix_preserving_template (line 165) | def test_non_prefix_preserving_template(self): class TestGetTrainingChatTemplate (line 231) | class TestGetTrainingChatTemplate: method test_new_chat_template_is_prefix_preserving (line 232) | def test_new_chat_template_is_prefix_preserving(self, tokenizer_name): method test_behavior_unchanged_single_user_no_generation_prompt (line 238) | def test_behavior_unchanged_single_user_no_generation_prompt(self, tok... method test_behavior_unchanged_single_user_with_generation_prompt (line 246) | def test_behavior_unchanged_single_user_with_generation_prompt(self, t... method test_behavior_unchanged_single_user_and_final_assistant_plain_content (line 259) | def test_behavior_unchanged_single_user_and_final_assistant_plain_cont... method test_behavior_unchanged_final_assistant_with_reasoning_content (line 270) | def test_behavior_unchanged_final_assistant_with_reasoning_content(sel... method test_behavior_unchanged_final_assistant_with_existing_think_tags (line 285) | def test_behavior_unchanged_final_assistant_with_existing_think_tags(s... method test_behavior_unchanged_assistant_with_tool_calls (line 299) | def test_behavior_unchanged_assistant_with_tool_calls(self, tokenizer_... method test_behavior_unchanged_with_tools_with_and_without_system_message (line 314) | def test_behavior_unchanged_with_tools_with_and_without_system_message... method test_behavior_unchanged_with_tools_with_system_message (line 339) | def test_behavior_unchanged_with_tools_with_system_message(self, token... method test_behavior_unchanged_generation_prompt_with_enable_thinking_false (line 364) | def test_behavior_unchanged_generation_prompt_with_enable_thinking_fal... class TestParseResponse (line 394) | class TestParseResponse: method test_parse_response (line 395) | def test_parse_response(self, tokenizer_name): method test_parse_response_with_reasoning_content (line 408) | def test_parse_response_with_reasoning_content(self, tokenizer_name): method test_parse_response_tool_call (line 425) | def test_parse_response_tool_call(self, tokenizer_name): method test_parse_response_tool_call_with_content (line 439) | def test_parse_response_tool_call_with_content(self, tokenizer_name): method test_parse_response_tool_call_without_arguments (line 453) | def test_parse_response_tool_call_without_arguments(self, tokenizer_na... method test_parse_response_multiple_tool_calls (line 467) | def test_parse_response_multiple_tool_calls(self, tokenizer_name): method test_parse_response_malformed_tool_call (line 484) | def test_parse_response_malformed_tool_call(self, tokenizer_name): FILE: tests/test_cli.py function test_help_no_type_error (line 26) | def test_help_no_type_error(command): class TestCLI (line 37) | class TestCLI(TrlTestCase): method test_dpo (line 38) | def test_dpo(self): method test_dpo_multiple_loss_types (line 45) | def test_dpo_multiple_loss_types(self): method test_env (line 53) | def test_env(self, mock_stdout): method test_grpo (line 61) | def test_grpo(self): method test_kto (line 68) | def test_kto(self): method test_reward (line 75) | def test_reward(self): method test_rloo (line 82) | def test_rloo(self): method test_sft (line 89) | def test_sft(self): method test_sft_config_file (line 96) | def test_sft_config_file(self): method test_vllm_serve_config_file (line 122) | def test_vllm_serve_config_file(self): FILE: tests/test_cli_utils.py class MyDataclass (line 29) | class MyDataclass: class InvalidDataclass (line 35) | class InvalidDataclass: class TestTrlParser (line 39) | class TestTrlParser(TrlTestCase): method test_init_without_config_field (line 40) | def test_init_without_config_field(self): method test_init_with_config_field (line 45) | def test_init_with_config_field(self): method test_parse_args_and_config_with_valid_config (line 53) | def test_parse_args_and_config_with_valid_config(self, mock_environ, m... method test_parse_args_and_arg_override_config (line 80) | def test_parse_args_and_arg_override_config(self, mock_yaml_load): method test_parse_args_and_config_with_invalid_env (line 98) | def test_parse_args_and_config_with_invalid_env(self, mock_yaml_load): method test_parse_args_and_config_without_config (line 109) | def test_parse_args_and_config_without_config(self): method test_set_defaults_with_config (line 124) | def test_set_defaults_with_config(self): method test_parse_args_and_config_with_remaining_strings (line 137) | def test_parse_args_and_config_with_remaining_strings(self): method test_parse_args_and_config_with_remaining_strings_in_config_and_args (line 154) | def test_parse_args_and_config_with_remaining_strings_in_config_and_ar... method test_subparsers_with_config_defaults (line 172) | def test_subparsers_with_config_defaults(self, mock_yaml_load): method test_subparsers_with_config_defaults_and_arg_override (line 198) | def test_subparsers_with_config_defaults_and_arg_override(self, mock_y... method test_subparsers_with_config_defaults_and_arg_override_wrong_name (line 221) | def test_subparsers_with_config_defaults_and_arg_override_wrong_name(s... method test_subparsers_multiple_with_config_defaults (line 243) | def test_subparsers_multiple_with_config_defaults(self, mock_yaml_load): class TestGetDataset (line 270) | class TestGetDataset: method test_single_dataset_with_config (line 271) | def test_single_dataset_with_config(self): method test_single_dataset_preference_config (line 279) | def test_single_dataset_preference_config(self): method test_single_dataset_streaming (line 287) | def test_single_dataset_streaming(self): method test_dataset_mixture_basic (line 296) | def test_dataset_mixture_basic(self): method test_dataset_mixture_with_weights (line 315) | def test_dataset_mixture_with_weights(self): method test_dataset_mixture_with_test_split (line 336) | def test_dataset_mixture_with_test_split(self): method test_empty_dataset_mixture_raises_error (line 348) | def test_empty_dataset_mixture_raises_error(self): method test_mixture_multiple_different_configs (line 354) | def test_mixture_multiple_different_configs(self): method test_trlparser_parses_yaml_config_correctly (line 367) | def test_trlparser_parses_yaml_config_correctly(self): method test_trlparser_parses_yaml_and_loads_dataset (line 407) | def test_trlparser_parses_yaml_and_loads_dataset(self): FILE: tests/test_data_utils.py class TestPrepareMultimodalMessages (line 49) | class TestPrepareMultimodalMessages: method test_basic_user_assistant_conversation (line 50) | def test_basic_user_assistant_conversation(self): method test_first_user_message_gets_image (line 72) | def test_first_user_message_gets_image(self): method test_multiple_images (line 100) | def test_multiple_images(self): method test_system_message_transformation (line 127) | def test_system_message_transformation(self): method test_already_prepared_messages_unchanged (line 150) | def test_already_prepared_messages_unchanged(self): method test_mixed_prepared_and_unprepared_messages (line 178) | def test_mixed_prepared_and_unprepared_messages(self): method test_message_with_tool_calling_turns (line 206) | def test_message_with_tool_calling_turns(self): class TestPrepareMultimodalMessagesVLLM (line 250) | class TestPrepareMultimodalMessagesVLLM: method test_single_image_conversion (line 251) | def test_single_image_conversion(self): method test_mixed_content_conversion (line 274) | def test_mixed_content_conversion(self): method test_no_images (line 291) | def test_no_images(self): method test_multiple_messages (line 302) | def test_multiple_messages(self): method test_deepcopy_integrity (line 323) | def test_deepcopy_integrity(self): class TestIsConversational (line 341) | class TestIsConversational(TrlTestCase): method test_conversational (line 469) | def test_conversational(self, example): method test_non_conversational (line 473) | def test_non_conversational(self, example): class TestIsConversationalFromValue (line 477) | class TestIsConversationalFromValue(TrlTestCase): method test_positive_1 (line 478) | def test_positive_1(self): method test_negative_1 (line 487) | def test_negative_1(self): method test_negative_2 (line 496) | def test_negative_2(self): class TestApplyChatTemplate (line 501) | class TestApplyChatTemplate(TrlTestCase): method test_apply_chat_template (line 582) | def test_apply_chat_template(self, tokenizer_id, example): method test_maybe_apply_chat_template (line 609) | def test_maybe_apply_chat_template(self, tokenizer_id, example): method test_apply_chat_template_with_chat_template_kwargs (line 633) | def test_apply_chat_template_with_chat_template_kwargs(self): method test_apply_chat_template_with_tools (line 656) | def test_apply_chat_template_with_tools(self): class TestApplyChatTemplateHarmony (line 688) | class TestApplyChatTemplateHarmony(TrlTestCase): method test_language_modeling (line 689) | def test_language_modeling(self): method test_prompt_only (line 720) | def test_prompt_only(self): method test_prompt_completion (line 750) | def test_prompt_completion(self): method test_preference (line 785) | def test_preference(self): method test_preference_with_implicit_prompt (line 825) | def test_preference_with_implicit_prompt(self): method test_unpaired_preference (line 876) | def test_unpaired_preference(self): class TestUnpairPreferenceDataset (line 914) | class TestUnpairPreferenceDataset(TrlTestCase): method test_unpair_preference_dataset (line 931) | def test_unpair_preference_dataset(self): method test_unpair_preference_dataset_dict (line 938) | def test_unpair_preference_dataset_dict(self): method test_maybe_unpair_preference_dataset (line 946) | def test_maybe_unpair_preference_dataset(self): method test_maybe_unpair_preference_dataset_dict (line 953) | def test_maybe_unpair_preference_dataset_dict(self): method test_maybe_unpair_preference_dataset_already_paired (line 961) | def test_maybe_unpair_preference_dataset_already_paired(self): method test_maybe_unpair_preference_dataset_dict_already_paired (line 968) | def test_maybe_unpair_preference_dataset_dict_already_paired(self): class TestExtractPrompt (line 976) | class TestExtractPrompt(TrlTestCase): method test_extract_prompt_conversational (line 1011) | def test_extract_prompt_conversational(self): method test_maybe_extract_prompt_conversational (line 1018) | def test_maybe_extract_prompt_conversational(self): method test_maybe_extract_prompt_conversational_already_explicit (line 1025) | def test_maybe_extract_prompt_conversational_already_explicit(self): method test_extract_prompt_standard (line 1032) | def test_extract_prompt_standard(self): method test_maybe_extract_prompt_standard (line 1039) | def test_maybe_extract_prompt_standard(self): method test_maybe_extract_prompt_standard_already_explicit (line 1046) | def test_maybe_extract_prompt_standard_already_explicit(self): class TestPackDatasetWrapped (line 1052) | class TestPackDatasetWrapped(TrlTestCase): method test_with_dataset (line 1053) | def test_with_dataset(self): method test_with_iterable_dataset (line 1070) | def test_with_iterable_dataset(self): class TestPackDatasetBfd (line 1089) | class TestPackDatasetBfd(TrlTestCase): method test_with_dataset (line 1090) | def test_with_dataset(self): method test_with_iterable_dataset (line 1109) | def test_with_iterable_dataset(self): method test_with_overlong_0 (line 1126) | def test_with_overlong_0(self): method test_with_overlong_two_coluns (line 1139) | def test_with_overlong_two_coluns(self): method test_with_non_power_of_2 (line 1154) | def test_with_non_power_of_2(self): method test_default_no_split (line 1167) | def test_default_no_split(self): method test_with_empty_sequences (line 1182) | def test_with_empty_sequences(self): class TestTruncateExamples (line 1196) | class TestTruncateExamples(TrlTestCase): method test_with_dataset (line 1197) | def test_with_dataset(self): method test_with_iterable_dataset (line 1214) | def test_with_iterable_dataset(self): method test_with_extra_column (line 1232) | def test_with_extra_column(self): method test_with_keep_end (line 1248) | def test_with_keep_end(self): method test_with_keep_end_and_zero_max_length (line 1261) | def test_with_keep_end_and_zero_max_length(self): class TestMaybeConvertToChatML (line 1275) | class TestMaybeConvertToChatML(TrlTestCase): method test_with_conversations_key (line 1276) | def test_with_conversations_key(self): method test_without_conversations_key (line 1292) | def test_without_conversations_key(self): method test_not_conversional (line 1304) | def test_not_conversional(self): method test_already_chatml (line 1309) | def test_already_chatml(self): FILE: tests/test_dpo_trainer.py class TestDataCollatorForPreference (line 42) | class TestDataCollatorForPreference(TrlTestCase): method test_padding_and_masks (line 43) | def test_padding_and_masks(self): method test_optional_reference_logps (line 81) | def test_optional_reference_logps(self): method test_with_pad_to_multiple_of (line 114) | def test_with_pad_to_multiple_of(self): class TestDataCollatorForVisionPreference (line 135) | class TestDataCollatorForVisionPreference(TrlTestCase): method test_mm_token_type_ids_shape (line 141) | def test_mm_token_type_ids_shape(self): class TestDPOTrainer (line 167) | class TestDPOTrainer(TrlTestCase): method test_train (line 176) | def test_train(self, model_id): method test_train_gpt_oss (line 203) | def test_train_gpt_oss(self): method test_train_model (line 231) | def test_train_model(self): method test_train_loss_types (line 282) | def test_train_loss_types(self, loss_type): method test_train_multi_loss_types (line 317) | def test_train_multi_loss_types(self): method test_train_with_wpo (line 348) | def test_train_with_wpo(self): method test_train_with_ld (line 379) | def test_train_with_ld(self): method test_train_with_f_divergence (line 414) | def test_train_with_f_divergence(self, f_divergence_type): method test_train_with_explicit_ref_model (line 445) | def test_train_with_explicit_ref_model(self): method test_training_with_sync_ref_model (line 483) | def test_training_with_sync_ref_model(self): method test_train_model_dtype (line 517) | def test_train_model_dtype(self): method test_train_dense_with_peft_config_lora (line 553) | def test_train_dense_with_peft_config_lora(self): method test_train_moe_with_peft_config (line 594) | def test_train_moe_with_peft_config(self): method test_train_peft_model (line 635) | def test_train_peft_model(self): method test_train_with_peft_config_and_gradient_checkpointing (line 679) | def test_train_with_peft_config_and_gradient_checkpointing(self): method test_train_with_liger (line 721) | def test_train_with_liger(self): method test_train_with_iterable_dataset (line 750) | def test_train_with_iterable_dataset(self): method test_train_padding_free (line 781) | def test_train_padding_free(self): method test_train_with_chat_template_kwargs (line 812) | def test_train_with_chat_template_kwargs(self): method test_train_toolcall_data (line 866) | def test_train_toolcall_data(self): method test_train_with_eval (line 894) | def test_train_with_eval(self): method test_train_with_multiple_eval_dataset (line 913) | def test_train_with_multiple_eval_dataset(self): method test_train_with_compute_metrics (line 932) | def test_train_with_compute_metrics(self): method test_train_with_gradient_checkpointing (line 963) | def test_train_with_gradient_checkpointing(self): method test_tag_added (line 992) | def test_tag_added(self): method test_tag_added_peft (line 1006) | def test_tag_added_peft(self): method test_train_vlm (line 1054) | def test_train_vlm(self, model_id): method test_train_vlm_multi_image (line 1106) | def test_train_vlm_multi_image(self, model_id): method test_train_vlm_gemma_3n (line 1143) | def test_train_vlm_gemma_3n(self): method test_train_vlm_text_only_data (line 1186) | def test_train_vlm_text_only_data(self, model_id, dataset_config): method test_train_vlm_with_max_length (line 1216) | def test_train_vlm_with_max_length(self): method test_peft_with_quantization (line 1237) | def test_peft_with_quantization(self): method test_train_vlm_keep_end_raises (line 1292) | def test_train_vlm_keep_end_raises(self): FILE: tests/test_grpo_trainer.py function multiply_tool (line 62) | def multiply_tool(a: int, b: int) -> int: function async_multiply_tool (line 76) | async def async_multiply_tool(a: int, b: int) -> int: class TestGetHighEntropyMask (line 90) | class TestGetHighEntropyMask(TrlTestCase): method get_high_entropy_mask (line 91) | def get_high_entropy_mask(self, entropies, mask, threshold): method test_compute_entropy_mask_0 (line 109) | def test_compute_entropy_mask_0(self): method test_compute_entropy_mask_1 (line 121) | def test_compute_entropy_mask_1(self): method test_compute_entropy_mask_lower_threshold (line 129) | def test_compute_entropy_mask_lower_threshold(self): method test_compute_entropy_threshold_0 (line 137) | def test_compute_entropy_threshold_0(self): method test_compute_entropy_threshold_1 (line 145) | def test_compute_entropy_threshold_1(self): method test_compute_entropy_all_masked (line 153) | def test_compute_entropy_all_masked(self): class TestGRPORolloutDispatch (line 162) | class TestGRPORolloutDispatch: method _make_trainer (line 163) | def _make_trainer(self): method test_generate_prefers_rollout_func (line 202) | def test_generate_prefers_rollout_func(self): method test_generate_rollout_func_syncs_vllm_weights_when_needed (line 220) | def test_generate_rollout_func_syncs_vllm_weights_when_needed(self): method test_generate_rollout_func_raises_when_required_keys_are_missing (line 233) | def test_generate_rollout_func_raises_when_required_keys_are_missing(s... class TestGRPOTrainer (line 241) | class TestGRPOTrainer(TrlTestCase): method test_init_minimal (line 242) | def test_init_minimal(self): method test_training (line 252) | def test_training(self, config_name): method test_training_loss_types (line 282) | def test_training_loss_types(self, loss_type): method test_training_with_eval (line 314) | def test_training_with_eval(self): method test_training_with_num_generations_eval (line 337) | def test_training_with_num_generations_eval(self): method test_training_eval_on_start (line 364) | def test_training_eval_on_start(self): method test_training_multiple_iterations (line 388) | def test_training_multiple_iterations(self): method test_training_peft_config (line 419) | def test_training_peft_config(self): method test_training_peft_model (line 455) | def test_training_peft_model(self): method test_training_peft_with_gradient_checkpointing (line 495) | def test_training_peft_with_gradient_checkpointing(self): method test_training_different_reward_model (line 531) | def test_training_different_reward_model(self): method test_training_reward_func_standard (line 570) | def test_training_reward_func_standard(self): method test_training_reward_func_conversational (line 604) | def test_training_reward_func_conversational(self): method test_training_multiple_reward_funcs (line 639) | def test_training_multiple_reward_funcs(self): method test_training_sync_and_async_reward_funcs (line 677) | def test_training_sync_and_async_reward_funcs(self): method test_training_multiple_reward_funcs_with_None_output (line 718) | def test_training_multiple_reward_funcs_with_None_output(self): method test_training_multiple_reward_funcs_with_weights (line 762) | def test_training_multiple_reward_funcs_with_weights(self): method test_training_multiple_mixed_reward_funcs (line 806) | def test_training_multiple_mixed_reward_funcs(self): method test_training_reward_func_additional_column (line 840) | def test_training_reward_func_additional_column(self): method test_training_with_sync_ref_model (line 880) | def test_training_with_sync_ref_model(self): method test_training_beta_non_zero (line 916) | def test_training_beta_non_zero(self): method test_training_with_pad_to_multiple_of (line 945) | def test_training_with_pad_to_multiple_of(self): method test_get_off_policy_mask (line 975) | def test_get_off_policy_mask(self): method test_get_off_policy_mask_padding (line 1002) | def test_get_off_policy_mask_padding(self): method test_training_with_off_policy_mask (line 1041) | def test_training_with_off_policy_mask(self): method test_training_with_off_policy_mask_with_liger (line 1072) | def test_training_with_off_policy_mask_with_liger(self): method test_compute_liger_loss_passes_vllm_is_ratio (line 1103) | def test_compute_liger_loss_passes_vllm_is_ratio(self): method test_training_with_bias_correction_kl (line 1153) | def test_training_with_bias_correction_kl(self): method test_training_with_cast_lm_head_to_fp32 (line 1188) | def test_training_with_cast_lm_head_to_fp32(self, model_name): method test_training_with_entropy_filter (line 1217) | def test_training_with_entropy_filter(self): method test_training_vllm_and_peft (line 1249) | def test_training_vllm_and_peft(self): method test_training_vllm_structured_outputs (line 1296) | def test_training_vllm_structured_outputs(self): method test_training_vllm_importance_sampling_correction (line 1330) | def test_training_vllm_importance_sampling_correction(self): method test_training_with_additional_generation_kwargs (line 1363) | def test_training_with_additional_generation_kwargs(self): method test_training_vllm_with_additional_generation_kwargs (line 1400) | def test_training_vllm_with_additional_generation_kwargs(self): method test_training_normalize_then_sum_aggregation (line 1436) | def test_training_normalize_then_sum_aggregation(self): method test_training_scale_rewards (line 1475) | def test_training_scale_rewards(self, scale_rewards): method test_training_with_mask_truncated_completions (line 1506) | def test_training_with_mask_truncated_completions(self, mock_generate): method test_training_with_mask_truncated_completions_all_masked (line 1555) | def test_training_with_mask_truncated_completions_all_masked(self): method test_warning_raised_all_rewards_none (line 1593) | def test_warning_raised_all_rewards_none(self, caplog): method test_training_num_generations_larger_than_batch_size (line 1622) | def test_training_num_generations_larger_than_batch_size(self): method test_training_delta_clipping (line 1652) | def test_training_delta_clipping(self): method test_training_multiple_dataloader_workers (line 1682) | def test_training_multiple_dataloader_workers(self): method test_training_with_generation_kwargs (line 1723) | def test_training_with_generation_kwargs(self): method test_training_with_reward_func_accessing_trainer_state (line 1754) | def test_training_with_reward_func_accessing_trainer_state(self): method test_training_reward_func_with_log_extra (line 1779) | def test_training_reward_func_with_log_extra(self): method test_training_reward_func_with_log_metric (line 1805) | def test_training_reward_func_with_log_metric(self): method test_prepare_input_called_with_correct_data (line 1832) | def test_prepare_input_called_with_correct_data(self): method test_training_vlm (line 1901) | def test_training_vlm(self, model_id): method test_training_vlm_with_pad_to_multiple_of (line 1946) | def test_training_vlm_with_pad_to_multiple_of(self): method test_training_vlm_beta_non_zero (line 1989) | def test_training_vlm_beta_non_zero(self, model_id): method test_training_vlm_peft (line 2036) | def test_training_vlm_peft(self, model_id): method test_training_vlm_and_importance_sampling (line 2082) | def test_training_vlm_and_importance_sampling(self, model_id): method test_training_vlm_and_liger (line 2136) | def test_training_vlm_and_liger(self, model_id): method test_training_vlm_and_vllm (line 2185) | def test_training_vlm_and_vllm(self, model_id) -> None: method test_training_vlm_multi_image (line 2226) | def test_training_vlm_multi_image(self, model_id): method test_training_sequence_importance_sampling (line 2261) | def test_training_sequence_importance_sampling(self): method test_training_with_chat_template_kwargs (line 2292) | def test_training_with_chat_template_kwargs(self): method test_training_with_tools (line 2329) | def test_training_with_tools(self, tools: list[Callable]): method test_training_with_environment_factory (line 2418) | def test_training_with_environment_factory(self): method test_training_with_malformed_tool_calls (line 2513) | def test_training_with_malformed_tool_calls(self): method test_mismatched_reward_processing_classes_length (line 2562) | def test_mismatched_reward_processing_classes_length(self): method test_correct_reward_processing_classes_list (line 2588) | def test_correct_reward_processing_classes_list(self): method test_single_reward_model_with_single_processing_class (line 2619) | def test_single_reward_model_with_single_processing_class(self): class TestGRPOTrainerSlow (line 2647) | class TestGRPOTrainerSlow(TrlTestCase): method setup_method (line 2648) | def setup_method(self): method teardown_method (line 2653) | def teardown_method(self): method test_training_with_liger_grpo_kernel (line 2666) | def test_training_with_liger_grpo_kernel(self, model_name): method test_training_with_liger_grpo_kernel_and_peft (line 2712) | def test_training_with_liger_grpo_kernel_and_peft(self, model_name): method test_liger_grpo_kernel_importance_sampling (line 2773) | def test_liger_grpo_kernel_importance_sampling(self): method test_training_with_transformers_paged (line 2820) | def test_training_with_transformers_paged(self, model_name): method test_vlm_training (line 2867) | def test_vlm_training(self, model_name): method test_vlm_processor_vllm_colocate_mode (line 2994) | def test_vlm_processor_vllm_colocate_mode(self): method test_training_vllm (line 3142) | def test_training_vllm(self): FILE: tests/test_model_utils.py class TestDisableGradientCheckpointing (line 20) | class TestDisableGradientCheckpointing: method test_when_disabled (line 21) | def test_when_disabled(self): method test_when_enabled (line 28) | def test_when_enabled(self): FILE: tests/test_reward_trainer.py class TestDataCollatorForPreference (line 34) | class TestDataCollatorForPreference(TrlTestCase): method test_basic_padding (line 35) | def test_basic_padding(self): method test_pad_to_multiple_of (line 50) | def test_pad_to_multiple_of(self): method test_single_example (line 67) | def test_single_example(self): method test_different_pad_token_id (line 77) | def test_different_pad_token_id(self): method test_collate_with_margin (line 94) | def test_collate_with_margin(self): class TestRewardTrainer (line 110) | class TestRewardTrainer(TrlTestCase): method test_raises_error_when_model_num_labels_not_one (line 111) | def test_raises_error_when_model_num_labels_not_one(self): method test_train (line 135) | def test_train(self, model_id): method test_train_dataset_types (line 166) | def test_train_dataset_types(self, config_name): method test_train_model (line 192) | def test_train_model(self): method test_train_from_sequence_classification_model (line 221) | def test_train_from_sequence_classification_model(self): method test_train_model_dtype (line 247) | def test_train_model_dtype(self): method test_train_dense_with_peft_config (line 285) | def test_train_dense_with_peft_config(self): method test_train_moe_with_peft_config (line 322) | def test_train_moe_with_peft_config(self): method test_train_peft_model (line 359) | def test_train_peft_model(self): method test_train_with_peft_config_and_gradient_checkpointing (line 403) | def test_train_with_peft_config_and_gradient_checkpointing(self): method test_train_with_peft_config_and_gradient_checkpointing_reentrant (line 441) | def test_train_with_peft_config_and_gradient_checkpointing_reentrant(s... method test_train_with_pretokenized_data (line 489) | def test_train_with_pretokenized_data(self, chosen_column, rejected_co... method test_train_with_iterable_dataset (line 531) | def test_train_with_iterable_dataset(self): method test_train_with_chat_template_kwargs (line 559) | def test_train_with_chat_template_kwargs(self): method test_train_with_set_chat_template_from_model (line 610) | def test_train_with_set_chat_template_from_model(self): method test_train_with_set_chat_template_from_path (line 642) | def test_train_with_set_chat_template_from_path(self, lazy_shared_data... method test_train_toolcall_data (line 688) | def test_train_toolcall_data(self): method test_train_toolcall_data_as_json (line 714) | def test_train_toolcall_data_as_json(self): method test_train_with_eval (line 748) | def test_train_with_eval(self): method test_train_with_multiple_eval_dataset (line 767) | def test_train_with_multiple_eval_dataset(self): method test_train_with_compute_metrics (line 786) | def test_train_with_compute_metrics(self): method test_train_with_gradient_checkpointing (line 817) | def test_train_with_gradient_checkpointing(self): method test_train_with_gradient_checkpointing_reentrant (line 844) | def test_train_with_gradient_checkpointing_reentrant(self, use_reentra... method test_tag_added (line 875) | def test_tag_added(self): method test_tag_added_peft (line 889) | def test_tag_added_peft(self): method test_train_with_margin (line 903) | def test_train_with_margin(self): method test_train_with_center_rewards_coefficient (line 935) | def test_train_with_center_rewards_coefficient(self): FILE: tests/test_rewards.py class TestThinkFormatReward (line 22) | class TestThinkFormatReward(TrlTestCase): method test_valid_format (line 23) | def test_valid_format(self): method test_invalid_format (line 36) | def test_invalid_format(self): method test_mixed_format (line 53) | def test_mixed_format(self): class TestSoftOverlongPunishmentReward (line 66) | class TestSoftOverlongPunishmentReward: method test_soft_overlong_punishment_short_completion (line 67) | def test_soft_overlong_punishment_short_completion(self): method test_soft_overlong_punishment_long_completion (line 75) | def test_soft_overlong_punishment_long_completion(self): method test_soft_overlong_punishment_intermediate_completion (line 83) | def test_soft_overlong_punishment_intermediate_completion(self): class TestAccuracyReward (line 91) | class TestAccuracyReward: method test_accuracy_reward_correct_answer (line 93) | def test_accuracy_reward_correct_answer(self): method test_accuracy_reward_wrong_answer (line 102) | def test_accuracy_reward_wrong_answer(self): method test_accuracy_reward_wrong_answer_no_latex (line 110) | def test_accuracy_reward_wrong_answer_no_latex(self): method test_accuracy_reward_unparsable_gold (line 118) | def test_accuracy_reward_unparsable_gold(self): method test_accuracy_reward_in_worker_thread (line 133) | def test_accuracy_reward_in_worker_thread(self): class TestReasoningAccuracyReward (line 154) | class TestReasoningAccuracyReward: method test_correct_answer_yields_unit_reward (line 156) | def test_correct_answer_yields_unit_reward(self): method test_correct_answer_with_custom_tags_yields_unit_reward (line 167) | def test_correct_answer_with_custom_tags_yields_unit_reward(self): method test_incorrect_answer_yields_zero_reward (line 178) | def test_incorrect_answer_yields_zero_reward(self): method test_correct_answer_in_reasoning_yields_zero_reward (line 185) | def test_correct_answer_in_reasoning_yields_zero_reward(self): method test_incomplete_reasoning_yields_zero_reward (line 196) | def test_incomplete_reasoning_yields_zero_reward(self): method test_unparsable_gold_solution_yields_none_reward (line 207) | def test_unparsable_gold_solution_yields_none_reward(self): FILE: tests/test_rich_progress_callback.py class DummyModel (line 25) | class DummyModel(nn.Module): method __init__ (line 26) | def __init__(self): method forward (line 30) | def forward(self, x): class TestRichProgressCallback (line 35) | class TestRichProgressCallback(TrlTestCase): method setup_method (line 36) | def setup_method(self): method test_rich_progress_callback_logging (line 41) | def test_rich_progress_callback_logging(self): FILE: tests/test_rloo_trainer.py class TestRLOOTrainer (line 39) | class TestRLOOTrainer(TrlTestCase): method test_init_minimal (line 40) | def test_init_minimal(self): method test_training (line 50) | def test_training(self, config_name): method test_training_with_eval (line 79) | def test_training_with_eval(self): method test_training_with_num_generations_eval (line 102) | def test_training_with_num_generations_eval(self): method test_training_multiple_iterations (line 126) | def test_training_multiple_iterations(self): method test_training_peft_config (line 157) | def test_training_peft_config(self): method test_training_peft_model (line 193) | def test_training_peft_model(self): method test_training_peft_with_gradient_checkpointing (line 233) | def test_training_peft_with_gradient_checkpointing(self): method test_training_different_reward_model (line 269) | def test_training_different_reward_model(self): method test_training_reward_func_standard (line 308) | def test_training_reward_func_standard(self): method test_training_reward_func_conversational (line 342) | def test_training_reward_func_conversational(self): method test_training_multiple_reward_funcs (line 377) | def test_training_multiple_reward_funcs(self): method test_training_sync_and_async_reward_funcs (line 415) | def test_training_sync_and_async_reward_funcs(self): method test_training_multiple_reward_funcs_with_None_output (line 456) | def test_training_multiple_reward_funcs_with_None_output(self): method test_training_multiple_reward_funcs_with_weights (line 500) | def test_training_multiple_reward_funcs_with_weights(self): method test_training_multiple_mixed_reward_funcs (line 544) | def test_training_multiple_mixed_reward_funcs(self): method test_training_reward_func_additional_column (line 578) | def test_training_reward_func_additional_column(self): method test_training_with_sync_ref_model (line 618) | def test_training_with_sync_ref_model(self): method test_training_beta_zero (line 654) | def test_training_beta_zero(self): method test_training_with_pad_to_multiple_of (line 683) | def test_training_with_pad_to_multiple_of(self): method test_training_vllm_and_peft (line 716) | def test_training_vllm_and_peft(self): method test_training_vllm_structured_outputs (line 763) | def test_training_vllm_structured_outputs(self): method test_training_with_additional_generation_kwargs (line 795) | def test_training_with_additional_generation_kwargs(self): method test_training_vllm_with_additional_generation_kwargs (line 832) | def test_training_vllm_with_additional_generation_kwargs(self): method test_training_with_normalized_advantages (line 868) | def test_training_with_normalized_advantages(self): method test_training_with_clipped_rewards (line 898) | def test_training_with_clipped_rewards(self): method test_training_with_mask_truncated_completions (line 929) | def test_training_with_mask_truncated_completions(self, mock_generate): method test_training_with_mask_truncated_completions_all_masked (line 978) | def test_training_with_mask_truncated_completions_all_masked(self): method test_warning_raised_all_rewards_none (line 1016) | def test_warning_raised_all_rewards_none(self, caplog): method test_training_num_generations_larger_than_batch_size (line 1045) | def test_training_num_generations_larger_than_batch_size(self): method test_training_multiple_dataloader_workers (line 1075) | def test_training_multiple_dataloader_workers(self): method test_training_with_generation_kwargs (line 1116) | def test_training_with_generation_kwargs(self): method test_training_with_reward_func_accessing_trainer_state (line 1147) | def test_training_with_reward_func_accessing_trainer_state(self): method test_training_reward_func_with_log_extra (line 1172) | def test_training_reward_func_with_log_extra(self): method test_training_reward_func_with_log_metric (line 1198) | def test_training_reward_func_with_log_metric(self): method test_prepare_input_called_with_correct_data (line 1225) | def test_prepare_input_called_with_correct_data(self): method test_training_vlm (line 1294) | def test_training_vlm(self, model_id): method test_training_vlm_with_pad_to_multiple_of (line 1338) | def test_training_vlm_with_pad_to_multiple_of(self): method test_training_vlm_beta_non_zero (line 1381) | def test_training_vlm_beta_non_zero(self, model_id): method test_training_vlm_peft (line 1428) | def test_training_vlm_peft(self, model_id): method test_training_vlm_and_vllm (line 1477) | def test_training_vlm_and_vllm(self, model_id) -> None: method test_training_vlm_multi_image (line 1518) | def test_training_vlm_multi_image(self, model_id): method test_training_with_chat_template_kwargs (line 1550) | def test_training_with_chat_template_kwargs(self): method test_mismatched_reward_processing_classes_length (line 1581) | def test_mismatched_reward_processing_classes_length(self): method test_correct_reward_processing_classes_list (line 1607) | def test_correct_reward_processing_classes_list(self): method test_single_reward_model_with_single_processing_class (line 1638) | def test_single_reward_model_with_single_processing_class(self): FILE: tests/test_sft_trainer.py class TestDFTLoss (line 61) | class TestDFTLoss(TrlTestCase): method test_dft_loss (line 62) | def test_dft_loss(self): class TestDataCollatorForLanguageModeling (line 84) | class TestDataCollatorForLanguageModeling(TrlTestCase): method test_basic_padding (line 85) | def test_basic_padding(self): method test_completion_mask (line 97) | def test_completion_mask(self): method test_completion_only_loss_disabled (line 112) | def test_completion_only_loss_disabled(self): method test_padding_free_mode (line 128) | def test_padding_free_mode(self): method test_padding_free_with_completion_mask (line 140) | def test_padding_free_with_completion_mask(self): method test_packing (line 155) | def test_packing(self): method test_pad_to_multiple_of (line 172) | def test_pad_to_multiple_of(self): method test_pad_to_multiple_of_and_padding_free (line 184) | def test_pad_to_multiple_of_and_padding_free(self): method test_custom_position_ids_but_no_padding_free (line 196) | def test_custom_position_ids_but_no_padding_free(self): method test_single_example (line 208) | def test_single_example(self): method test_different_pad_token_id (line 220) | def test_different_pad_token_id(self): method test_assistant_masks (line 232) | def test_assistant_masks(self): method test_single_example_single_doc (line 246) | def test_single_example_single_doc(self): method test_single_example_multiple_docs (line 252) | def test_single_example_multiple_docs(self): method test_multiple_examples (line 259) | def test_multiple_examples(self): class TestSFTTrainer (line 267) | class TestSFTTrainer(TrlTestCase): method test_init_with_training_arguments (line 268) | def test_init_with_training_arguments(self): method test_train (line 289) | def test_train(self, model_id): method test_train_gpt_oss (line 312) | def test_train_gpt_oss(self): method test_train_model (line 336) | def test_train_model(self): method test_train_dft_loss (line 364) | def test_train_dft_loss(self): method test_train_moe_model_with_aux_loss (line 398) | def test_train_moe_model_with_aux_loss(self): method test_train_with_formatting_func (line 426) | def test_train_with_formatting_func(self): method test_train_model_dtype (line 458) | def test_train_model_dtype(self): method test_train_dense_with_peft_config_lora (line 494) | def test_train_dense_with_peft_config_lora(self): method test_train_with_peft_config_prompt_tuning (line 539) | def test_train_with_peft_config_prompt_tuning(self, peft_type): method test_train_moe_with_peft_config (line 597) | def test_train_moe_with_peft_config(self): method test_train_peft_model (line 634) | def test_train_peft_model(self): method test_train_with_peft_config_and_gradient_checkpointing (line 674) | def test_train_with_peft_config_and_gradient_checkpointing(self): method test_train_with_peft_config_and_gradient_checkpointing_reentrant (line 712) | def test_train_with_peft_config_and_gradient_checkpointing_reentrant(s... method test_train_with_liger (line 754) | def test_train_with_liger(self): method test_compute_loss_skip_logits_on_eval_without_metrics_with_liger (line 780) | def test_compute_loss_skip_logits_on_eval_without_metrics_with_liger(s... method test_predict_does_not_skip_logits_with_liger (line 822) | def test_predict_does_not_skip_logits_with_liger(self): method test_train_with_non_chatml_conversational_data (line 854) | def test_train_with_non_chatml_conversational_data(self): method test_train_with_pretokenized_data (line 884) | def test_train_with_pretokenized_data(self): method test_train_with_iterable_dataset (line 914) | def test_train_with_iterable_dataset(self): method test_train_padding_free (line 940) | def test_train_padding_free(self): method test_train_packing (line 973) | def test_train_packing(self, packing_strategy): method test_eval_packing (line 1001) | def test_eval_packing(self): method test_only_train_packing (line 1035) | def test_only_train_packing(self): method test_train_with_chat_template_kwargs (line 1068) | def test_train_with_chat_template_kwargs(self): method test_train_assistant_only (line 1118) | def test_train_assistant_only(self): method test_train_completion_only (line 1142) | def test_train_completion_only(self): method test_train_completion_only_harmony (line 1166) | def test_train_completion_only_harmony(self): method test_train_assistant_only_and_completion_only (line 1190) | def test_train_assistant_only_and_completion_only(self): method test_train_assistant_only_iterable_dataset (line 1224) | def test_train_assistant_only_iterable_dataset(self): method test_train_with_set_chat_template_from_model (line 1250) | def test_train_with_set_chat_template_from_model(self): method test_train_with_set_chat_template_from_path (line 1275) | def test_train_with_set_chat_template_from_path(self, lazy_shared_data... method test_train_toolcall_data (line 1314) | def test_train_toolcall_data(self): method test_train_toolcall_data_as_json (line 1338) | def test_train_toolcall_data_as_json(self): method test_train_with_eval (line 1371) | def test_train_with_eval(self): method test_train_with_multiple_eval_dataset (line 1390) | def test_train_with_multiple_eval_dataset(self): method test_train_with_compute_metrics (line 1409) | def test_train_with_compute_metrics(self): method test_train_with_gradient_checkpointing (line 1440) | def test_train_with_gradient_checkpointing(self): method test_train_with_gradient_checkpointing_reentrant (line 1465) | def test_train_with_gradient_checkpointing_reentrant(self, use_reentra... method test_tag_added (line 1494) | def test_tag_added(self): method test_tag_added_peft (line 1508) | def test_tag_added_peft(self): method test_train_vlm (line 1556) | def test_train_vlm(self, model_id): method test_train_vlm_multi_image (line 1607) | def test_train_vlm_multi_image(self, model_id): method test_train_vlm_prompt_completion (line 1649) | def test_train_vlm_prompt_completion(self, model_id): method test_train_vlm_gemma_3n (line 1685) | def test_train_vlm_gemma_3n(self): method test_train_vlm_text_only_data (line 1728) | def test_train_vlm_text_only_data(self, model_id, dataset_config): method test_prompt_tuning (line 1758) | def test_prompt_tuning(self): method test_peft_with_quantization (line 1791) | def test_peft_with_quantization(self): method test_prompt_tuning_peft_model (line 1846) | def test_prompt_tuning_peft_model(self): class TestSFTTrainerSlow (line 1879) | class TestSFTTrainerSlow(TrlTestCase): method setup_method (line 1880) | def setup_method(self): method teardown_method (line 1892) | def teardown_method(self): method test_sft_trainer_transformers_mp (line 1905) | def test_sft_trainer_transformers_mp(self, model_name, packing): method test_sft_trainer_transformers_mp_gc_device_map (line 1949) | def test_sft_trainer_transformers_mp_gc_device_map( method test_sft_trainer_transformers_mp_gc_peft_qlora (line 1997) | def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, p... method test_sft_trainer_with_chat_format_qlora (line 2046) | def test_sft_trainer_with_chat_format_qlora(self, model_name, packing): method test_sft_trainer_with_liger (line 2093) | def test_sft_trainer_with_liger(self, model_name, packing): method test_train_offloading (line 2144) | def test_train_offloading(self, model_name, packing): FILE: tests/test_skills.py class TestGetTrlSkillsDir (line 23) | class TestGetTrlSkillsDir: method test_returns_path_object (line 26) | def test_returns_path_object(self): method test_directory_exists (line 31) | def test_directory_exists(self): method test_is_directory (line 36) | def test_is_directory(self): method test_contains_skills_module (line 41) | def test_contains_skills_module(self): class TestListSkills (line 47) | class TestListSkills: method test_returns_list (line 50) | def test_returns_list(self): method test_contains_trl_training (line 55) | def test_contains_trl_training(self): method test_skills_are_sorted (line 60) | def test_skills_are_sorted(self): method test_with_custom_directory (line 65) | def test_with_custom_directory(self, tmp_path): method test_empty_directory (line 77) | def test_empty_directory(self, tmp_path): method test_nonexistent_directory (line 82) | def test_nonexistent_directory(self, tmp_path): method test_ignores_files (line 88) | def test_ignores_files(self, tmp_path): method test_requires_skill_md (line 97) | def test_requires_skill_md(self, tmp_path): class TestInstallSkill (line 108) | class TestInstallSkill: method test_basic_installation (line 111) | def test_basic_installation(self, tmp_path): method test_creates_target_directory (line 121) | def test_creates_target_directory(self, tmp_path): method test_skill_not_found (line 130) | def test_skill_not_found(self, tmp_path): method test_skill_already_exists_without_force (line 137) | def test_skill_already_exists_without_force(self, tmp_path): method test_force_overwrites_existing (line 148) | def test_force_overwrites_existing(self, tmp_path): method test_force_overwrites_symlink (line 166) | def test_force_overwrites_symlink(self, tmp_path): method test_skill_not_directory (line 182) | def test_skill_not_directory(self, tmp_path): method test_preserves_directory_structure (line 194) | def test_preserves_directory_structure(self, tmp_path): method test_install_to_same_directory_fails (line 212) | def test_install_to_same_directory_fails(self, tmp_path): class TestUninstallSkill (line 227) | class TestUninstallSkill: method test_basic_uninstallation (line 230) | def test_basic_uninstallation(self, tmp_path): method test_skill_not_installed (line 244) | def test_skill_not_installed(self, tmp_path): method test_uninstall_from_nonexistent_directory (line 252) | def test_uninstall_from_nonexistent_directory(self, tmp_path): method test_uninstall_removes_all_contents (line 259) | def test_uninstall_removes_all_contents(self, tmp_path): method test_uninstall_doesnt_affect_other_skills (line 280) | def test_uninstall_doesnt_affect_other_skills(self, tmp_path): class TestIntegration (line 303) | class TestIntegration: method test_full_workflow (line 306) | def test_full_workflow(self, tmp_path): method test_install_uninstall_cycle (line 336) | def test_install_uninstall_cycle(self, tmp_path): method test_force_reinstall_workflow (line 354) | def test_force_reinstall_workflow(self, tmp_path): class TestEdgeCases (line 376) | class TestEdgeCases: method test_skill_with_special_characters_in_name (line 379) | def test_skill_with_special_characters_in_name(self, tmp_path): method test_empty_skill_directory (line 397) | def test_empty_skill_directory(self, tmp_path): method test_skill_with_hidden_files (line 414) | def test_skill_with_hidden_files(self, tmp_path): method test_list_skills_with_symlinks (line 429) | def test_list_skills_with_symlinks(self, tmp_path): class TestListAgentNames (line 448) | class TestListAgentNames: method test_returns_list (line 451) | def test_returns_list(self): method test_contains_expected_agents (line 456) | def test_contains_expected_agents(self): method test_agents_are_sorted (line 463) | def test_agents_are_sorted(self): class TestResolveTargetPath (line 469) | class TestResolveTargetPath: method test_resolve_agent_name_project_scope (line 472) | def test_resolve_agent_name_project_scope(self): method test_resolve_agent_name_global_scope (line 477) | def test_resolve_agent_name_global_scope(self): method test_resolve_custom_path_string (line 482) | def test_resolve_custom_path_string(self): method test_resolve_custom_path_object (line 487) | def test_resolve_custom_path_object(self): method test_resolve_path_with_tilde (line 493) | def test_resolve_path_with_tilde(self): method test_all_predefined_agents (line 499) | def test_all_predefined_agents(self): method test_invalid_scope_for_predefined_agent (line 507) | def test_invalid_scope_for_predefined_agent(self): class TestHighLevelAPI (line 513) | class TestHighLevelAPI: method test_list_skills_with_target_string (line 516) | def test_list_skills_with_target_string(self, tmp_path): method test_list_skills_with_target_path (line 525) | def test_list_skills_with_target_path(self, tmp_path): method test_list_skills_without_target (line 533) | def test_list_skills_without_target(self): method test_install_skill_with_target_string (line 539) | def test_install_skill_with_target_string(self, tmp_path): method test_install_skill_with_target_path (line 545) | def test_install_skill_with_target_path(self, tmp_path): method test_install_skill_with_force (line 551) | def test_install_skill_with_force(self, tmp_path): method test_uninstall_skill_with_target_string (line 558) | def test_uninstall_skill_with_target_string(self, tmp_path): method test_uninstall_skill_with_target_path (line 565) | def test_uninstall_skill_with_target_path(self, tmp_path): method test_install_with_custom_source (line 572) | def test_install_with_custom_source(self, tmp_path): FILE: tests/test_skills_cli.py class TestCLICommands (line 23) | class TestCLICommands: method test_cmd_list_without_target (line 26) | def test_cmd_list_without_target(self, capsys): method test_cmd_list_with_target (line 38) | def test_cmd_list_with_target(self, tmp_path, capsys): method test_cmd_list_empty_target (line 51) | def test_cmd_list_empty_target(self, tmp_path, capsys): method test_cmd_install_single_skill (line 61) | def test_cmd_install_single_skill(self, tmp_path, capsys): method test_cmd_install_all_skills (line 73) | def test_cmd_install_all_skills(self, tmp_path, capsys): method test_cmd_install_no_skill_or_all (line 85) | def test_cmd_install_no_skill_or_all(self, capsys): method test_cmd_install_both_skill_and_all (line 95) | def test_cmd_install_both_skill_and_all(self, capsys): method test_cmd_install_nonexistent_skill (line 105) | def test_cmd_install_nonexistent_skill(self, tmp_path, capsys): method test_cmd_install_already_exists (line 116) | def test_cmd_install_already_exists(self, tmp_path, capsys): method test_cmd_install_with_force (line 130) | def test_cmd_install_with_force(self, tmp_path, capsys): method test_cmd_uninstall_success (line 144) | def test_cmd_uninstall_success(self, tmp_path, capsys): method test_cmd_uninstall_not_installed (line 159) | def test_cmd_uninstall_not_installed(self, tmp_path, capsys): method test_cmd_install_creates_target_directory (line 170) | def test_cmd_install_creates_target_directory(self, tmp_path, capsys): method test_cmd_uninstall_invalid_target (line 187) | def test_cmd_uninstall_invalid_target(self, capsys): class TestCLIArgumentParsing (line 198) | class TestCLIArgumentParsing: method test_add_skills_subcommands_creates_parsers (line 201) | def test_add_skills_subcommands_creates_parsers(self): method test_list_command_optional_target (line 222) | def test_list_command_optional_target(self): method test_install_command_requires_target (line 236) | def test_install_command_requires_target(self): method test_scope_choices (line 246) | def test_scope_choices(self): method test_install_all_flag (line 263) | def test_install_all_flag(self): method test_install_force_flag (line 273) | def test_install_force_flag(self): method test_default_scope_is_project (line 282) | def test_default_scope_is_project(self): FILE: tests/test_utils.py class TestUseAdapter (line 55) | class TestUseAdapter(TrlTestCase): method test_disables_on_none (line 56) | def test_disables_on_none(self): method test_restores_previous_adapter (line 69) | def test_restores_previous_adapter(self): method test_with_multiple_adapters (line 85) | def test_with_multiple_adapters(self): class TestPad (line 107) | class TestPad(TrlTestCase): method test_pad_1_dim_left (line 108) | def test_pad_1_dim_left(self): method test_pad_1_dim_right (line 115) | def test_pad_1_dim_right(self): method test_pad_2_dim_left (line 122) | def test_pad_2_dim_left(self): method test_pad_2_dim_right (line 134) | def test_pad_2_dim_right(self): method test_pad_2_dim_right_multidim (line 146) | def test_pad_2_dim_right_multidim(self): method test_pad_to_multiple_of_1 (line 158) | def test_pad_to_multiple_of_1(self): method test_pad_to_multiple_of_2 (line 166) | def test_pad_to_multiple_of_2(self): method test_pad_to_multiple_of_side_left (line 174) | def test_pad_to_multiple_of_side_left(self): method test_pad_to_multiple_of_no_extra_padding (line 182) | def test_pad_to_multiple_of_no_extra_padding(self): class TestHashModule (line 191) | class TestHashModule(TrlTestCase): method test_hash_module_deterministic_across_order (line 192) | def test_hash_module_deterministic_across_order(self): method test_hash_module_changes_with_value (line 209) | def test_hash_module_changes_with_value(self): method test_hash_module_includes_dtype (line 217) | def test_hash_module_includes_dtype(self): method test_hash_module_tiny_model_twice (line 225) | def test_hash_module_tiny_model_twice(self): method test_hash_module_tiny_model_change_layer (line 231) | def test_hash_module_tiny_model_change_layer(self): class TestGetPEFTConfig (line 242) | class TestGetPEFTConfig(TrlTestCase): method test_create_peft_config_use_peft_false (line 243) | def test_create_peft_config_use_peft_false(self): method test_create_peft_config_use_peft_true (line 249) | def test_create_peft_config_use_peft_true(self): class TestNanStd (line 275) | class TestNanStd(TrlTestCase): method test_nanstd_ignores_nans (line 276) | def test_nanstd_ignores_nans(self): method test_nanstd_dim_and_keepdim (line 281) | def test_nanstd_dim_and_keepdim(self): method test_nanstd_all_nan (line 287) | def test_nanstd_all_nan(self): class TestGenerateModelCard (line 293) | class TestGenerateModelCard(TrlTestCase): method test_full (line 294) | def test_full(self): method test_val_none (line 321) | def test_val_none(self): class TestFlushLeft (line 342) | class TestFlushLeft(TrlTestCase): method test_basic_case (line 343) | def test_basic_case(self): method test_single_row (line 357) | def test_single_row(self): method test_no_shift_needed (line 368) | def test_no_shift_needed(self): method test_no_tensors (line 379) | def test_no_tensors(self): class TestFlushRight (line 386) | class TestFlushRight(TrlTestCase): method test_basic_case (line 387) | def test_basic_case(self): method test_single_row (line 401) | def test_single_row(self): method test_no_shift_needed (line 412) | def test_no_shift_needed(self): method test_no_tensors (line 423) | def test_no_tensors(self): class TestRepeatRandomSampler (line 430) | class TestRepeatRandomSampler(TrlTestCase): method test_sampler (line 431) | def test_sampler(self): method test_sampler_no_shuffle (line 443) | def test_sampler_no_shuffle(self): method test_sampler_no_repeat (line 450) | def test_sampler_no_repeat(self): method test_sampler_with_batch_size (line 460) | def test_sampler_with_batch_size(self): method test_sampler_with_batch_size_and_drop (line 472) | def test_sampler_with_batch_size_and_drop(self): method test_sampler_with_mini_repeat_count_and_batch_size_1 (line 487) | def test_sampler_with_mini_repeat_count_and_batch_size_1(self): method test_sampler_with_mini_repeat_count_and_batch_size_2 (line 504) | def test_sampler_with_mini_repeat_count_and_batch_size_2(self): method test_sampler_with_mini_repeat_count_and_batch_size_3 (line 523) | def test_sampler_with_mini_repeat_count_and_batch_size_3(self): class TestEntropyFromLogits (line 542) | class TestEntropyFromLogits(TrlTestCase): method test_entropy_from_logits_2_dims (line 546) | def test_entropy_from_logits_2_dims(self, dtype, chunk_size, shape): class TestPrintPromptCompletionsSample (line 559) | class TestPrintPromptCompletionsSample(TrlTestCase): method test_print_output (line 561) | def test_print_output(self, mock_stdout): method test_num_samples (line 588) | def test_num_samples(self, mock_stdout): method test_print_messages (line 623) | def test_print_messages(self, mock_stdout): method test_print_messages_with_tools (line 672) | def test_print_messages_with_tools(self, mock_stdout): class TestSelectiveLogSoftmax (line 711) | class TestSelectiveLogSoftmax(TrlTestCase): method test_selective_log_softmax (line 713) | def test_selective_log_softmax(self, dtype): method test_selective_log_softmax_multi_index (line 733) | def test_selective_log_softmax_multi_index(self, dtype, k): class TestShuffleSequenceDict (line 753) | class TestShuffleSequenceDict(TrlTestCase): method test_shuffle_preserves_shape (line 754) | def test_shuffle_preserves_shape(self): method test_shuffle_consistent_across_tensors (line 764) | def test_shuffle_consistent_across_tensors(self): method test_none_tensor_remains_none (line 786) | def test_none_tensor_remains_none(self): method test_shuffle_with_list (line 795) | def test_shuffle_with_list(self): class TestSplitTensorDict (line 818) | class TestSplitTensorDict(TrlTestCase): method test_split_equal_chunks (line 819) | def test_split_equal_chunks(self): method test_with_none_tensor (line 833) | def test_with_none_tensor(self): method test_with_scalar (line 845) | def test_with_scalar(self): class TestSplitPixelValuesByGrid (line 858) | class TestSplitPixelValuesByGrid(TrlTestCase): method test_split_correctly_0 (line 859) | def test_split_correctly_0(self): method test_split_correctly_1 (line 875) | def test_split_correctly_1(self): method test_missing_keys (line 891) | def test_missing_keys(self): method test_mismatched_length (line 896) | def test_mismatched_length(self): method test_multi_images (line 905) | def test_multi_images(self): class TestUnsplitPixelValuesByGrid (line 922) | class TestUnsplitPixelValuesByGrid(TrlTestCase): method test_unsplit_correctly (line 923) | def test_unsplit_correctly(self): method test_no_op_if_not_list (line 936) | def test_no_op_if_not_list(self): class TestForwardMaskedLogits (line 943) | class TestForwardMaskedLogits: method test_llm (line 965) | def test_llm(self, model_id): method test_vlm (line 1015) | def test_vlm(self, model_id): FILE: tests/test_vllm_client_server.py class TestChunkList (line 48) | class TestChunkList(TrlTestCase): method test_even_split (line 49) | def test_even_split(self): method test_uneven_split (line 52) | def test_uneven_split(self): method test_more_chunks_than_elements (line 55) | def test_more_chunks_than_elements(self): method test_n_equals_len (line 58) | def test_n_equals_len(self): method test_n_is_1 (line 61) | def test_n_is_1(self): method test_single_element_list (line 64) | def test_single_element_list(self): method test_any_dtype (line 67) | def test_any_dtype(self): class TestExtractLogprobs (line 74) | class TestExtractLogprobs(TrlTestCase): method test_extract_logprobs_sorts_by_rank_and_replaces_nan (line 75) | def test_extract_logprobs_sorts_by_rank_and_replaces_nan(self): method test_extract_logprobs_returns_none_token_ids_when_logprobs_missing (line 118) | def test_extract_logprobs_returns_none_token_ids_when_logprobs_missing... class TestVLLMClientServer (line 130) | class TestVLLMClientServer(TrlTestCase): method setup_class (line 134) | def setup_class(cls): method test_generate (line 149) | def test_generate(self): method test_generate_with_logprobs_none (line 169) | def test_generate_with_logprobs_none(self): method test_chat (line 177) | def test_chat(self): method test_chat_with_logprobs_none (line 197) | def test_chat_with_logprobs_none(self): method test_chat_with_tools (line 205) | def test_chat_with_tools(self): method test_generate_with_token_ids (line 227) | def test_generate_with_token_ids(self): method test_generate_with_params (line 252) | def test_generate_with_params(self): method test_update_model_params (line 272) | def test_update_model_params(self): method test_reset_prefix_cache (line 276) | def test_reset_prefix_cache(self): method test_logprobs_match_with_non_default_sampling (line 281) | def test_logprobs_match_with_non_default_sampling(self): method teardown_class (line 362) | def teardown_class(cls): class TestVLLMClientServerBaseURL (line 375) | class TestVLLMClientServerBaseURL(TrlTestCase): method setup_class (line 379) | def setup_class(cls): method test_generate (line 394) | def test_generate(self): method test_generate_with_logprobs_none (line 414) | def test_generate_with_logprobs_none(self): method test_chat (line 422) | def test_chat(self): method test_chat_with_logprobs_none (line 442) | def test_chat_with_logprobs_none(self): method test_chat_with_tools (line 450) | def test_chat_with_tools(self): method test_generate_with_token_ids (line 472) | def test_generate_with_token_ids(self): method test_generate_with_params (line 497) | def test_generate_with_params(self): method test_update_model_params (line 517) | def test_update_model_params(self): method test_reset_prefix_cache (line 521) | def test_reset_prefix_cache(self): method teardown_class (line 526) | def teardown_class(cls): class TestVLLMClientServerTP (line 538) | class TestVLLMClientServerTP(TrlTestCase): method setup_class (line 542) | def setup_class(cls): method test_generate (line 560) | def test_generate(self): method test_generate_with_logprobs_none (line 580) | def test_generate_with_logprobs_none(self): method test_chat (line 588) | def test_chat(self): method test_chat_with_logprobs_none (line 608) | def test_chat_with_logprobs_none(self): method test_chat_with_tools (line 616) | def test_chat_with_tools(self): method test_generate_with_token_ids (line 638) | def test_generate_with_token_ids(self): method test_generate_with_params (line 663) | def test_generate_with_params(self): method test_update_model_params (line 683) | def test_update_model_params(self): method test_reset_prefix_cache (line 687) | def test_reset_prefix_cache(self): method teardown_class (line 692) | def teardown_class(cls): class TestVLLMClientServerDP (line 708) | class TestVLLMClientServerDP(TrlTestCase): method setup_class (line 712) | def setup_class(cls): method test_generate (line 730) | def test_generate(self): method test_generate_with_logprobs_none (line 750) | def test_generate_with_logprobs_none(self): method test_chat (line 758) | def test_chat(self): method test_chat_with_logprobs_none (line 778) | def test_chat_with_logprobs_none(self): method test_chat_with_tools (line 786) | def test_chat_with_tools(self): method test_generate_with_token_ids (line 808) | def test_generate_with_token_ids(self): method test_generate_with_params (line 833) | def test_generate_with_params(self): method test_update_model_params (line 853) | def test_update_model_params(self): method test_reset_prefix_cache (line 857) | def test_reset_prefix_cache(self): method teardown_class (line 862) | def teardown_class(cls): class TestVLLMClientServerDeviceParameter (line 874) | class TestVLLMClientServerDeviceParameter(TrlTestCase): method setup_class (line 880) | def setup_class(cls): method test_init_communicator_with_device_int (line 891) | def test_init_communicator_with_device_int(self): method test_init_communicator_with_device_string (line 908) | def test_init_communicator_with_device_string(self): method test_init_communicator_with_torch_device (line 921) | def test_init_communicator_with_torch_device(self): method teardown_class (line 938) | def teardown_class(cls): class TestVLLMClientServerVLM (line 947) | class TestVLLMClientServerVLM(TrlTestCase): method setup_class (line 951) | def setup_class(cls): method test_generate_with_token_ids_and_image (line 960) | def test_generate_with_token_ids_and_image(self): method test_generate_with_token_ids_mixed_images (line 1000) | def test_generate_with_token_ids_mixed_images(self): method teardown_class (line 1035) | def teardown_class(cls): FILE: tests/testing_utils.py function is_bitsandbytes_multi_backend_available (line 73) | def is_bitsandbytes_multi_backend_available() -> bool: function is_ampere_or_newer (line 88) | def is_ampere_or_newer(device_index=0): class TrlTestCase (line 100) | class TrlTestCase: method set_tmp_dir (line 102) | def set_tmp_dir(self, tmp_path): function ignore_warnings (line 106) | def ignore_warnings(message: str = None, category: type[Warning] = Warni... function kill_process (line 129) | def kill_process(process): FILE: trl/_compat.py function _is_package_version_below (line 30) | def _is_package_version_below(package_name: str, version_threshold: str)... function _is_package_version_at_least (line 54) | def _is_package_version_at_least(package_name: str, version_threshold: s... function _patch_vllm_logging (line 78) | def _patch_vllm_logging() -> None: function _patch_vllm_disabled_tqdm (line 86) | def _patch_vllm_disabled_tqdm() -> None: function _patch_vllm_cached_tokenizer (line 110) | def _patch_vllm_cached_tokenizer() -> None: function _patch_transformers_hybrid_cache (line 170) | def _patch_transformers_hybrid_cache() -> None: function _patch_transformers_parallelism_config (line 214) | def _patch_transformers_parallelism_config() -> None: FILE: trl/_lazy_module.py class _LazyModule (line 22) | class _LazyModule(ModuleType): method __init__ (line 29) | def __init__(self, name, module_file, import_structure, module_spec=No... method __dir__ (line 46) | def __dir__(self): method __getattr__ (line 55) | def __getattr__(self, name: str) -> Any: method _get_module (line 69) | def _get_module(self, module_name: str): method __reduce__ (line 78) | def __reduce__(self): FILE: trl/chat_template_utils.py function clone_chat_template (line 18) | def clone_chat_template( function add_response_schema (line 429) | def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTok... function is_chat_template_prefix_preserving (line 472) | def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -... function get_training_chat_template (line 610) | def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | ... function _validate_tool_calls (line 671) | def _validate_tool_calls(tool_calls: list | None) -> None: function parse_response (line 709) | def parse_response(tokenizer: PreTrainedTokenizer, ids: list[int]) -> dict: FILE: trl/cli/accelerate_config.py function resolve_accelerate_config_argument (line 19) | def resolve_accelerate_config_argument(launch_args: list[str]) -> list[s... FILE: trl/cli/accelerate_launcher.py function launch_training_script (line 22) | def launch_training_script( FILE: trl/cli/commands/__init__.py function get_commands (line 22) | def get_commands() -> list[Command]: FILE: trl/cli/commands/base.py class CommandContext (line 21) | class CommandContext: method argv_after (line 26) | def argv_after(self, token: str) -> list[str]: class Command (line 41) | class Command(ABC): method __init__ (line 52) | def __init__(self, name: str, help_text: str): method register (line 57) | def register(self, subparsers) -> None: method run (line 61) | def run(self, args: Namespace, context: CommandContext) -> int: FILE: trl/cli/commands/env.py class EnvCommand (line 20) | class EnvCommand(Command): method __init__ (line 23) | def __init__(self): method register (line 26) | def register(self, subparsers) -> None: method run (line 29) | def run(self, args: Namespace, context: CommandContext) -> int: FILE: trl/cli/commands/skills.py class SkillsCommand (line 21) | class SkillsCommand(Command): method __init__ (line 24) | def __init__(self): method register (line 28) | def register(self, subparsers) -> None: method run (line 33) | def run(self, args: Namespace, context: CommandContext) -> int: FILE: trl/cli/commands/training.py function _subtract_subsequence (line 21) | def _subtract_subsequence(lst: list[str], subseq: list[str]) -> list[str]: class TrainingCommand (line 34) | class TrainingCommand(Command): method __init__ (line 45) | def __init__(self, name: str): method register (line 48) | def register(self, subparsers) -> None: method run (line 51) | def run(self, args: Namespace, context: CommandContext) -> int: FILE: trl/cli/commands/vllm_serve.py class VllmServeCommand (line 20) | class VllmServeCommand(Command): method __init__ (line 23) | def __init__(self): method register (line 26) | def register(self, subparsers) -> None: method run (line 29) | def run(self, args: Namespace, context: CommandContext) -> int: FILE: trl/cli/main.py function _build_parser (line 22) | def _build_parser(commands: list[Command]) -> ArgumentParser: function main (line 32) | def main(argv: list[str] | None = None) -> int: FILE: trl/data_utils.py function prepare_multimodal_messages (line 32) | def prepare_multimodal_messages(messages: list[dict[str, Any]], images: ... function prepare_multimodal_messages_vllm (line 126) | def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> ... function is_conversational (line 159) | def is_conversational(example: dict[str, Any]) -> bool: function apply_chat_template (line 200) | def apply_chat_template( function maybe_apply_chat_template (line 333) | def maybe_apply_chat_template( function _unpair_row (line 397) | def _unpair_row(examples: list[dict[str, list[dict[str, str]]]]) -> list... function unpair_preference_dataset (line 408) | def unpair_preference_dataset( function maybe_unpair_preference_dataset (line 451) | def maybe_unpair_preference_dataset( function extract_prompt (line 502) | def extract_prompt(example: dict[str, Sequence]) -> dict[str, Sequence]: function maybe_extract_prompt (line 589) | def maybe_extract_prompt(example: dict[str, list]) -> dict[str, list]: function _get_dataset_format (line 615) | def _get_dataset_format(dataset: DatasetType) -> dict[str, Any]: function _check_if_columns_can_be_packed (line 627) | def _check_if_columns_can_be_packed(columns: list[pa.Array]): class _SegmentTree (line 639) | class _SegmentTree: method __init__ (line 647) | def __init__(self, maxval: int): method add (line 653) | def add(self, val): method remove (line 663) | def remove(self, val): method search (line 673) | def search(self, val): function _pack_bfd (line 684) | def _pack_bfd( function _pack_wrapped (line 774) | def _pack_wrapped(examples: pa.Table, seq_length: int) -> pa.Table: function pack_dataset (line 790) | def pack_dataset( function truncate_dataset (line 880) | def truncate_dataset( function is_conversational_from_value (line 944) | def is_conversational_from_value(example: dict[str, Any]) -> bool: function maybe_convert_to_chatml (line 984) | def maybe_convert_to_chatml(example: dict[str, list]) -> dict[str, list]: FILE: trl/experimental/async_grpo/async_grpo_config.py class AsyncGRPOConfig (line 21) | class AsyncGRPOConfig(_BaseConfig): method __post_init__ (line 201) | def __post_init__(self): FILE: trl/experimental/async_grpo/async_grpo_trainer.py class _SupportsReset (line 47) | class _SupportsReset(Protocol): method reset (line 48) | def reset(self, **kwargs) -> str | None: ... class RolloutWorkerProtocol (line 54) | class RolloutWorkerProtocol(Protocol): method start (line 57) | def start(self) -> None: ... method stop (line 58) | def stop(self) -> None: ... method pause (line 59) | def pause(self) -> None: ... method resume (line 60) | def resume(self) -> None: ... method send_weights (line 61) | def send_weights(self, iterator: Iterator[tuple[str, torch.Tensor]]) -... method update_model_version (line 62) | def update_model_version(self, version: int) -> None: ... class StepIntervalCallback (line 65) | class StepIntervalCallback(TrainerCallback): method __init__ (line 70) | def __init__(self, fn, every_n_steps: int): method on_step_end (line 74) | def on_step_end(self, _args, state, _control, **_kwargs): class RolloutQueueDataset (line 79) | class RolloutQueueDataset(torch.utils.data.IterableDataset): method __init__ (line 80) | def __init__(self, rollout_queue, model_version_fn, max_staleness=3, t... method __iter__ (line 86) | def __iter__(self): class _EmptyIterableDataset (line 115) | class _EmptyIterableDataset(torch.utils.data.IterableDataset): method __iter__ (line 118) | def __iter__(self): class DataCollatorForRollout (line 123) | class DataCollatorForRollout(DataCollatorMixin): method torch_call (line 127) | def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: class AsyncGRPOTrainer (line 168) | class AsyncGRPOTrainer(_BaseTrainer): method __init__ (line 270) | def __init__( method get_train_dataloader (line 391) | def get_train_dataloader(self) -> DataLoader: method _set_signature_columns_if_needed (line 414) | def _set_signature_columns_if_needed(self): method compute_loss (line 430) | def compute_loss(self, model, inputs, return_outputs=False, num_items_... method log (line 543) | def log(self, logs: dict[str, float], start_time: float | None = None)... method _streaming_iter (line 552) | def _streaming_iter(self): method _sync_weight (line 560) | def _sync_weight(self): method _inner_training_loop (line 591) | def _inner_training_loop(self, *args, **kwargs): FILE: trl/experimental/async_grpo/async_rollout_worker.py class RolloutGroup (line 47) | class RolloutGroup: class RolloutSample (line 64) | class RolloutSample: class AsyncRolloutWorker (line 75) | class AsyncRolloutWorker: method __init__ (line 83) | def __init__( method _wait_for_server_ready_sync (line 178) | def _wait_for_server_ready_sync(self, timeout_s: float = 240.0, poll_i... method _init_weight_transfer (line 201) | def _init_weight_transfer(self) -> None: method update_model_version (line 231) | def update_model_version(self, model_version: int): method _run_loops (line 234) | async def _run_loops(self, stop_event: asyncio.Event) -> None: method start (line 245) | def start(self) -> None: method stop (line 249) | def stop(self) -> None: method _run (line 257) | def _run(self) -> None: method pause (line 270) | def pause(self) -> None: method resume (line 275) | def resume(self) -> None: method send_weights (line 280) | def send_weights(self, iterator) -> None: method _generate_loop (line 303) | async def _generate_loop(self, stop_event: asyncio.Event) -> None: method _compute_rollout_metrics (line 422) | def _compute_rollout_metrics(self, samples: list[RolloutSample], scori... method _score_loop (line 442) | async def _score_loop(self, stop_event: asyncio.Event) -> None: method _repeat_iterator (line 497) | def _repeat_iterator(self) -> Iterator[tuple[int, dict[str, Any]]]: method _generate_one (line 509) | async def _generate_one( method _build_messages_suffix_ids (line 547) | def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -... method _execute_tool_calls (line 572) | def _execute_tool_calls( method _generate_one_turn (line 591) | async def _generate_one_turn(self, prompt_ids: list[int]) -> tuple[lis... method _score_group (line 614) | async def _score_group(self, group: RolloutGroup) -> list[RolloutSample]: method _post (line 692) | async def _post(self, path: str, payload: dict, timeout: float, max_re... FILE: trl/experimental/bco/bco_config.py class BCOConfig (line 22) | class BCOConfig(_BaseConfig): FILE: trl/experimental/bco/bco_trainer.py function get_global_statistics (line 89) | def get_global_statistics( class RunningMoments (line 110) | class RunningMoments: method update (line 123) | def update(self, xs: torch.Tensor) -> tuple[float, float]: method save_to_json (line 150) | def save_to_json(self, json_path: str): method load_from_json (line 160) | def load_from_json(cls, accelerator: Accelerator, json_path: str): function _tokenize (line 168) | def _tokenize( function _process_tokens (line 239) | def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = ... class BCOTrainer (line 349) | class BCOTrainer(_BaseTrainer): method __init__ (line 410) | def __init__( method match_underlying_distribution (line 803) | def match_underlying_distribution(self): method _get_chosen_prob (line 806) | def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> to... method _vectorize_prompt (line 835) | def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mas... method _get_prompt_embeddings (line 853) | def _get_prompt_embeddings( method _get_sample_prompt_embeddings (line 875) | def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size:... method _save_optimizer_and_scheduler (line 907) | def _save_optimizer_and_scheduler(self, output_dir): method _load_optimizer_and_scheduler (line 918) | def _load_optimizer_and_scheduler(self, checkpoint): method null_ref_context (line 936) | def null_ref_context(self): method get_train_dataloader (line 949) | def get_train_dataloader(self) -> DataLoader: method get_eval_dataloader (line 983) | def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> ... method compute_reference_log_probs (line 1029) | def compute_reference_log_probs(self, padded_batch: dict) -> dict: method get_batch_logps (line 1072) | def get_batch_logps( method forward (line 1119) | def forward( method _get_udm_weight (line 1167) | def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> t... method bco_loss (line 1176) | def bco_loss( method get_batch_loss_metrics (line 1232) | def get_batch_loss_metrics( method compute_loss (line 1326) | def compute_loss( method store_metrics (line 1350) | def store_metrics(self, metrics: dict[str, float], train_eval: Literal... method _get_train_sampler (line 1354) | def _get_train_sampler(self, dataset: Dataset | None = None) -> torch.... method generate_from_model_and_ref (line 1361) | def generate_from_model_and_ref(self, model, batch: dict[str, torch.Lo... method prediction_step (line 1408) | def prediction_step( method evaluation_loop (line 1446) | def evaluation_loop( method log (line 1506) | def log(self, logs: dict[str, float], start_time: float | None = None)... method _save_checkpoint (line 1542) | def _save_checkpoint(self, model, trial): FILE: trl/experimental/bema_for_ref_model/callback.py class CallbackHandlerWithRefModel (line 28) | class CallbackHandlerWithRefModel(CallbackHandler): method __init__ (line 33) | def __init__(self, callbacks, model, ref_model, processing_class, opti... method call_event (line 38) | def call_event(self, event, args, state, control, **kwargs): class BEMACallback (line 59) | class BEMACallback(_BEMACallback): method __init__ (line 128) | def __init__( method on_step_end (line 158) | def on_step_end( method _update_model_with_bema_weights (line 202) | def _update_model_with_bema_weights(self, model, bema_state_dict, is_p... FILE: trl/experimental/bema_for_ref_model/dpo_trainer.py class DPOTrainer (line 19) | class DPOTrainer(_DPOTrainer): method __init__ (line 20) | def __init__(self, *args, **kwargs): FILE: trl/experimental/cpo/cpo_config.py class CPOConfig (line 22) | class CPOConfig(_BaseConfig): method __post_init__ (line 176) | def __post_init__(self): FILE: trl/experimental/cpo/cpo_trainer.py class CPOTrainer (line 74) | class CPOTrainer(_BaseTrainer): method __init__ (line 129) | def __init__( method build_tokenized_answer (line 389) | def build_tokenized_answer(self, prompt, answer): method tokenize_row (line 438) | def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | N... method concatenated_inputs (line 571) | def concatenated_inputs( method cpo_loss (line 635) | def cpo_loss( method get_batch_logps (line 709) | def get_batch_logps( method concatenated_forward (line 749) | def concatenated_forward( method get_batch_loss_metrics (line 822) | def get_batch_loss_metrics( method compute_loss (line 876) | def compute_loss( method generate_from_model (line 897) | def generate_from_model(self, model, batch: dict[str, torch.LongTensor... method prediction_step (line 920) | def prediction_step( method store_metrics (line 957) | def store_metrics(self, metrics: dict[str, float], train_eval: Literal... method evaluation_loop (line 961) | def evaluation_loop( method log (line 1012) | def log(self, logs: dict[str, float], start_time: float | None = None)... method _shift_right (line 1030) | def _shift_right(self, input_ids): method _save_checkpoint (line 1054) | def _save_checkpoint(self, model, trial): FILE: trl/experimental/dppo/dppo_config.py class DPPOConfig (line 22) | class DPPOConfig(GRPOConfig): method __post_init__ (line 93) | def __post_init__(self): FILE: trl/experimental/dppo/dppo_trainer.py function _strip_padding (line 66) | def _strip_padding(tensor: torch.Tensor, mask: torch.Tensor) -> list[list]: class DPPOTrainer (line 71) | class DPPOTrainer(GRPOTrainer): method __init__ (line 190) | def __init__( method _tokenize_prompts (line 234) | def _tokenize_prompts(self, prompts: list): method _generate_single_turn (line 274) | def _generate_single_turn(self, prompt_ids, images, multimodal_fields): method _tool_call_loop (line 418) | def _tool_call_loop( method _generate (line 619) | def _generate(self, prompts: list): method _get_per_token_logps_with_topk (line 752) | def _get_per_token_logps_with_topk( method _generate_and_score_completions (line 841) | def _generate_and_score_completions( method _compute_divergence_mask (line 1187) | def _compute_divergence_mask( method _compute_loss (line 1270) | def _compute_loss(self, model, inputs): FILE: trl/experimental/gfpo/gfpo_config.py class GFPOConfig (line 21) | class GFPOConfig(_GRPOConfig): method __post_init__ (line 29) | def __post_init__(self): FILE: trl/experimental/gfpo/gfpo_trainer.py class GFPOTrainer (line 33) | class GFPOTrainer(_GRPOTrainer): method __init__ (line 34) | def __init__( method _generate_and_score_completions (line 69) | def _generate_and_score_completions(self, inputs): FILE: trl/experimental/gkd/gkd_config.py class GKDConfig (line 22) | class GKDConfig(SFTConfig): method __post_init__ (line 104) | def __post_init__(self): FILE: trl/experimental/gkd/gkd_trainer.py class GKDTrainer (line 53) | class GKDTrainer(SFTTrainer): method __init__ (line 110) | def __init__( method generalized_jsd_loss (line 222) | def generalized_jsd_loss( method compute_loss (line 292) | def compute_loss(self, model, inputs, return_outputs=False, num_items_... method generate_on_policy_outputs (line 394) | def generate_on_policy_outputs(model, inputs, generation_config, pad_t... method training_step (line 416) | def training_step( FILE: trl/experimental/gold/gold_config.py class GOLDConfig (line 23) | class GOLDConfig(SFTConfig): method __post_init__ (line 385) | def __post_init__(self): FILE: trl/experimental/gold/gold_trainer.py function print_prompt_completions_sample_uld (line 94) | def print_prompt_completions_sample_uld( function build_teacher_inputs_from_texts (line 173) | def build_teacher_inputs_from_texts( class ULDLoss (line 236) | class ULDLoss(nn.Module): method __init__ (line 241) | def __init__(self, config: GOLDConfig, student_tokenizer=None, teacher... method __call__ (line 270) | def __call__( method _initialize_vocabulary_mapping (line 304) | def _initialize_vocabulary_mapping(self): method _compute_distillation_loss (line 336) | def _compute_distillation_loss( method _build_alignment_groups_from_ids (line 438) | def _build_alignment_groups_from_ids(self, student_token_ids, teacher_... method _merge_probabilities_with_alignment_groups (line 534) | def _merge_probabilities_with_alignment_groups(self, probs, alignment_... method _compute_hybrid_uld_loss (line 611) | def _compute_hybrid_uld_loss(self, student_aligned, teacher_aligned): method _compute_jsd_loss_for_matched_tokens (line 708) | def _compute_jsd_loss_for_matched_tokens(self, student_logits, teacher... method _get_start_and_size_answers (line 737) | def _get_start_and_size_answers(self, answer_tensors): class GOLDVLLMSyncCallback (line 754) | class GOLDVLLMSyncCallback(TrainerCallback): method __init__ (line 757) | def __init__(self, trainer): method on_step_end (line 760) | def on_step_end(self, args, state: TrainerState, control: TrainerContr... class GOLDTrainer (line 774) | class GOLDTrainer(SFTTrainer): method __init__ (line 789) | def __init__( method _set_signature_columns_if_needed (line 1047) | def _set_signature_columns_if_needed(self): method _get_train_sampler (line 1065) | def _get_train_sampler(self, dataset=None): method get_train_dataloader (line 1077) | def get_train_dataloader(self): method _prepare_inputs (line 1118) | def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | A... method _decode_completion_texts_from_labels (line 1131) | def _decode_completion_texts_from_labels(self, slice_inputs: dict[str,... method _ensure_original_text_fields (line 1151) | def _ensure_original_text_fields( method _build_sequence_batch (line 1177) | def _build_sequence_batch( method _fill_buffer (line 1197) | def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any]... method _generate_on_policy_for_slices (line 1231) | def _generate_on_policy_for_slices( method _deduplicate_prompts (line 1303) | def _deduplicate_prompts( method _generate_vllm_server_global (line 1323) | def _generate_vllm_server_global( method _generate_vllm_colocate (line 1378) | def _generate_vllm_colocate( method _generate_non_vllm_for_slices (line 1445) | def _generate_non_vllm_for_slices(self, slices: list[dict[str, torch.T... method _process_completions_to_buffer (line 1472) | def _process_completions_to_buffer( method _prepare_dataset (line 1564) | def _prepare_dataset( method _prepare_dataset_with_original_text (line 1584) | def _prepare_dataset_with_original_text( method generalized_jsd_loss (line 1817) | def generalized_jsd_loss( method compute_loss (line 1896) | def compute_loss(self, model, inputs, return_outputs=False, num_items_... method generate_on_policy_outputs (line 2082) | def generate_on_policy_outputs(self, model, inputs, generation_config,... method _sync_fsdp_params_to_vllm (line 2163) | def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "... method _move_model_to_vllm (line 2190) | def _move_model_to_vllm(self): method _wake_vllm_if_needed (line 2260) | def _wake_vllm_if_needed(self): method _get_liger_zero3_lm_head_gather_ctx (line 2265) | def _get_liger_zero3_lm_head_gather_ctx(self, model: nn.Module): method training_step (line 2287) | def training_step( method log (line 2325) | def log(self, logs: dict[str, float], start_time: float | None = None)... FILE: trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py class GRPOWithReplayBufferConfig (line 21) | class GRPOWithReplayBufferConfig(GRPOConfig): FILE: trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py class ReplayBuffer (line 28) | class ReplayBuffer: method __init__ (line 33) | def __init__(self, max_size: int): method add (line 37) | def add(self, scores: list[float], data: list[dict]): method sample (line 46) | def sample(self, num_samples: int) -> list[dict[str, torch.Tensor]]: class GRPOWithReplayBufferTrainer (line 60) | class GRPOWithReplayBufferTrainer(GRPOTrainer): method __init__ (line 61) | def __init__(self, args: GRPOWithReplayBufferConfig | None = None, **k... method _generate_and_score_completions (line 65) | def _generate_and_score_completions( method slice_group_data (line 427) | def slice_group_data( method update_replay_buffer (line 449) | def update_replay_buffer( method sample_from_replay_buffer (line 528) | def sample_from_replay_buffer( method update_with_replay_buffer (line 574) | def update_with_replay_buffer( FILE: trl/experimental/gspo_token/grpo_trainer.py class GRPOTrainer (line 21) | class GRPOTrainer(_GRPOTrainer): method _compute_loss (line 22) | def _compute_loss(self, model, inputs): FILE: trl/experimental/judges/judges.py function _ensure_llm_blender_importable (line 57) | def _ensure_llm_blender_importable() -> None: class BaseJudge (line 77) | class BaseJudge(ABC): method judge (line 83) | def judge(self, prompts: list[str], completions: list[str], shuffle_or... class BaseRankJudge (line 87) | class BaseRankJudge(ABC): method judge (line 107) | def judge(self, prompts: list[str], completions: list[list[str]], shuf... class BasePairwiseJudge (line 128) | class BasePairwiseJudge(BaseJudge): method judge (line 134) | def judge(self, prompts: list[str], completions: list[list[str]], shuf... class BaseBinaryJudge (line 160) | class BaseBinaryJudge(BaseJudge): method judge (line 166) | def judge( class PairRMJudge (line 200) | class PairRMJudge(BasePairwiseJudge): method __init__ (line 227) | def __init__(self): method judge (line 242) | def judge( class HfPairwiseJudge (line 309) | class HfPairwiseJudge(BasePairwiseJudge): method __init__ (line 327) | def __init__( method judge (line 336) | def judge(self, prompts: list[str], completions: list[list[str]], shuf... class OpenAIPairwiseJudge (line 365) | class OpenAIPairwiseJudge(BasePairwiseJudge): method __init__ (line 383) | def __init__( method judge (line 397) | def judge(self, prompts: list[str], completions: list[list[str]], shuf... class AllTrueJudge (line 440) | class AllTrueJudge(BaseBinaryJudge): method __init__ (line 454) | def __init__(self, judges: list[BaseBinaryJudge]): method judge (line 457) | def judge( FILE: trl/experimental/kto/kto_config.py class KTOConfig (line 22) | class KTOConfig(_BaseConfig): method __post_init__ (line 145) | def __post_init__(self): FILE: trl/experimental/kto/kto_trainer.py function _get_kl_dataset (line 85) | def _get_kl_dataset(batch: dict[str, list[Any]]) -> dict[str, list[Any]]: function _tokenize (line 96) | def _tokenize( function _process_tokens (line 156) | def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = ... class KTOTrainer (line 248) | class KTOTrainer(_BaseTrainer): method __init__ (line 309) | def __init__( method null_ref_context (line 743) | def null_ref_context(self): method get_train_dataloader (line 756) | def get_train_dataloader(self) -> DataLoader: method get_eval_dataloader (line 800) | def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> ... method compute_reference_log_probs (line 855) | def compute_reference_log_probs(self, padded_batch: dict) -> dict: method get_batch_logps (line 899) | def get_batch_logps( method forward (line 939) | def forward( method kto_loss (line 985) | def kto_loss( method _compute_kl_logps (line 1062) | def _compute_kl_logps(self, model, batch): method _compute_loss_liger (line 1081) | def _compute_loss_liger(self, model, batch): method get_batch_loss_metrics (line 1174) | def get_batch_loss_metrics( method compute_loss (line 1290) | def compute_loss( method store_metrics (line 1314) | def store_metrics(self, metrics: dict[str, float], train_eval: Literal... method _get_train_sampler (line 1318) | def _get_train_sampler(self, dataset: Dataset | None = None) -> torch.... method generate_from_model_and_ref (line 1325) | def generate_from_model_and_ref(self, model, batch: dict[str, torch.Lo... method prediction_step (line 1373) | def prediction_step( method evaluation_loop (line 1411) | def evaluation_loop( method log (line 1471) | def log(self, logs: dict[str, float], start_time: float | None = None)... method _save_checkpoint (line 1507) | def _save_checkpoint(self, model, trial): FILE: trl/experimental/merge_model_callback.py function upload_model_to_hf (line 35) | def upload_model_to_hf(folder_path: str, repo_id: str): class MergeConfig (line 48) | class MergeConfig: method __init__ (line 82) | def __init__(self, method: str = "linear"): method create_merge_config_linear (line 114) | def create_merge_config_linear(self) -> "MergeConfiguration": method create_merge_config_ties (line 133) | def create_merge_config_ties(self) -> "MergeConfiguration": method create_merge_config_dare_ties (line 177) | def create_merge_config_dare_ties(self) -> "MergeConfiguration": method create_merge_config_slerp (line 221) | def create_merge_config_slerp(self) -> "MergeConfiguration": method create (line 260) | def create(self) -> "MergeConfiguration": function merge_models (line 271) | def merge_models(config: "MergeConfiguration", out_path: str): class MergeModelCallback (line 294) | class MergeModelCallback(TrainerCallback): method __init__ (line 319) | def __init__( method _merge_and_maybe_push (line 333) | def _merge_and_maybe_push(self, output_dir, global_step, model): method on_save (line 346) | def on_save(self, args, state, control, model=None, **kwargs): method on_train_end (line 350) | def on_train_end(self, args, state, control, model=None, **kwargs): FILE: trl/experimental/minillm/minillm_config.py class MiniLLMConfig (line 24) | class MiniLLMConfig(GRPOConfig): method __post_init__ (line 87) | def __post_init__(self): FILE: trl/experimental/minillm/minillm_trainer.py function dummy_reward_func (line 43) | def dummy_reward_func(completions: list, **kwargs): class MiniLLMTrainer (line 48) | class MiniLLMTrainer(GRPOTrainer): method __init__ (line 166) | def __init__( method _single_step_decomposition_loss (line 245) | def _single_step_decomposition_loss( method _compute_advantage (line 292) | def _compute_advantage( method compute_loss (line 349) | def compute_loss(self, model, inputs, return_outputs=False, num_items_... FILE: trl/experimental/nash_md/nash_md_config.py class NashMDConfig (line 21) | class NashMDConfig(OnlineDPOConfig): method __post_init__ (line 43) | def __post_init__(self): FILE: trl/experimental/nash_md/nash_md_trainer.py class GeometricMixtureWrapper (line 50) | class GeometricMixtureWrapper(GenerationMixin): method __init__ (line 66) | def __init__(self, model, ref_model, generation_config, mixture_coef=0... method __call__ (line 78) | def __call__(self, *args, **kwargs): method forward (line 82) | def forward(self, *args, **kwargs): method prepare_inputs_for_generation (line 93) | def prepare_inputs_for_generation(self, *args, **kwargs): method _validate_model_class (line 101) | def _validate_model_class(self): method _validate_model_kwargs (line 104) | def _validate_model_kwargs(self, model_kwargs): class NashMDTrainer (line 108) | class NashMDTrainer(OnlineDPOTrainer): method __init__ (line 170) | def __init__( method mixture_coef (line 236) | def mixture_coef(self): method _generate_completions (line 243) | def _generate_completions(self, model, prompts): method _process_completions (line 298) | def _process_completions(self, model_output, mixture_output, prompts): method _compute_rewards (line 325) | def _compute_rewards(self, model_data, mixture_data, context_length): method _compute_judge (line 343) | def _compute_judge(self, model_data, mixture_data, context_length): method _compute_logprobs (line 377) | def _compute_logprobs(self, model, model_data, context_length): method _compute_losses (line 402) | def _compute_losses( method _log_statistics (line 422) | def _log_statistics( method training_step (line 481) | def training_step( FILE: trl/experimental/online_dpo/online_dpo_config.py class OnlineDPOConfig (line 23) | class OnlineDPOConfig(_BaseConfig): method __post_init__ (line 386) | def __post_init__(self): FILE: trl/experimental/online_dpo/online_dpo_trainer.py class OnlineDPOTrainer (line 104) | class OnlineDPOTrainer(_BaseTrainer): method __init__ (line 182) | def __init__( method beta (line 587) | def beta(self): method tokenize_row (line 595) | def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrai... method _enable_gradient_checkpointing (line 610) | def _enable_gradient_checkpointing(self, model: PreTrainedModel, args:... method _generate_vllm (line 625) | def _generate_vllm(self, prompts, images=None): method _generate_vllm_server (line 655) | def _generate_vllm_server(self, prompts, images=None): method _generate_vllm_colocate (line 731) | def _generate_vllm_colocate(self, prompts, images=None): method _sync_fsdp2_params_to_vllm (line 772) | def _sync_fsdp2_params_to_vllm(self, module: nn.Module): method _move_model_to_vllm (line 795) | def _move_model_to_vllm(self): method _sync_fsdp1_params_to_vllm (line 871) | def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = ... method _fix_param_name_to_vllm (line 898) | def _fix_param_name_to_vllm(self, name, extra_prefixes: list[str] | No... method process_vision_row (line 906) | def process_vision_row( method _generate (line 933) | def _generate(self, model, prompts, images=None): method _calculate_rewards_from_functions (line 1088) | def _calculate_rewards_from_functions(self, prompts, completions, comp... method _forward (line 1134) | def _forward(self, model, prompt_ids, prompt_mask, completion_ids, com... method training_step (line 1172) | def training_step( method _maybe_log_save_evaluate (line 1394) | def _maybe_log_save_evaluate( method _save_checkpoint (line 1440) | def _save_checkpoint(self, model, trial): FILE: trl/experimental/openenv/utils.py function _build_base_generation_kwargs (line 29) | def _build_base_generation_kwargs( function _build_colocate_sampling_params (line 60) | def _build_colocate_sampling_params( function _build_server_generation_kwargs (line 80) | def _build_server_generation_kwargs( function generate_rollout_completions (line 88) | def generate_rollout_completions( function _generate_rollout_completions_server (line 116) | def _generate_rollout_completions_server( function _generate_rollout_completions_colocate (line 156) | def _generate_rollout_completions_colocate( FILE: trl/experimental/orpo/orpo_config.py class ORPOConfig (line 22) | class ORPOConfig(_BaseConfig): FILE: trl/experimental/orpo/orpo_trainer.py function log1mexp (line 78) | def log1mexp(x: torch.FloatTensor) -> torch.FloatTensor: class ORPOTrainer (line 85) | class ORPOTrainer(_BaseTrainer): method __init__ (line 138) | def __init__( method build_tokenized_answer (line 374) | def build_tokenized_answer(self, prompt, answer): method tokenize_row (line 423) | def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | N... method concatenated_inputs (line 566) | def concatenated_inputs( method odds_ratio_loss (line 630) | def odds_ratio_loss( method get_batch_logps (line 665) | def get_batch_logps( method concatenated_forward (line 705) | def concatenated_forward( method get_batch_loss_metrics (line 784) | def get_batch_loss_metrics( method compute_loss (line 839) | def compute_loss( method generate_from_model (line 863) | def generate_from_model(self, model, batch: dict[str, torch.LongTensor... method prediction_step (line 886) | def prediction_step( method store_metrics (line 928) | def store_metrics(self, metrics: dict[str, float], train_eval: Literal... method evaluation_loop (line 932) | def evaluation_loop( method log (line 983) | def log(self, logs: dict[str, float], start_time: float | None = None)... method _shift_right (line 1001) | def _shift_right(self, input_ids): method _save_checkpoint (line 1025) | def _save_checkpoint(self, model, trial): FILE: trl/experimental/papo/papo_config.py class PAPOConfig (line 22) | class PAPOConfig(GRPOConfig): method __post_init__ (line 60) | def __post_init__(self): FILE: trl/experimental/papo/papo_trainer.py class PAPOTrainer (line 27) | class PAPOTrainer(GRPOTrainer): method __init__ (line 114) | def __init__( method _mask_image (line 154) | def _mask_image(self, pixel_values: torch.Tensor, mask_ratio: float = ... method _compute_loss (line 214) | def _compute_loss(self, model, inputs): FILE: trl/experimental/ppo/modeling_value_head.py class PreTrainedModelWrapper (line 52) | class PreTrainedModelWrapper(nn.Module): method __init__ (line 79) | def __init__( method from_pretrained (line 107) | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, *... method _get_checkpoint_from_hub (line 339) | def _get_checkpoint_from_hub( method _get_current_device (line 391) | def _get_current_device(cls): method _split_kwargs (line 409) | def _split_kwargs(cls, kwargs): method add_and_load_reward_modeling_adapter (line 439) | def add_and_load_reward_modeling_adapter( method push_to_hub (line 509) | def push_to_hub(self, *args, **kwargs): method save_pretrained (line 523) | def save_pretrained(self, *args, **kwargs): method state_dict (line 550) | def state_dict(self, *args, **kwargs): method post_init (line 556) | def post_init(self, *args, **kwargs): method compute_reward_score (line 563) | def compute_reward_score(self, input_ids, attention_mask=None, **kwargs): class ValueHead (line 594) | class ValueHead(nn.Module): method __init__ (line 599) | def __init__(self, config, **kwargs): method forward (line 622) | def forward(self, hidden_states): class AutoModelForCausalLMWithValueHead (line 634) | class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): method __init__ (line 665) | def __init__(self, pretrained_model, **kwargs): method _init_weights (line 681) | def _init_weights(self, **kwargs): method forward (line 703) | def forward( method generate (line 758) | def generate(self, *args, **kwargs): method state_dict (line 772) | def state_dict(self, *args, **kwargs): method push_to_hub (line 788) | def push_to_hub(self, *args, **kwargs): method post_init (line 793) | def post_init(self, state_dict): class AutoModelForSeq2SeqLMWithValueHead (line 838) | class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper): method __init__ (line 861) | def __init__(self, pretrained_model, **kwargs): method _has_lm_head (line 873) | def _has_lm_head(self): method post_init (line 880) | def post_init(self, state_dict): method state_dict (line 934) | def state_dict(self, *args, **kwargs): method push_to_hub (line 950) | def push_to_hub(self, *args, **kwargs): method _init_weights (line 955) | def _init_weights(self, **kwargs): method forward (line 969) | def forward( method generate (line 1003) | def generate(self, *args, **kwargs): FILE: trl/experimental/ppo/ppo_config.py class PPOConfig (line 22) | class PPOConfig(_BaseConfig): FILE: trl/experimental/ppo/ppo_trainer.py function generate (line 84) | def generate( function batch_generation (line 124) | def batch_generation( function exact_div (line 156) | def exact_div(a, b, custom_error_message=""): function print_rich_table (line 163) | def print_rich_table(df: pd.DataFrame) -> None: function truncate_response (line 177) | def truncate_response(stop_token_id: int, pad_token_id: int, responses: ... function forward (line 200) | def forward( class OnlineTrainerState (line 233) | class OnlineTrainerState(TrainerState): function masked_mean (line 246) | def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool | N... function masked_var (line 254) | def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool ... function masked_whiten (line 273) | def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: ... class PolicyAndValueWrapper (line 284) | class PolicyAndValueWrapper(nn.Module): method __init__ (line 285) | def __init__(self, policy, value_model) -> None: method gradient_checkpointing_enable (line 292) | def gradient_checkpointing_enable(self, **kwargs): method gradient_checkpointing_disable (line 296) | def gradient_checkpointing_disable(self): method forward (line 300) | def forward(self, **kwargs): class PPOTrainer (line 306) | class PPOTrainer(_BaseTrainer): method __init__ (line 358) | def __init__( method get_train_dataloader (line 577) | def get_train_dataloader(self) -> DataLoader: method get_eval_dataloader (line 580) | def get_eval_dataloader(self) -> DataLoader: method null_ref_context (line 584) | def null_ref_context(self): method save_model (line 597) | def save_model(self, output_dir: str | None = None, _internal_call: bo... method train (line 611) | def train(self): method generate_completions (line 957) | def generate_completions(self, sampling: bool = False): method _save_checkpoint (line 1030) | def _save_checkpoint(self, model, trial): FILE: trl/experimental/prm/prm_config.py class PRMConfig (line 21) | class PRMConfig(_BaseConfig): FILE: trl/experimental/prm/prm_trainer.py function compute_accuracy (line 52) | def compute_accuracy(eval_pred: EvalPrediction) -> dict[str, float]: class PRMTrainer (line 97) | class PRMTrainer(_BaseTrainer): method __init__ (line 150) | def __init__( method tokenize_row (line 259) | def tokenize_row( method _save_checkpoint (line 351) | def _save_checkpoint(self, model, trial): FILE: trl/experimental/utils.py class DPODataCollatorWithPadding (line 46) | class DPODataCollatorWithPadding: method __call__ (line 60) | def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: class DataCollatorForChatML (line 127) | class DataCollatorForChatML: method __post_init__ (line 138) | def __post_init__(self): method __call__ (line 145) | def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.... function truncate_right (line 260) | def truncate_right( function add_bos_token_if_needed (line 292) | def add_bos_token_if_needed( function add_eos_token_if_needed (line 314) | def add_eos_token_if_needed( function first_true_indices (line 326) | def first_true_indices(bools: torch.Tensor, dtype=torch.long) -> torch.T... function get_reward (line 349) | def get_reward( function prepare_model_for_kbit_training (line 399) | def prepare_model_for_kbit_training(model, use_gradient_checkpointing=Tr... function enable_gradient_checkpointing (line 436) | def enable_gradient_checkpointing( function prepare_peft_model (line 465) | def prepare_peft_model( function pad_to_length (line 520) | def pad_to_length(tensor: torch.Tensor, length: int, pad_value: int | fl... function empty_cache (line 535) | def empty_cache() -> None: function peft_module_casting_to_bf16 (line 553) | def peft_module_casting_to_bf16(model): function create_reference_model (line 571) | def create_reference_model( FILE: trl/experimental/winrate_callback.py function _generate_completions (line 42) | def _generate_completions( function _win_rate_completions_df (line 82) | def _win_rate_completions_df( class WinRateCallback (line 92) | class WinRateCallback(TrainerCallback): method __init__ (line 134) | def __init__( method on_train_begin (line 158) | def on_train_begin(self, args: TrainingArguments, state: TrainerState,... method on_evaluate (line 226) | def on_evaluate(self, args: TrainingArguments, state: TrainerState, co... FILE: trl/experimental/xpo/xpo_config.py class XPOConfig (line 21) | class XPOConfig(OnlineDPOConfig): method __post_init__ (line 41) | def __post_init__(self): FILE: trl/experimental/xpo/xpo_trainer.py class XPOTrainer (line 49) | class XPOTrainer(OnlineDPOTrainer): method __init__ (line 109) | def __init__( method alpha (line 180) | def alpha(self): method _generate_completions (line 187) | def _generate_completions(self, prompts, model): method _process_completions (line 227) | def _process_completions(self, model_output, ref_output, prompts): method _compute_rewards (line 254) | def _compute_rewards(self, model_data, ref_data, context_length): method _compute_judge (line 272) | def _compute_judge(self, model_data, ref_data, context_length): method _compute_logprobs (line 307) | def _compute_logprobs(self, model, model_data, ref_data, context_length): method _compute_losses (line 339) | def _compute_losses( method _log_statistics (line 379) | def _log_statistics( method training_step (line 462) | def training_step( FILE: trl/extras/profiling.py class ProfilingContext (line 30) | class ProfilingContext: method __init__ (line 75) | def __init__( method __enter__ (line 90) | def __enter__(self): method __exit__ (line 95) | def __exit__(self, exc_type, exc_val, exc_tb): method _log_metrics (line 102) | def _log_metrics(self, duration: float) -> None: function profiling_context (line 125) | def profiling_context(trainer: Trainer, name: str) -> ProfilingContext: function profiling_decorator (line 167) | def profiling_decorator(func: Callable) -> Callable: FILE: trl/generation/vllm_client.py function pil_to_base64 (line 51) | def pil_to_base64(image): class VLLMClient (line 58) | class VLLMClient: method __init__ (line 122) | def __init__( method check_server (line 168) | def check_server(self, total_timeout: float = 0.0, retry_interval: flo... method generate (line 204) | def generate( method chat (line 302) | def chat( method init_communicator (line 416) | def init_communicator(self, device: torch.device | str | int = 0): method update_named_param (line 489) | def update_named_param(self, name: str, weights: torch.Tensor): method update_model_params (line 514) | def update_model_params(self, model: nn.Module): method reset_prefix_cache (line 526) | def reset_prefix_cache(self): method close_communicator (line 535) | def close_communicator(self): FILE: trl/generation/vllm_generation.py function empty_cache (line 44) | def empty_cache() -> None: function extract_logprobs (line 62) | def extract_logprobs(all_outputs: list["RequestOutput"]): class VLLMGeneration (line 103) | class VLLMGeneration: method __init__ (line 216) | def __init__( method _init_vllm (line 284) | def _init_vllm(self): method _fix_param_name_to_vllm (line 369) | def _fix_param_name_to_vllm(self, name: str, extra_prefixes: list[str]... method _sync_fsdp1_params_to_vllm (line 377) | def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = ... method _sync_fsdp2_params_to_vllm (line 406) | def _sync_fsdp2_params_to_vllm(self, module: nn.Module): method sync_weights (line 432) | def sync_weights(self): method generate (line 521) | def generate( FILE: trl/import_utils.py function _is_package_available (line 29) | def _is_package_available(pkg_name: str, return_version: bool = False) -... function is_deepspeed_available (line 60) | def is_deepspeed_available() -> bool: function is_fastapi_available (line 64) | def is_fastapi_available() -> bool: function is_jmespath_available (line 68) | def is_jmespath_available() -> bool: function is_joblib_available (line 72) | def is_joblib_available() -> bool: function is_liger_kernel_available (line 76) | def is_liger_kernel_available(min_version: str = LIGER_KERNEL_MIN_VERSIO... function is_llm_blender_available (line 81) | def is_llm_blender_available() -> bool: function is_math_verify_available (line 85) | def is_math_verify_available() -> bool: function is_mergekit_available (line 89) | def is_mergekit_available() -> bool: function is_pydantic_available (line 93) | def is_pydantic_available() -> bool: function is_requests_available (line 97) | def is_requests_available() -> bool: function is_unsloth_available (line 101) | def is_unsloth_available() -> bool: function is_uvicorn_available (line 105) | def is_uvicorn_available() -> bool: function is_vllm_available (line 109) | def is_vllm_available(min_version: str | None = None) -> bool: function is_vllm_ascend_available (line 123) | def is_vllm_ascend_available() -> bool: function is_weave_available (line 127) | def is_weave_available() -> bool: class TRLExperimentalWarning (line 131) | class TRLExperimentalWarning(UserWarning): function suppress_warning (line 138) | def suppress_warning(category): function suppress_experimental_warning (line 144) | def suppress_experimental_warning(): FILE: trl/models/activation_offloading.py function _get_unique_tensor_key (line 49) | def _get_unique_tensor_key(tensor: torch.Tensor) -> tuple: class OffloadActivations (line 79) | class OffloadActivations(saved_tensors_hooks): method __init__ (line 120) | def __init__( method update_model_params (line 529) | def update_model_params(self, model: nn.Module): class NoOpManager (line 562) | class NoOpManager(saved_tensors_hooks): method __init__ (line 571) | def __init__(self) -> None: function get_act_offloading_ctx_manager (line 578) | def get_act_offloading_ctx_manager( FILE: trl/models/utils.py function remove_hooks (line 47) | def remove_hooks(model: "DeepSpeedEngine") -> None: function get_all_parameters (line 70) | def get_all_parameters(sub_module, recurse=False): function iter_params (line 74) | def iter_params(module, recurse=False): function add_hooks (line 78) | def add_hooks(model: "DeepSpeedEngine") -> None: function _unwrap_model_for_generation (line 98) | def _unwrap_model_for_generation( function _override_model_generation_config (line 145) | def _override_model_generation_config(model, generation_kwargs=None): function unwrap_model_for_generation (line 187) | def unwrap_model_for_generation( function prepare_deepspeed (line 225) | def prepare_deepspeed(model: "Module", accelerator: "Accelerator"): function prepare_fsdp (line 265) | def prepare_fsdp(model, accelerator: Accelerator) -> FSDP | FSDPModule: class _ForwardRedirection (line 313) | class _ForwardRedirection: method __call__ (line 323) | def __call__( method on_after_inner_forward (line 357) | def on_after_inner_forward(self, wrapper_module: nn.Module, original_m... method on_after_outer_forward (line 360) | def on_after_outer_forward(self, wrapper_module: nn.Module, original_m... function disable_gradient_checkpointing (line 365) | def disable_gradient_checkpointing(model: PreTrainedModel, gradient_chec... function create_reference_model (line 385) | def create_reference_model( FILE: trl/rewards/accuracy_rewards.py function accuracy_reward (line 26) | def accuracy_reward(completions: list[list[dict[str, str]]], solution: l... function reasoning_accuracy_reward (line 97) | def reasoning_accuracy_reward( FILE: trl/rewards/format_rewards.py function think_format_reward (line 18) | def think_format_reward(completions: list[list[dict[str, str]]], **kwarg... FILE: trl/rewards/other_rewards.py function get_soft_overlong_punishment (line 18) | def get_soft_overlong_punishment(max_completion_len: int, soft_punish_ca... FILE: trl/scripts/_hf_argparser.py function string_to_bool (line 40) | def string_to_bool(v): function make_choice_type_function (line 53) | def make_choice_type_function(choices: list) -> Callable[[str], Any]: function HfArg (line 68) | def HfArg( class HfArgumentParser (line 115) | class HfArgumentParser(ArgumentParser): method __init__ (line 132) | def __init__(self, dataclass_types: DataClassType | Iterable[DataClass... method _parse_dataclass_field (line 150) | def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.... method _add_dataclass_arguments (line 255) | def _add_dataclass_arguments(self, dtype: DataClassType): method parse_args_into_dataclasses (line 276) | def parse_args_into_dataclasses( method parse_dict (line 362) | def parse_dict(self, args: dict[str, Any], allow_extra_keys: bool = Fa... method parse_json_file (line 390) | def parse_json_file(self, json_file: str | os.PathLike, allow_extra_ke... method parse_yaml_file (line 412) | def parse_yaml_file(self, yaml_file: str | os.PathLike, allow_extra_ke... FILE: trl/scripts/dpo.py function main (line 69) | def main(script_args, training_args, model_args, dataset_args): function make_parser (line 149) | def make_parser(subparsers: argparse._SubParsersAction | None = None, pr... FILE: trl/scripts/env.py function print_env (line 26) | def print_env(): FILE: trl/scripts/grpo.py class GRPOScriptArguments (line 38) | class GRPOScriptArguments(ScriptArguments): function main (line 72) | def main(script_args, training_args, model_args, dataset_args): function make_parser (line 171) | def make_parser(subparsers: argparse._SubParsersAction | None = None, pr... FILE: trl/scripts/kto.py function main (line 75) | def main(script_args, training_args, model_args, dataset_args): function make_parser (line 141) | def make_parser(subparsers: argparse._SubParsersAction | None = None, pr... FILE: trl/scripts/reward.py function main (line 32) | def main(script_args, training_args, model_args, dataset_args): function make_parser (line 80) | def make_parser(subparsers: argparse._SubParsersAction | None = None, pr... FILE: trl/scripts/rloo.py class RLOOScriptArguments (line 38) | class RLOOScriptArguments(ScriptArguments): function main (line 72) | def main(script_args, training_args, model_args, dataset_args): function make_parser (line 155) | def make_parser(subparsers: argparse._SubParsersAction | None = None, pr... FILE: trl/scripts/sft.py function main (line 71) | def main(script_args, training_args, model_args, dataset_args): function make_parser (line 147) | def make_parser(subparsers: argparse._SubParsersAction | None = None, pr... FILE: trl/scripts/utils.py class DatasetConfig (line 40) | class DatasetConfig: class DatasetMixtureConfig (line 74) | class DatasetMixtureConfig: method __post_init__ (line 129) | def __post_init__(self): class ScriptArguments (line 138) | class ScriptArguments: function init_zero_verbose (line 197) | def init_zero_verbose(): class TrlParser (line 226) | class TrlParser(HfArgumentParser): method __init__ (line 274) | def __init__( method parse_args_and_config (line 295) | def parse_args_and_config( method set_defaults_with_config (line 351) | def set_defaults_with_config(self, **kwargs) -> list[str]: function get_git_commit_hash (line 381) | def get_git_commit_hash(package_name): function get_dataset (line 404) | def get_dataset(mixture_config: DatasetMixtureConfig) -> "DatasetDict": FILE: trl/scripts/vllm_serve.py class WeightSyncWorkerExtension (line 34) | class WeightSyncWorkerExtension: method init_communicator (line 48) | def init_communicator(self, host: str, port: int, world_size: int, cli... method update_named_param (line 115) | def update_named_param(self, name: str, dtype: str, shape: Sequence[in... method close_communicator (line 149) | def close_communicator(self) -> None: class ScriptArguments (line 163) | class ScriptArguments: function llm_worker (line 310) | def llm_worker( function chunk_list (line 364) | def chunk_list(lst: list, n: int) -> list[list]: function main (line 384) | def main(script_args: ScriptArguments): function make_parser (line 888) | def make_parser(subparsers: argparse._SubParsersAction | None = None, pr... FILE: trl/skills/cli.py function add_skills_subcommands (line 26) | def add_skills_subcommands(subparsers: argparse._SubParsersAction) -> None: function cmd_install (line 90) | def cmd_install(args): function cmd_uninstall (line 147) | def cmd_uninstall(args): function cmd_list (line 162) | def cmd_list(args): FILE: trl/skills/skills.py function list_agent_names (line 50) | def list_agent_names() -> list[str]: function _get_trl_skills_dir (line 60) | def _get_trl_skills_dir() -> Path: function resolve_target_path (line 72) | def resolve_target_path(target: str | Path, scope: str = "project") -> P... function _list_skills_in_dir (line 117) | def _list_skills_in_dir(skills_dir: Path) -> list[str]: function list_skills (line 138) | def list_skills(target: str | Path | None = None, scope: str = "project"... function _install_skill_to_dir (line 178) | def _install_skill_to_dir( function install_skill (line 244) | def install_skill( function _uninstall_skill_from_dir (line 294) | def _uninstall_skill_from_dir(skill_name: str, target_dir: Path) -> bool: function uninstall_skill (line 326) | def uninstall_skill(skill_name: str, target: str | Path, scope: str = "p... FILE: trl/trainer/base_config.py class _BaseConfig (line 21) | class _BaseConfig(TrainingArguments): method __post_init__ (line 104) | def __post_init__(self): FILE: trl/trainer/base_trainer.py class _BaseTrainer (line 26) | class _BaseTrainer(Trainer): method create_model_card (line 32) | def create_model_card( FILE: trl/trainer/callbacks.py function _generate_completions (line 62) | def _generate_completions( class SyncRefModelCallback (line 102) | class SyncRefModelCallback(TrainerCallback): method __init__ (line 107) | def __init__( method _sync_target_model (line 116) | def _sync_target_model(model, target_model, alpha): method sync_target_model (line 121) | def sync_target_model(model, target_model, alpha): method on_step_end (line 134) | def on_step_end(self, args, state, control, **kwargs): class RichProgressCallback (line 143) | class RichProgressCallback(TrainerCallback): method __init__ (line 148) | def __init__(self): method on_train_begin (line 161) | def on_train_begin(self, args, state, control, **kwargs): method on_step_end (line 174) | def on_step_end(self, args, state, control, **kwargs): method on_prediction_step (line 181) | def on_prediction_step(self, args, state, control, eval_dataloader=Non... method on_evaluate (line 190) | def on_evaluate(self, args, state, control, **kwargs): method on_predict (line 198) | def on_predict(self, args, state, control, **kwargs): method on_log (line 206) | def on_log(self, args, state, control, logs=None, **kwargs): method on_train_end (line 239) | def on_train_end(self, args, state, control, **kwargs): class LogCompletionsCallback (line 254) | class LogCompletionsCallback(TrainerCallback): method __init__ (line 278) | def __init__( method on_step_end (line 299) | def on_step_end(self, args, state, control, **kwargs): class WeaveCallback (line 346) | class WeaveCallback(TrainerCallback): method __init__ (line 410) | def __init__( method _initialize_weave (line 438) | def _initialize_weave(self): method is_evaluation_mode (line 473) | def is_evaluation_mode(self) -> bool: method on_train_begin (line 477) | def on_train_begin(self, args, state, control, **kwargs): method on_evaluate (line 481) | def on_evaluate(self, args, state, control, **kwargs): class BEMACallback (line 575) | class BEMACallback(TrainerCallback): method __init__ (line 637) | def __init__( method _unwrap_model (line 666) | def _unwrap_model(model): method on_train_begin (line 688) | def on_train_begin( method _ema_beta (line 709) | def _ema_beta(self, step: int) -> float: method _bema_alpha (line 714) | def _bema_alpha(self, step: int) -> float: method _update_bema_weights (line 718) | def _update_bema_weights(self, step: int): method on_step_end (line 731) | def on_step_end( method on_train_end (line 754) | def on_train_end(self, args: TrainingArguments, state: TrainerState, c... FILE: trl/trainer/dpo_config.py class DPOConfig (line 22) | class DPOConfig(_BaseConfig): method __post_init__ (line 310) | def __post_init__(self): FILE: trl/trainer/dpo_trainer.py function get_dataset_column_names (line 89) | def get_dataset_column_names(dataset: Dataset | IterableDataset) -> list... class DataCollatorForPreference (line 94) | class DataCollatorForPreference(DataCollatorMixin): method torch_call (line 154) | def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: class DataCollatorForVisionPreference (line 215) | class DataCollatorForVisionPreference(DataCollatorMixin): method torch_call (line 302) | def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: class DPOTrainer (line 406) | class DPOTrainer(_BaseTrainer): method __init__ (line 501) | def __init__( method _prepare_dataset (line 823) | def _prepare_dataset( method _set_signature_columns_if_needed (line 951) | def _set_signature_columns_if_needed(self): method _precompute_ref_logps (line 975) | def _precompute_ref_logps(self, dataset: Dataset, name: str, batch_siz... method _truncate_inputs (line 1017) | def _truncate_inputs( method compute_ref_log_probs (line 1051) | def compute_ref_log_probs(self, inputs): method _compute_loss_liger (line 1103) | def _compute_loss_liger(self, model, inputs, return_outputs): method _compute_loss (line 1181) | def _compute_loss(self, model, inputs, return_outputs): method compute_loss (line 1494) | def compute_loss(self, model, inputs, return_outputs=False, num_items_... method training_step (line 1501) | def training_step(self, *args, **kwargs): method log (line 1505) | def log(self, logs: dict[str, float], start_time: float | None = None)... method prediction_step (line 1520) | def prediction_step(self, model, inputs, prediction_loss_only, ignore_... method _save_checkpoint (line 1532) | def _save_checkpoint(self, model, trial): FILE: trl/trainer/grpo_config.py class GRPOConfig (line 22) | class GRPOConfig(_BaseConfig): method __post_init__ (line 869) | def __post_init__(self): FILE: trl/trainer/grpo_trainer.py class _SupportsReset (line 123) | class _SupportsReset(Protocol): method reset (line 124) | def reset(self, **kwargs) -> str | None: ... class GRPOTrainer (line 130) | class GRPOTrainer(_BaseTrainer): method __init__ (line 268) | def __init__( method _set_signature_columns_if_needed (line 852) | def _set_signature_columns_if_needed(self): method get_train_dataloader (line 869) | def get_train_dataloader(self): method _get_train_sampler (line 878) | def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: method _get_eval_sampler (line 914) | def _get_eval_sampler(self, eval_dataset) -> Sampler: method _get_last_hidden_state (line 923) | def _get_last_hidden_state( method get_high_entropy_mask (line 967) | def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.T... method _get_per_token_logps_and_entropies (line 1006) | def _get_per_token_logps_and_entropies( method training_step (line 1081) | def training_step(self, model, inputs, num_items_in_batch): method _prepare_inputs (line 1093) | def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | A... method _log_completion_extra (line 1124) | def _log_completion_extra(self, column: str, values: list): method _log_metric (line 1136) | def _log_metric(self, name: str, value: float): method _calculate_rewards (line 1150) | def _calculate_rewards(self, inputs, prompts, completions, completion_... method _tokenize_prompts (line 1241) | def _tokenize_prompts(self, prompts: list): method _generate_single_turn (line 1284) | def _generate_single_turn(self, prompt_ids, images, multimodal_fields): method _get_tool_suffix_ids (line 1380) | def _get_tool_suffix_ids(self, tool_messages): method _tool_call_loop (line 1401) | def _tool_call_loop(self, prompts, prompt_ids, completion_ids, complet... method _generate (line 1577) | def _generate(self, prompts: list): method _generate_and_score_completions (line 1692) | def _generate_and_score_completions( method compute_liger_loss (line 2113) | def compute_liger_loss(self, unwrapped_model, inputs): method compute_loss (line 2161) | def compute_loss(self, model, inputs, return_outputs=False, num_items_... method get_off_policy_mask (line 2172) | def get_off_policy_mask( method get_gamma_weights (line 2195) | def get_gamma_weights( method _compute_loss (line 2243) | def _compute_loss(self, model, inputs): method prediction_step (line 2442) | def prediction_step(self, model, inputs, prediction_loss_only, ignore_... method log (line 2450) | def log(self, logs: dict[str, float], start_time: float | None = None)... method _save_checkpoint (line 2519) | def _save_checkpoint(self, model, trial): FILE: trl/trainer/kto_config.py class KTOConfig (line 26) | class KTOConfig(_KTOConfig): method __post_init__ (line 27) | def __post_init__(self): FILE: trl/trainer/kto_trainer.py class KTOTrainer (line 26) | class KTOTrainer(_KTOTrainer): method __init__ (line 27) | def __init__(self, *args, **kwargs): FILE: trl/trainer/model_config.py class ModelConfig (line 19) | class ModelConfig: method __post_init__ (line 183) | def __post_init__(self): FILE: trl/trainer/reward_config.py class RewardConfig (line 22) | class RewardConfig(_BaseConfig): FILE: trl/trainer/reward_trainer.py function _suppress_seqcls_cross_arch_keys (line 75) | def _suppress_seqcls_cross_arch_keys(logger: logging.Logger): function _ignore_seqcls_cross_arch_keys (line 96) | def _ignore_seqcls_cross_arch_keys(): function suppress_seqcls_warning (line 121) | def suppress_seqcls_warning(): function get_dataset_column_names (line 134) | def get_dataset_column_names(dataset: Dataset | IterableDataset) -> list... class DataCollatorForPreference (line 139) | class DataCollatorForPreference(DataCollatorMixin): method torch_call (line 200) | def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: class RewardTrainer (line 229) | class RewardTrainer(_BaseTrainer): method __init__ (line 321) | def __init__( method _prepare_dataset (line 537) | def _prepare_dataset( method _set_signature_columns_if_needed (line 636) | def _set_signature_columns_if_needed(self): method compute_loss (line 643) | def compute_loss(self, model, inputs, return_outputs=False, num_items_... method training_step (line 685) | def training_step(self, *args, **kwargs): method log (line 689) | def log(self, logs: dict[str, float], start_time: float | None = None)... method _save_checkpoint (line 703) | def _save_checkpoint(self, model, trial): FILE: trl/trainer/rloo_config.py class RLOOConfig (line 22) | class RLOOConfig(_BaseConfig): method __post_init__ (line 547) | def __post_init__(self): FILE: trl/trainer/rloo_trainer.py class RLOOTrainer (line 104) | class RLOOTrainer(_BaseTrainer): method __init__ (line 220) | def __init__( method _set_signature_columns_if_needed (line 596) | def _set_signature_columns_if_needed(self): method get_train_dataloader (line 613) | def get_train_dataloader(self): method _get_train_sampler (line 622) | def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: method _get_eval_sampler (line 658) | def _get_eval_sampler(self, eval_dataset) -> Sampler: method _get_per_token_logps_and_entropies (line 667) | def _get_per_token_logps_and_entropies( method training_step (line 742) | def training_step(self, model, inputs, num_items_in_batch): method _prepare_inputs (line 754) | def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | A... method _log_completion_extra (line 785) | def _log_completion_extra(self, column: str, values: list): method _log_metric (line 797) | def _log_metric(self, name: str, value: float): method _calculate_rewards (line 811) | def _calculate_rewards(self, inputs, prompts, completions, completion_... method _tokenize_prompts (line 900) | def _tokenize_prompts(self, prompts: list): method _generate_single_turn (line 941) | def _generate_single_turn(self, prompt_ids, images, multimodal_fields): method _generate (line 1031) | def _generate(self, prompts: list): method _generate_and_score_completions (line 1080) | def _generate_and_score_completions( method compute_loss (line 1370) | def compute_loss(self, model, inputs, return_outputs=False, num_items_... method _compute_loss (line 1375) | def _compute_loss(self, model, inputs): method prediction_step (line 1435) | def prediction_step(self, model, inputs, prediction_loss_only, ignore_... method log (line 1443) | def log(self, logs: dict[str, float], start_time: float | None = None)... method _save_checkpoint (line 1504) | def _save_checkpoint(self, model, trial): FILE: trl/trainer/sft_config.py class SFTConfig (line 23) | class SFTConfig(_BaseConfig): method __post_init__ (line 269) | def __post_init__(self): FILE: trl/trainer/sft_trainer.py function get_dataset_column_names (line 86) | def get_dataset_column_names(dataset: Dataset | IterableDataset) -> list... class DataCollatorForLanguageModeling (line 91) | class DataCollatorForLanguageModeling(DataCollatorMixin): method torch_call (line 164) | def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: method get_position_ids_from_packed_seq_lengths (line 231) | def get_position_ids_from_packed_seq_lengths(batch_seq_lengths: list[l... class DataCollatorForVisionLanguageModeling (line 259) | class DataCollatorForVisionLanguageModeling(DataCollatorMixin): method torch_call (line 344) | def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: method _collate_language_modeling (line 356) | def _collate_language_modeling(self, examples: list[dict[str, Any]]) -... method _collate_prompt_completion (line 395) | def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -... function dft_loss (line 494) | def dft_loss(outputs, labels, num_items_in_batch=None): class SFTTrainer (line 511) | class SFTTrainer(_BaseTrainer): method __init__ (line 609) | def __init__( method _prepare_dataset (line 973) | def _prepare_dataset( method _set_signature_columns_if_needed (line 1198) | def _set_signature_columns_if_needed(self): method compute_loss (line 1209) | def compute_loss(self, model, inputs, return_outputs=False, num_items_... method prediction_step (line 1339) | def prediction_step(self, model, inputs, prediction_loss_only, ignore_... method training_step (line 1345) | def training_step(self, *args, **kwargs): method log (line 1349) | def log(self, logs: dict[str, float], start_time: float | None = None)... method _save_checkpoint (line 1363) | def _save_checkpoint(self, model, trial): FILE: trl/trainer/utils.py function _is_port_free (line 73) | def _is_port_free(port: int, host: str = "127.0.0.1") -> bool: function _find_free_port (line 83) | def _find_free_port() -> int: function ensure_master_addr_port (line 93) | def ensure_master_addr_port(addr: str | None = None, port: int | None = ... function pad (line 114) | def pad( function disable_dropout_in_model (line 180) | def disable_dropout_in_model(model: torch.nn.Module) -> None: function get_quantization_config (line 186) | def get_quantization_config(model_args: ModelConfig) -> BitsAndBytesConf... function get_kbit_device_map (line 205) | def get_kbit_device_map() -> dict[str, int] | None: function get_peft_config (line 212) | def get_peft_config(model_args: ModelConfig) -> "PeftConfig | None": function prepare_deepspeed (line 238) | def prepare_deepspeed( function generate_model_card (line 296) | def generate_model_card( function get_comet_experiment_url (line 378) | def get_comet_experiment_url() -> str | None: function get_trackio_space_url (line 391) | def get_trackio_space_url() -> str | None: function log_table_to_comet_experiment (line 412) | def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None: function flush_left (line 430) | def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> torch.Tens... function flush_right (line 495) | def flush_right(mask: torch.Tensor, *tensors: torch.Tensor) -> torch.Ten... function selective_log_softmax (line 525) | def selective_log_softmax(logits, index) -> torch.Tensor: function entropy_from_logits (line 572) | def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> ... function print_prompt_completions_sample (line 609) | def print_prompt_completions_sample( class RepeatSampler (line 723) | class RepeatSampler(Sampler): method __init__ (line 772) | def __init__( method __iter__ (line 794) | def __iter__(self): method __len__ (line 815) | def __len__(self) -> int: function nanstd (line 820) | def nanstd(tensor: torch.Tensor, dim: int | tuple[int, ...] | None = Non... function split_tensor_dict (line 856) | def split_tensor_dict( function shuffle_sequence_dict (line 891) | def shuffle_sequence_dict(seq_dict: dict[str, Sequence | None]) -> dict[... function nanmin (line 923) | def nanmin(tensor: torch.Tensor) -> torch.Tensor: function nanmax (line 938) | def nanmax(tensor: torch.Tensor) -> torch.Tensor: function identity (line 953) | def identity(x): function split_pixel_values_by_grid (line 958) | def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[s... function unsplit_pixel_values_by_grid (line 979) | def unsplit_pixel_values_by_grid(batch: dict[str, torch.Tensor | list[to... function remove_none_values (line 1000) | def remove_none_values(example: TListOrMapping) -> TListOrMapping: function create_model_from_path (line 1032) | def create_model_from_path( function hash_module (line 1071) | def hash_module(module: torch.nn.Module) -> str: function get_config_model_id (line 1082) | def get_config_model_id(config: PretrainedConfig) -> str: class CausalLMOutputWithPastAndFlatLogits (line 1098) | class CausalLMOutputWithPastAndFlatLogits(CausalLMOutputWithPast): function forward_masked_logits (line 1102) | def forward_masked_logits( function use_adapter (line 1158) | def use_adapter(model: "PeftModel", adapter_name: str | None): function start_event_loop_in_daemon (line 1197) | def start_event_loop_in_daemon( function shutdown_event_loop_in_daemon (line 1228) | def shutdown_event_loop_in_daemon(