SYMBOL INDEX (2153 symbols across 176 files) FILE: benchmarks/big_model_inference/big_model_inference.py function parse_args (line 43) | def parse_args(): function main (line 79) | def main(): FILE: benchmarks/big_model_inference/measures_util.py class PeakCPUMemory (line 28) | class PeakCPUMemory: method __init__ (line 29) | def __init__(self): method peak_monitor (line 33) | def peak_monitor(self): method start (line 43) | def start(self): method stop (line 49) | def stop(self): function start_measure (line 58) | def start_measure(): function end_measure (line 77) | def end_measure(start_measures): function log_measures (line 96) | def log_measures(measures, description): FILE: benchmarks/fp8/ms_amp/ddp.py function train_baseline (line 36) | def train_baseline(opt_level="O2"): function train_integration (line 75) | def train_integration(opt_level="O2"): FILE: benchmarks/fp8/ms_amp/distrib_deepspeed.py function train_baseline (line 38) | def train_baseline(zero_stage: int = 1, opt_level: str = "O1"): function train_integration (line 103) | def train_integration(zero_stage: int = 1, opt_level: str = "O1"): FILE: benchmarks/fp8/ms_amp/fp8_utils.py function get_dataloaders (line 17) | def get_dataloaders(model_name: str, batch_size: int = 16): function get_training_utilities (line 65) | def get_training_utilities(model_name: str, batch_size: int = 16, accele... function get_named_parameters (line 94) | def get_named_parameters(model): function evaluate_model (line 105) | def evaluate_model(model, dataloader, metric, accelerator=None): FILE: benchmarks/fp8/ms_amp/non_distributed.py function train_baseline (line 35) | def train_baseline(opt_level="O2"): function train_integration (line 69) | def train_integration(opt_level="O2"): FILE: benchmarks/fp8/torchao/ddp.py function evaluate_model (line 38) | def evaluate_model(model, dataloader, metric, accelerator=None): function filter_linear_layers (line 52) | def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_... function train_baseline (line 63) | def train_baseline(): function train_integration (line 109) | def train_integration(): FILE: benchmarks/fp8/torchao/distrib_deepspeed.py function filter_linear_layers (line 40) | def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_... function train_baseline (line 51) | def train_baseline(zero_stage: int = 1): function train_integration (line 140) | def train_integration(zero_stage: int = 1): FILE: benchmarks/fp8/torchao/fp8_utils.py function get_dataloaders (line 17) | def get_dataloaders(model_name: str, batch_size: int = 16): function get_training_utilities (line 65) | def get_training_utilities(model_name: str, batch_size: int = 16, accele... function get_named_parameters (line 94) | def get_named_parameters(model): function evaluate_model (line 105) | def evaluate_model(model, dataloader, metric, accelerator=None): FILE: benchmarks/fp8/torchao/fsdp.py function filter_linear_layers (line 44) | def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_... function evaluate_model (line 55) | def evaluate_model(model, dataloader, metric, accelerator=None): function train_baseline (line 69) | def train_baseline(): function train_integration (line 119) | def train_integration(): FILE: benchmarks/fp8/torchao/non_distributed.py function evaluate_model (line 37) | def evaluate_model(model, dataloader, metric, accelerator=None): function filter_linear_layers (line 51) | def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_... function train_baseline (line 62) | def train_baseline(): function train_integration (line 102) | def train_integration(): FILE: benchmarks/fp8/transformer_engine/ddp.py function train_baseline (line 39) | def train_baseline(): function train_integration (line 92) | def train_integration(): FILE: benchmarks/fp8/transformer_engine/distrib_deepspeed.py function train_baseline (line 41) | def train_baseline(zero_stage: int = 1): function train_integration (line 126) | def train_integration(zero_stage: int = 1): FILE: benchmarks/fp8/transformer_engine/fp8_utils.py function get_dataloaders (line 17) | def get_dataloaders(model_name: str, batch_size: int = 16): function get_training_utilities (line 65) | def get_training_utilities(model_name: str, batch_size: int = 16, accele... function get_named_parameters (line 94) | def get_named_parameters(model): function evaluate_model (line 105) | def evaluate_model(model, dataloader, metric, accelerator=None): FILE: benchmarks/fp8/transformer_engine/fsdp.py function train_baseline (line 47) | def train_baseline(): function train_integration (line 104) | def train_integration(): FILE: benchmarks/fp8/transformer_engine/non_distributed.py function train_baseline (line 38) | def train_baseline(): function train_integration (line 83) | def train_integration(): FILE: benchmarks/fsdp2/main.py function train (line 33) | def train( function evaluate (line 52) | def evaluate(args, config: dict, init_fn: Callable, run_name: str) -> to... function main (line 67) | def main(): FILE: benchmarks/fsdp2/measure_utils.py class MemoryTracker (line 27) | class MemoryTracker: method __init__ (line 28) | def __init__( method _monitor (line 69) | def _monitor(self): method start (line 84) | def start(self): method stop (line 99) | def stop(self): method peak_allocated_memory (line 125) | def peak_allocated_memory(self): method peak_reserved_memory (line 129) | def peak_reserved_memory(self): FILE: benchmarks/fsdp2/utils.py function get_named_parameters (line 36) | def get_named_parameters(model: torch.nn.Module, drop_refs: bool = False... function replace_optimizer_params (line 59) | def replace_optimizer_params(optimizer: torch.optim.Optimizer): function swap_back_optimizer_params (line 81) | def swap_back_optimizer_params( function parse_args (line 106) | def parse_args(): function prepare_dataloader (line 143) | def prepare_dataloader(tokenizer, args, accelerator: Accelerator) -> Dat... function get_model (line 191) | def get_model(model_name: str): function get_tokenizer (line 198) | def get_tokenizer(model_name: str): function prepare_torch (line 204) | def prepare_torch( function prepare_accelerate (line 262) | def prepare_accelerate( FILE: benchmarks/fsdp2/visualize.py function parse_args (line 21) | def parse_args(): function filter_data (line 39) | def filter_data(data, memory_threshold, filter_partition, key): function compare_memory_usage (line 54) | def compare_memory_usage(data, labels, memory_threshold, filter_partition): FILE: examples/alst_ulysses_sequence_parallelism/sp-alst.py function convert (line 62) | def convert(ex): function collate_fn (line 69) | def collate_fn(batch): function collate_fn (line 92) | def collate_fn(batch): FILE: examples/by_feature/automatic_gradient_accumulation.py function get_dataloaders (line 54) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16): function training_function (line 120) | def training_function(config, args): function main (line 223) | def main(): FILE: examples/by_feature/checkpointing.py function get_dataloaders (line 55) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16): function training_function (line 124) | def training_function(config, args): function main (line 285) | def main(): FILE: examples/by_feature/cross_validation.py function get_fold_dataloaders (line 62) | def get_fold_dataloaders( function training_function (line 139) | def training_function(config, args): function main (line 260) | def main(): FILE: examples/by_feature/ddp_comm_hook.py function get_dataloaders (line 50) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16): function training_function (line 119) | def training_function(config, args): function main (line 198) | def main(): FILE: examples/by_feature/deepspeed_with_config_support.py function parse_args (line 65) | def parse_args(): function evaluate (line 245) | def evaluate(args, model, eval_dataloader, accelerator, eval_dataset): function main (line 264) | def main(): FILE: examples/by_feature/early_stopping.py function get_dataloaders (line 49) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16): class EarlyStoppingCallback (line 116) | class EarlyStoppingCallback: method __init__ (line 119) | def __init__(self, min_delta=0, patience=5): method check_early_stopping (line 125) | def check_early_stopping(self, eval_loss): function training_function (line 140) | def training_function(config, args): function main (line 228) | def main(): FILE: examples/by_feature/fsdp_with_peak_mem_tracking.py function b2mb (line 62) | def b2mb(x): class TorchTracemalloc (line 68) | class TorchTracemalloc: method __enter__ (line 69) | def __enter__(self): method cpu_mem_used (line 92) | def cpu_mem_used(self): method peak_monitor_func (line 96) | def peak_monitor_func(self): method __exit__ (line 108) | def __exit__(self, *exc): function training_function (line 140) | def training_function(config, args): function main (line 405) | def main(): FILE: examples/by_feature/gradient_accumulation.py function get_dataloaders (line 49) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16): function training_function (line 118) | def training_function(config, args): function main (line 203) | def main(): FILE: examples/by_feature/gradient_accumulation_for_autoregressive_models.py function get_dataloaders (line 49) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, max_... function training_function (line 144) | def training_function(config, args): function main (line 306) | def main(): FILE: examples/by_feature/local_sgd.py function get_dataloaders (line 52) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16): function training_function (line 121) | def training_function(config, args): function main (line 208) | def main(): FILE: examples/by_feature/megatron_lm_gpt_pretraining.py function parse_args (line 69) | def parse_args(): function main (line 247) | def main(): FILE: examples/by_feature/memory.py function get_dataloaders (line 55) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16): function training_function (line 124) | def training_function(config, args): function main (line 216) | def main(): FILE: examples/by_feature/multi_process_metrics.py function get_dataloaders (line 56) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16): function training_function (line 125) | def training_function(config, args): function main (line 220) | def main(): FILE: examples/by_feature/profiler.py function get_dataloaders (line 50) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16): function training_function (line 119) | def training_function(config, args): function main (line 210) | def main(): FILE: examples/by_feature/schedule_free.py function get_dataloaders (line 57) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16): function training_function (line 132) | def training_function(config, args): function main (line 208) | def main(): FILE: examples/by_feature/tracking.py function get_dataloaders (line 54) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16): function training_function (line 123) | def training_function(config, args): function main (line 242) | def main(): FILE: examples/complete_cv_example.py function extract_label (line 48) | def extract_label(fname): class PetsDataset (line 53) | class PetsDataset(Dataset): method __init__ (line 54) | def __init__(self, file_names, image_transform=None, label_to_id=None): method __len__ (line 59) | def __len__(self): method __getitem__ (line 62) | def __getitem__(self, idx): function training_function (line 74) | def training_function(config, args): function main (line 279) | def main(): FILE: examples/complete_nlp_example.py function training_function (line 50) | def training_function(config, args): function main (line 272) | def main(): FILE: examples/cv_example.py function extract_label (line 48) | def extract_label(fname): class PetsDataset (line 53) | class PetsDataset(Dataset): method __init__ (line 54) | def __init__(self, file_names, image_transform=None, label_to_id=None): method __len__ (line 59) | def __len__(self): method __getitem__ (line 62) | def __getitem__(self, idx): function training_function (line 74) | def training_function(config, args): function main (line 184) | def main(): FILE: examples/finetune_lm_tpu.py function format_dolly (line 36) | def format_dolly(example, tokenizer): function train (line 54) | def train(model_id, dataset): FILE: examples/inference/distributed/distributed_image_generation.py function get_batches (line 44) | def get_batches(items, batch_size): function main (line 57) | def main( FILE: examples/inference/distributed/distributed_speech_generation.py function load_pokemon_data (line 51) | def load_pokemon_data(split: str, max_text_length: int): class ExistsFilter (line 69) | class ExistsFilter: method __init__ (line 70) | def __init__(self, output_dir: Union[pathlib.Path, str]): method __call__ (line 75) | def __call__(self, x): function preprocess_fn (line 79) | def preprocess_fn(sample, tokenizer, max_text_length: int): function collate_fn (line 91) | def collate_fn(examples, tokenizer): function create_dataloader (line 129) | def create_dataloader(dataset, batch_size, distributed_state, tokenizer): function save_results (line 152) | def save_results(output_queue: queue.Queue, output_dir: pathlib.Path, sa... function main (line 181) | def main( FILE: examples/inference/distributed/florence2.py function main (line 44) | def main( FILE: examples/inference/distributed/llava_next_video.py function save_results (line 44) | def save_results(output_queue: queue.Queue, output_dir: pathlib.Path): function get_batches (line 64) | def get_batches(processed_videos, batch_size): function read_video_pyav (line 77) | def read_video_pyav(container, indices): function get_video_paths (line 98) | def get_video_paths(video_dir): function process_videos (line 111) | def process_videos(video_paths, processor, prompt, frames_per_video): function main (line 138) | def main( FILE: examples/multigpu_remote_launcher.py function launch_train (line 23) | def launch_train(*args): FILE: examples/nlp_example.py function get_dataloaders (line 47) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16): function training_function (line 113) | def training_function(config, args): function main (line 192) | def main(): FILE: examples/torch_native_parallelism/fsdp2_fp8.py function parse_args (line 37) | def parse_args(): function main (line 48) | def main(): FILE: examples/torch_native_parallelism/nd_parallel.py function parse_args (line 42) | def parse_args(): function forward (line 57) | def forward(model, batch, optimizer, accelerator: Accelerator): function train (line 83) | def train(args): FILE: examples/torch_native_parallelism/nd_parallel_trainer.py function parse_args (line 26) | def parse_args(): function main (line 36) | def main(): FILE: examples/torch_native_parallelism/utils.py function get_dataset (line 29) | def get_dataset(tokenizer: AutoTokenizer, seq_len: int, accelerator: Acc... function get_model_flops_per_token (line 94) | def get_model_flops_per_token(model: AutoModelForCausalLM, seq_len: int)... function create_collate_fn (line 118) | def create_collate_fn(): class PerformanceTracker (line 129) | class PerformanceTracker: method __init__ (line 132) | def __init__(self, warmup_steps: int = 10): method reset (line 136) | def reset(self): method step (line 143) | def step(self, batch_tokens: int, model_flops_per_token: float | None ... method get_print_message (line 185) | def get_print_message(self, metrics: dict, with_memory: bool = False) ... function setup_tokenizer (line 196) | def setup_tokenizer(model_id: str) -> AutoTokenizer: function gpu_memory_usage_all (line 204) | def gpu_memory_usage_all(device=0): FILE: manim_animations/big_model_inference/stage_1.py class Stage1 (line 18) | class Stage1(Scene): method construct (line 19) | def construct(self): FILE: manim_animations/big_model_inference/stage_2.py class Stage2 (line 17) | class Stage2(Scene): method construct (line 18) | def construct(self): FILE: manim_animations/big_model_inference/stage_3.py class Stage3 (line 17) | class Stage3(Scene): method construct (line 18) | def construct(self): FILE: manim_animations/big_model_inference/stage_4.py class Stage4 (line 17) | class Stage4(Scene): method construct (line 18) | def construct(self): FILE: manim_animations/big_model_inference/stage_5.py class Stage5 (line 17) | class Stage5(Scene): method construct (line 18) | def construct(self): FILE: manim_animations/dataloaders/stage_0.py class Stage0 (line 18) | class Stage0(Scene): method construct (line 19) | def construct(self): FILE: manim_animations/dataloaders/stage_1.py class Stage01 (line 17) | class Stage01(Scene): method construct (line 18) | def construct(self): FILE: manim_animations/dataloaders/stage_2.py class Stage2 (line 18) | class Stage2(Scene): method construct (line 19) | def construct(self): FILE: manim_animations/dataloaders/stage_3.py class Stage3 (line 17) | class Stage3(Scene): method construct (line 18) | def construct(self): FILE: manim_animations/dataloaders/stage_4.py class Stage4 (line 17) | class Stage4(Scene): method construct (line 18) | def construct(self): FILE: manim_animations/dataloaders/stage_5.py class Stage5 (line 17) | class Stage5(Scene): method construct (line 18) | def construct(self): FILE: manim_animations/dataloaders/stage_6.py class Stage6 (line 18) | class Stage6(Scene): method construct (line 19) | def construct(self): FILE: manim_animations/dataloaders/stage_7.py class Stage7 (line 17) | class Stage7(Scene): method construct (line 18) | def construct(self): FILE: src/accelerate/accelerator.py class Accelerator (line 184) | class Accelerator: method __init__ (line 279) | def __init__( method deepspeed_plugin (line 639) | def deepspeed_plugin(self): method use_distributed (line 651) | def use_distributed(self): method multi_device (line 658) | def multi_device(self): method distributed_type (line 671) | def distributed_type(self): method num_processes (line 675) | def num_processes(self): method process_index (line 679) | def process_index(self): method local_process_index (line 683) | def local_process_index(self): method device (line 687) | def device(self): method split_batches (line 691) | def split_batches(self): method dispatch_batches (line 695) | def dispatch_batches(self): method even_batches (line 699) | def even_batches(self): method even_batches (line 703) | def even_batches(self, value: bool): method use_seedable_sampler (line 707) | def use_seedable_sampler(self): method non_blocking (line 711) | def non_blocking(self): method use_stateful_dataloader (line 715) | def use_stateful_dataloader(self): method project_dir (line 721) | def project_dir(self): method logging_dir (line 725) | def logging_dir(self): method save_iteration (line 729) | def save_iteration(self): method is_main_process (line 733) | def is_main_process(self): method is_local_main_process (line 738) | def is_local_main_process(self): method is_last_process (line 743) | def is_last_process(self): method mixed_precision (line 747) | def mixed_precision(self): method is_fsdp2 (line 751) | def is_fsdp2(self): method is_composable_parallelism_enabled (line 755) | def is_composable_parallelism_enabled(self): method parallelism_config (line 759) | def parallelism_config(self) -> Union[ParallelismConfig, None]: method torch_device_mesh (line 763) | def torch_device_mesh(self): method should_save_model (line 767) | def should_save_model(self): method tensor_parallel_rank (line 783) | def tensor_parallel_rank(self) -> int: method pipeline_parallel_rank (line 795) | def pipeline_parallel_rank(self) -> int: method context_parallel_rank (line 802) | def context_parallel_rank(self) -> int: method data_parallel_rank (line 809) | def data_parallel_rank(self) -> int: method data_parallel_shard_rank (line 821) | def data_parallel_shard_rank(self) -> int: method split_between_processes (line 833) | def split_between_processes(self, inputs: list | tuple | dict | torch.... method on_main_process (line 874) | def on_main_process(self, function: Callable[..., Any] | None = None): method on_local_main_process (line 913) | def on_local_main_process(self, function: Callable[..., Any] | None = ... method on_last_process (line 955) | def on_last_process(self, function: Callable[..., Any]): method on_process (line 994) | def on_process(self, function: Callable[..., Any] | None = None, proce... method on_local_process (line 1039) | def on_local_process(self, function: Callable[..., Any] | None = None,... method main_process_first (line 1088) | def main_process_first(self): method local_main_process_first (line 1110) | def local_main_process_first(self): method no_sync (line 1132) | def no_sync(self, model): method trigger_sync_in_backward (line 1182) | def trigger_sync_in_backward(model): method _do_sync (line 1229) | def _do_sync(self): method sync_gradients (line 1239) | def sync_gradients(self): method sync_gradients (line 1243) | def sync_gradients(self, sync_gradients): method gradient_accumulation_steps (line 1247) | def gradient_accumulation_steps(self): method gradient_accumulation_steps (line 1251) | def gradient_accumulation_steps(self, gradient_accumulation_steps): method accumulate (line 1255) | def accumulate(self, *models): method join_uneven_inputs (line 1300) | def join_uneven_inputs(self, joinables, even_batches=None): method print (line 1382) | def print(self, *args, **kwargs): method _prepare_one (line 1397) | def _prepare_one(self, obj, first_pass=False, device_placement=None): method prepare (line 1414) | def prepare(self, *args, device_placement=None): method _prepare_tp (line 1580) | def _prepare_tp(self, *args): method _prepare_cp (line 1658) | def _prepare_cp(self, *args): method _prepare_fsdp2 (line 1673) | def _prepare_fsdp2(self, *args): method prepare_model (line 1765) | def prepare_model( method _prepare_ao (line 2059) | def _prepare_ao(self, *args): method _prepare_te (line 2087) | def _prepare_te(self, *args): method _prepare_deepspeed (line 2123) | def _prepare_deepspeed(self, *args): method deepspeed_ulysses_dl_adapter (line 2475) | def deepspeed_ulysses_dl_adapter(self, dl, model): method _prepare_megatron_lm (line 2495) | def _prepare_megatron_lm(self, *args): method _prepare_device_mesh (line 2598) | def _prepare_device_mesh(self): method _prepare_msamp (line 2608) | def _prepare_msamp(self, *args, device_placement): method prepare_data_loader (line 2663) | def prepare_data_loader( method prepare_optimizer (line 2722) | def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_p... method prepare_scheduler (line 2766) | def prepare_scheduler(self, scheduler: LRScheduler): method backward (line 2807) | def backward(self, loss, **kwargs): method set_trigger (line 2841) | def set_trigger(self): method check_trigger (line 2867) | def check_trigger(self): method unscale_gradients (line 2900) | def unscale_gradients(self, optimizer=None): method clip_grad_norm_ (line 2935) | def clip_grad_norm_(self, parameters, max_norm, norm_type=2): method clip_grad_value_ (line 2998) | def clip_grad_value_(self, parameters, clip_value): method gather (line 3025) | def gather(self, tensor): method gather_for_metrics (line 3057) | def gather_for_metrics(self, input_data, use_gather_object=False): method reduce (line 3130) | def reduce(self, tensor, reduction="sum", scale=1.0): method pad_across_processes (line 3166) | def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=F... method unwrap_model (line 3201) | def unwrap_model(self, model, keep_fp32_wrapper: bool = True, keep_tor... method wait_for_everyone (line 3235) | def wait_for_everyone(self): method init_trackers (line 3260) | def init_trackers(self, project_name: str, config: dict | None = None,... method get_tracker (line 3310) | def get_tracker(self, name: str, unwrap: bool = False): method log (line 3343) | def log(self, values: dict, step: int | None = None, log_kwargs: dict ... method end_training (line 3372) | def end_training(self): method save (line 3393) | def save(self, obj, f, safe_serialization=False): method save_model (line 3423) | def save_model( method register_save_state_pre_hook (line 3536) | def register_save_state_pre_hook(self, hook: Callable[..., None]) -> h... method save_state (line 3568) | def save_state(self, output_dir: str | None = None, safe_serialization... method register_load_state_pre_hook (line 3703) | def register_load_state_pre_hook(self, hook: Callable[..., None]) -> h... method load_state (line 3734) | def load_state(self, input_dir: str | None = None, load_kwargs: dict |... method free_memory (line 3886) | def free_memory(self, *objects): method clear (line 3915) | def clear(self, *objects): method _get_named_parameters (line 3933) | def _get_named_parameters(self, *args, drop_refs=False): method _get_devices (line 3969) | def _get_devices(self, *args): method get_state_dict (line 3986) | def get_state_dict(self, model, unwrap=True): method register_for_checkpointing (line 4058) | def register_for_checkpointing(self, *objects): method maybe_context_parallel (line 4095) | def maybe_context_parallel( method autocast (line 4162) | def autocast(self, autocast_handler: AutocastKwargs = None): method profile (line 4187) | def profile(self, profile_handler: ProfileKwargs | None = None): method optimizer_step_was_skipped (line 4247) | def optimizer_step_was_skipped(self): method skip_first_batches (line 4257) | def skip_first_batches(self, dataloader, num_batches: int = 0): method __deepcopy__ (line 4289) | def __deepcopy__(self, memo): method verify_device_map (line 4293) | def verify_device_map(self, model: torch.nn.Module) -> bool: method lomo_backward (line 4304) | def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> N... method fp8_backend (line 4329) | def fp8_backend(self) -> FP8BackendType: FILE: src/accelerate/big_modeling.py function init_empty_weights (line 62) | def init_empty_weights(include_buffers: Optional[bool] = None): function init_on_device (line 98) | def init_on_device(device: torch.device, include_buffers: Optional[bool]... function cpu_offload (line 179) | def cpu_offload( function cpu_offload_with_hook (line 225) | def cpu_offload_with_hook( function disk_offload (line 269) | def disk_offload( function dispatch_model (line 315) | def dispatch_model( function load_checkpoint_and_dispatch (line 522) | def load_checkpoint_and_dispatch( function attach_layerwise_casting_hooks (line 663) | def attach_layerwise_casting_hooks( function _attach_layerwise_casting_hooks (line 724) | def _attach_layerwise_casting_hooks( function _attach_context_parallel_hooks (line 762) | def _attach_context_parallel_hooks( FILE: src/accelerate/checkpointing.py function save_accelerator_state (line 63) | def save_accelerator_state( function load_accelerator_state (line 183) | def load_accelerator_state( function save_custom_state (line 321) | def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool... function load_custom_state (line 331) | def load_custom_state(obj, path, index: int = 0): FILE: src/accelerate/commands/accelerate_cli.py function main (line 28) | def main(): FILE: src/accelerate/commands/config/__init__.py function get_config_parser (line 25) | def get_config_parser(subparsers=None): function main (line 39) | def main(): FILE: src/accelerate/commands/config/cluster.py function get_cluster_input (line 59) | def get_cluster_input(): FILE: src/accelerate/commands/config/config.py function get_user_input (line 31) | def get_user_input(): function config_command_parser (line 44) | def config_command_parser(subparsers=None): function config_command (line 66) | def config_command(args): function main (line 82) | def main(): FILE: src/accelerate/commands/config/config_args.py function load_config_from_file (line 43) | def load_config_from_file(config_file): class BaseConfig (line 76) | class BaseConfig: method to_dict (line 83) | def to_dict(self): method process_config (line 103) | def process_config(config_dict): method from_json_file (line 129) | def from_json_file(cls, json_file=None): method to_json_file (line 143) | def to_json_file(self, json_file): method from_yaml_file (line 149) | def from_yaml_file(cls, yaml_file=None): method to_yaml_file (line 162) | def to_yaml_file(self, yaml_file): method __post_init__ (line 166) | def __post_init__(self): class ClusterConfig (line 179) | class ClusterConfig(BaseConfig): method __post_init__ (line 219) | def __post_init__(self): class SageMakerConfig (line 236) | class SageMakerConfig(BaseConfig): FILE: src/accelerate/commands/config/config_utils.py function _ask_field (line 47) | def _ask_field(input_text, convert_value=None, default=None, error_messa... function _ask_options (line 60) | def _ask_options(input_text, options=[], convert_value=None, default=0): function _convert_compute_environment (line 66) | def _convert_compute_environment(value): function _convert_distributed_mode (line 71) | def _convert_distributed_mode(value): function _convert_dynamo_backend (line 90) | def _convert_dynamo_backend(value): function _convert_mixed_precision (line 95) | def _convert_mixed_precision(value): function _convert_sagemaker_distributed_mode (line 100) | def _convert_sagemaker_distributed_mode(value): function _convert_fp8_backend (line 105) | def _convert_fp8_backend(value): function _convert_yes_no_to_bool (line 110) | def _convert_yes_no_to_bool(value): class SubcommandHelpFormatter (line 114) | class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter): method _format_usage (line 119) | def _format_usage(self, usage, actions, groups, prefix): FILE: src/accelerate/commands/config/default.py function write_basic_config (line 37) | def write_basic_config(mixed_precision="no", save_location: str = defaul... function default_command_parser (line 142) | def default_command_parser(parser, parents): function default_config_command (line 169) | def default_config_command(args): FILE: src/accelerate/commands/config/sagemaker.py function _create_iam_role_for_sagemaker (line 38) | def _create_iam_role_for_sagemaker(role_name): function _get_iam_role_arn (line 92) | def _get_iam_role_arn(role_name): function get_sagemaker_input (line 97) | def get_sagemaker_input(): FILE: src/accelerate/commands/config/update.py function update_config (line 26) | def update_config(args): function update_command_parser (line 44) | def update_command_parser(parser, parents): function update_config_command (line 61) | def update_config_command(args): FILE: src/accelerate/commands/env.py function env_command_parser (line 39) | def env_command_parser(subparsers=None): function env_command (line 54) | def env_command(args): function main (line 135) | def main() -> int: FILE: src/accelerate/commands/estimate.py function verify_on_hub (line 40) | def verify_on_hub(repo: str, token: Optional[str] = None): function check_has_model (line 50) | def check_has_model(error): function create_empty_model (line 66) | def create_empty_model( function create_ascii_table (line 146) | def create_ascii_table(headers: list, rows: list, title: str): function estimate_command_parser (line 187) | def estimate_command_parser(subparsers=None): function estimate_training_usage (line 224) | def estimate_training_usage(bytes: int, mixed_precision: str, msamp_conf... function gather_data (line 259) | def gather_data(args): function estimate_command (line 294) | def estimate_command(args): function main (line 311) | def main(): FILE: src/accelerate/commands/launch.py function clean_option (line 83) | def clean_option(option): class CustomHelpFormatter (line 91) | class CustomHelpFormatter(argparse.HelpFormatter): method __init__ (line 98) | def __init__(self, *args, **kwargs): method add_argument (line 108) | def add_argument(self, action: argparse.Action): method end_section (line 134) | def end_section(self): function launch_command_parser (line 141) | def launch_command_parser(subparsers=None): function simple_launcher (line 986) | def simple_launcher(args): function multi_gpu_launcher (line 998) | def multi_gpu_launcher(args): function deepspeed_launcher (line 1033) | def deepspeed_launcher(args): function tpu_launcher (line 1086) | def tpu_launcher(args): function tpu_pod_launcher (line 1117) | def tpu_pod_launcher(args): function sagemaker_launcher (line 1176) | def sagemaker_launcher(sagemaker_config: SageMakerConfig, args): function _validate_launch_command (line 1196) | def _validate_launch_command(args): function launch_command (line 1382) | def launch_command(args): function main (line 1408) | def main(): FILE: src/accelerate/commands/menu/cursor.py class CursorInfo (line 29) | class CursorInfo(ctypes.Structure): function hide_cursor (line 34) | def hide_cursor(): function show_cursor (line 46) | def show_cursor(): function hide (line 59) | def hide(): FILE: src/accelerate/commands/menu/helpers.py class Direction (line 30) | class Direction(enum.Enum): function forceWrite (line 35) | def forceWrite(content, end=""): function writeColor (line 40) | def writeColor(content, color, end=""): function reset_cursor (line 44) | def reset_cursor(): function move_cursor (line 48) | def move_cursor(num_lines: int, direction: str): function clear_line (line 52) | def clear_line(): function linebreak (line 57) | def linebreak(): FILE: src/accelerate/commands/menu/input.py function mark (line 23) | def mark(key: str): function mark_multiple (line 37) | def mark_multiple(*keys: list[str]): class KeyHandler (line 51) | class KeyHandler(type): method __new__ (line 56) | def __new__(cls, name, bases, attrs): method handle_input (line 69) | def handle_input(cls): function register (line 82) | def register(cls): FILE: src/accelerate/commands/menu/keymap.py function get_raw_chars (line 63) | def get_raw_chars(): function get_character (line 112) | def get_character(): FILE: src/accelerate/commands/menu/selection_menu.py class BulletMenu (line 37) | class BulletMenu: method __init__ (line 42) | def __init__(self, prompt: Optional[str] = None, choices: list = []): method write_choice (line 51) | def write_choice(self, index, end: str = ""): method print_choice (line 57) | def print_choice(self, index: int): method move_direction (line 66) | def move_direction(self, direction: Direction, num_spaces: int = 1): method move_up (line 83) | def move_up(self): method move_down (line 87) | def move_down(self): method select (line 91) | def select(self): method interrupt (line 96) | def interrupt(self): method select_row (line 101) | def select_row(self): method run (line 116) | def run(self, default_choice: int = 0): FILE: src/accelerate/commands/merge.py function merge_command (line 26) | def merge_command(args): function merge_command_parser (line 32) | def merge_command_parser(subparsers=None): function main (line 62) | def main(): FILE: src/accelerate/commands/test.py function test_command_parser (line 22) | def test_command_parser(subparsers=None): function test_command (line 44) | def test_command(args): function main (line 58) | def main(): FILE: src/accelerate/commands/to_fsdp2.py class ConversionStatus (line 26) | class ConversionStatus(enum.Enum): function _validate_to_fsdp2_args (line 71) | def _validate_to_fsdp2_args(args): function convert_config_to_fsdp2 (line 82) | def convert_config_to_fsdp2(config: dict) -> dict: function to_fsdp2_command_parser (line 126) | def to_fsdp2_command_parser(subparsers=None): function load_config (line 153) | def load_config(config_file: str) -> dict: function to_fsdp2_command (line 162) | def to_fsdp2_command(args): FILE: src/accelerate/commands/tpu.py function tpu_command_parser (line 29) | def tpu_command_parser(subparsers=None): function tpu_command_launcher (line 90) | def tpu_command_launcher(args): function main (line 153) | def main(): FILE: src/accelerate/commands/utils.py class _StoreAction (line 18) | class _StoreAction(argparse.Action): method __init__ (line 23) | def __init__(self, *args, **kwargs): method __call__ (line 33) | def __call__(self, parser, namespace, values, option_string=None): class _StoreConstAction (line 40) | class _StoreConstAction(_StoreAction): method __init__ (line 45) | def __init__(self, option_strings, dest, const, default=None, required... method __call__ (line 56) | def __call__(self, parser, namespace, values, option_string=None): class _StoreTrueAction (line 60) | class _StoreTrueAction(_StoreConstAction): method __init__ (line 65) | def __init__( class CustomArgumentGroup (line 78) | class CustomArgumentGroup(argparse._ArgumentGroup): method _add_action (line 84) | def _add_action(self, action): class CustomArgumentParser (line 105) | class CustomArgumentParser(argparse.ArgumentParser): method add_argument (line 111) | def add_argument(self, *args, **kwargs): method add_argument_group (line 120) | def add_argument_group(self, *args, **kwargs): FILE: src/accelerate/data_loader.py class SeedableRandomSampler (line 73) | class SeedableRandomSampler(RandomSampler): method __init__ (line 84) | def __init__(self, *args, **kwargs): method __iter__ (line 91) | def __iter__(self): method set_epoch (line 105) | def set_epoch(self, epoch: int): class BatchSamplerShard (line 110) | class BatchSamplerShard(BatchSampler): method __init__ (line 145) | def __init__( method total_length (line 172) | def total_length(self): method __len__ (line 175) | def __len__(self): method __iter__ (line 193) | def __iter__(self): method _iter_with_split (line 196) | def _iter_with_split(self): method _iter_with_no_split (line 218) | def _iter_with_no_split(self): class IterableDatasetShard (line 266) | class IterableDatasetShard(IterableDataset): method __init__ (line 299) | def __init__( method set_epoch (line 320) | def set_epoch(self, epoch): method __len__ (line 325) | def __len__(self): method __iter__ (line 332) | def __iter__(self): class DataLoaderStateMixin (line 365) | class DataLoaderStateMixin: method __init_subclass__ (line 386) | def __init_subclass__(cls, **kwargs): method reset (line 390) | def reset(self): method begin (line 394) | def begin(self): method end (line 403) | def end(self): class DataLoaderAdapter (line 408) | class DataLoaderAdapter: method __init__ (line 414) | def __init__(self, dataset, use_stateful_dataloader=False, batch_sampl... method __getattr__ (line 438) | def __getattr__(self, name): method state_dict (line 445) | def state_dict(self): method load_state_dict (line 448) | def load_state_dict(self, state_dict): method __class__ (line 452) | def __class__(self): method __len__ (line 460) | def __len__(self): method adjust_state_dict_for_prefetch (line 463) | def adjust_state_dict_for_prefetch(self): method _update_state_dict (line 488) | def _update_state_dict(self): class DataLoaderShard (line 502) | class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): method __init__ (line 537) | def __init__( method adjust_state_dict_for_prefetch (line 560) | def adjust_state_dict_for_prefetch(self): method __iter__ (line 568) | def __iter__(self): method __reduce__ (line 604) | def __reduce__(self): method set_epoch (line 613) | def set_epoch(self, epoch: int): method total_batch_size (line 633) | def total_batch_size(self): method total_dataset_length (line 642) | def total_dataset_length(self): method get_sampler (line 648) | def get_sampler(self): method set_sampler (line 651) | def set_sampler(self, sampler): class MpDeviceLoaderWrapper (line 664) | class MpDeviceLoaderWrapper(xpl.MpDeviceLoader): method __init__ (line 681) | def __init__(self, dataloader: DataLoaderShard, device: torch.device): method __iter__ (line 687) | def __iter__(self): method set_epoch (line 693) | def set_epoch(self, epoch: int): method total_batch_size (line 698) | def total_batch_size(self): method total_dataset_length (line 702) | def total_dataset_length(self): method batch_sampler (line 706) | def batch_sampler(self): method dataloader (line 710) | def dataloader(self): class DataLoaderDispatcher (line 714) | class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin): method __init__ (line 741) | def __init__( method _fetch_batches (line 796) | def _fetch_batches(self, iterator): method __iter__ (line 862) | def __iter__(self): method set_epoch (line 938) | def set_epoch(self, epoch: int): method __len__ (line 947) | def __len__(self): method __reduce__ (line 956) | def __reduce__(self): method total_batch_size (line 966) | def total_batch_size(self): method total_dataset_length (line 972) | def total_dataset_length(self): method get_sampler (line 975) | def get_sampler(self): method set_sampler (line 978) | def set_sampler(self, sampler): function get_sampler (line 988) | def get_sampler(dataloader): function prepare_data_loader (line 1006) | def prepare_data_loader( class SkipBatchSampler (line 1322) | class SkipBatchSampler(BatchSampler): method __init__ (line 1328) | def __init__(self, batch_sampler, skip_batches=0): method __iter__ (line 1332) | def __iter__(self): method total_length (line 1338) | def total_length(self): method __len__ (line 1341) | def __len__(self): class SkipDataLoader (line 1345) | class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin): method __init__ (line 1359) | def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=Fa... method __iter__ (line 1364) | def __iter__(self): method __len__ (line 1372) | def __len__(self): method __reduce__ (line 1375) | def __reduce__(self): function skip_first_batches (line 1385) | def skip_first_batches(dataloader, num_batches=0): FILE: src/accelerate/hooks.py function _compiler_disable (line 40) | def _compiler_disable(fn): class ModelHook (line 58) | class ModelHook: method init_hook (line 70) | def init_hook(self, module): method pre_forward (line 79) | def pre_forward(self, module, *args, **kwargs): method post_forward (line 93) | def post_forward(self, module, output): method detach_hook (line 106) | def detach_hook(self, module): class SequentialHook (line 116) | class SequentialHook(ModelHook): method __init__ (line 121) | def __init__(self, *hooks): method init_hook (line 124) | def init_hook(self, module): method pre_forward (line 130) | def pre_forward(self, module, *args, **kwargs): method post_forward (line 136) | def post_forward(self, module, output): method detach_hook (line 141) | def detach_hook(self, module): function add_hook_to_module (line 147) | def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool ... function remove_hook_from_module (line 205) | def remove_hook_from_module(module: nn.Module, recurse=False): class AlignDevicesHook (line 242) | class AlignDevicesHook(ModelHook): method __init__ (line 262) | def __init__( method __repr__ (line 291) | def __repr__(self): method init_hook (line 298) | def init_hook(self, module): method pre_forward (line 346) | def pre_forward(self, module, *args, **kwargs): method post_forward (line 392) | def post_forward(self, module, output): method detach_hook (line 423) | def detach_hook(self, module): function attach_execution_device_hook (line 431) | def attach_execution_device_hook( function attach_align_device_hook (line 479) | def attach_align_device_hook( function remove_hook_from_submodules (line 562) | def remove_hook_from_submodules(module: nn.Module): function attach_align_device_hook_on_blocks (line 574) | def attach_align_device_hook_on_blocks( class CpuOffload (line 708) | class CpuOffload(ModelHook): method __init__ (line 723) | def __init__( method init_hook (line 732) | def init_hook(self, module): method pre_forward (line 736) | def pre_forward(self, module, *args, **kwargs): class UserCpuOffloadHook (line 755) | class UserCpuOffloadHook: method __init__ (line 761) | def __init__(self, model, hook): method offload (line 765) | def offload(self): method remove (line 768) | def remove(self): class LayerwiseCastingHook (line 772) | class LayerwiseCastingHook(ModelHook): method __init__ (line 781) | def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dt... method init_hook (line 786) | def init_hook(self, module: torch.nn.Module): method pre_forward (line 791) | def pre_forward(self, module: torch.nn.Module, *args, **kwargs): method post_forward (line 796) | def post_forward(self, module: torch.nn.Module, output): FILE: src/accelerate/inference.py function generate_device_map (line 31) | def generate_device_map( function find_pippy_batch_size (line 60) | def find_pippy_batch_size(args, kwargs): function build_pipeline (line 75) | def build_pipeline(model, split_points, args, kwargs, num_chunks): function pippy_forward (line 101) | def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs): function prepare_pippy (line 126) | def prepare_pippy( FILE: src/accelerate/launchers.py function test_launch (line 36) | def test_launch(): function notebook_launcher (line 41) | def notebook_launcher( function debug_launcher (line 276) | def debug_launcher(function, args=(), num_processes=2): FILE: src/accelerate/local_sgd.py class LocalSGD (line 19) | class LocalSGD: method __enter__ (line 41) | def __enter__(self): method __exit__ (line 48) | def __exit__(self, type, value, tb): method __init__ (line 54) | def __init__(self, accelerator: Accelerator, model: torch.nn.Module, l... method step (line 88) | def step(self): method _sync_and_avg_model_params (line 99) | def _sync_and_avg_model_params(self): FILE: src/accelerate/logging.py class MultiProcessAdapter (line 23) | class MultiProcessAdapter(logging.LoggerAdapter): method _should_log (line 34) | def _should_log(main_process_only): method process (line 39) | def process(self, msg, kwargs): method log (line 49) | def log(self, level, msg, *args, **kwargs): method warning_once (line 82) | def warning_once(self, *args, **kwargs): function get_logger (line 93) | def get_logger(name: str, log_level: str | None = None): FILE: src/accelerate/optimizer.py function move_to_device (line 28) | def move_to_device(state, device): class AcceleratedOptimizer (line 38) | class AcceleratedOptimizer(torch.optim.Optimizer): method __init__ (line 55) | def __init__(self, optimizer, device_placement=True, scaler=None): method state (line 78) | def state(self): method state (line 82) | def state(self, state): method param_groups (line 86) | def param_groups(self): method param_groups (line 90) | def param_groups(self, param_groups): method defaults (line 94) | def defaults(self): method defaults (line 98) | def defaults(self, defaults): method add_param_group (line 101) | def add_param_group(self, param_group): method load_state_dict (line 104) | def load_state_dict(self, state_dict): method state_dict (line 109) | def state_dict(self): method zero_grad (line 112) | def zero_grad(self, set_to_none=None): method train (line 124) | def train(self): method eval (line 138) | def eval(self): method step (line 145) | def step(self, closure=None): method _switch_parameters (line 183) | def _switch_parameters(self, parameters_map): method step_was_skipped (line 188) | def step_was_skipped(self): method __getstate__ (line 192) | def __getstate__(self): method __setstate__ (line 200) | def __setstate__(self, state): function patch_optimizer_step (line 208) | def patch_optimizer_step(accelerated_optimizer: AcceleratedOptimizer, me... FILE: src/accelerate/parallelism_config.py class ParallelismConfig (line 34) | class ParallelismConfig: method __repr__ (line 85) | def __repr__(self): method to_json (line 100) | def to_json(self): method dp_dim_names (line 114) | def dp_dim_names(self): method non_dp_dim_names (line 124) | def non_dp_dim_names(self): method dp_shard_cp_dim_names (line 136) | def dp_shard_cp_dim_names(self): method dp_cp_dim_names (line 146) | def dp_cp_dim_names(self): method fsdp_dim_names (line 158) | def fsdp_dim_names(self): method total_size (line 167) | def total_size(self): method non_data_parallel_size (line 172) | def non_data_parallel_size(self): method data_parallel_size (line 177) | def data_parallel_size(self): method dp_replicate_enabled (line 182) | def dp_replicate_enabled(self): method dp_shard_enabled (line 187) | def dp_shard_enabled(self): method tp_enabled (line 192) | def tp_enabled(self): method cp_enabled (line 197) | def cp_enabled(self): method sp_enabled (line 202) | def sp_enabled(self): method active_mesh_dims (line 207) | def active_mesh_dims(self): method build_device_mesh (line 211) | def build_device_mesh(self, device_type: str): method get_device_mesh (line 246) | def get_device_mesh(self, device_type: Optional[str] = None): method _get_mesh (line 260) | def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]: method __post_init__ (line 274) | def __post_init__(self): method _set_size (line 350) | def _set_size(self, parallelism: str, size: int): method _validate_accelerator (line 355) | def _validate_accelerator(self, accelerator: "Accelerator"): FILE: src/accelerate/scheduler.py class AcceleratedScheduler (line 25) | class AcceleratedScheduler: method __init__ (line 47) | def __init__(self, scheduler, optimizers, step_with_optimizer: bool = ... method step (line 54) | def step(self, *args, **kwargs): method get_last_lr (line 85) | def get_last_lr(self): method state_dict (line 88) | def state_dict(self): method load_state_dict (line 91) | def load_state_dict(self, state_dict): method get_lr (line 94) | def get_lr(self): method print_lr (line 97) | def print_lr(self, *args, **kwargs): FILE: src/accelerate/state.py function is_initialized (line 78) | def is_initialized() -> bool: function do_nothing (line 87) | def do_nothing(*args, **kwargs): class ThreadLocalSharedDict (line 91) | class ThreadLocalSharedDict(threading.local): method __init__ (line 108) | def __init__(self, thread_local: bool = False): method __get__ (line 111) | def __get__(self, obj, objtype=None): method __set__ (line 114) | def __set__(self, obj, value): class PartialState (line 123) | class PartialState: method __init__ (line 177) | def __init__(self, cpu: bool = False, **kwargs): method __repr__ (line 330) | def __repr__(self) -> str: method _reset_state (line 340) | def _reset_state(): method initialized (line 345) | def initialized(self) -> bool: method use_distributed (line 350) | def use_distributed(self): method is_last_process (line 357) | def is_last_process(self) -> bool: method is_main_process (line 362) | def is_main_process(self) -> bool: method is_local_main_process (line 369) | def is_local_main_process(self) -> bool: method wait_for_everyone (line 377) | def wait_for_everyone(self): method _goes_first (line 416) | def _goes_first(self, is_main: bool): method split_between_processes (line 426) | def split_between_processes(self, inputs: list | tuple | dict | torch.... method main_process_first (line 517) | def main_process_first(self): method local_main_process_first (line 538) | def local_main_process_first(self): method on_main_process (line 558) | def on_main_process(self, function: Callable[..., Any] | None = None): method on_local_main_process (line 588) | def on_local_main_process(self, function: Callable[..., Any] | None = ... method on_last_process (line 619) | def on_last_process(self, function: Callable[..., Any]): method on_process (line 647) | def on_process(self, function: Callable[..., Any] | None = None, proce... method on_local_process (line 680) | def on_local_process(self, function: Callable[..., Any] | None = None,... method print (line 716) | def print(self, *args, **kwargs): method default_device (line 721) | def default_device(self) -> torch.device: method _prepare_backend (line 758) | def _prepare_backend( method set_device (line 822) | def set_device(self): method destroy_process_group (line 848) | def destroy_process_group(self, group=None): method __getattr__ (line 858) | def __getattr__(self, name: str): class AcceleratorState (line 871) | class AcceleratorState: method __init__ (line 902) | def __init__( method initialized (line 1041) | def initialized(self) -> bool: method __repr__ (line 1044) | def __repr__(self): method _check_initialized (line 1050) | def _check_initialized(self, mixed_precision=None, cpu=None): method mixed_precision (line 1064) | def mixed_precision(self): method _reset_state (line 1078) | def _reset_state(reset_partial_state: bool = False): method destroy_process_group (line 1084) | def destroy_process_group(self, group=None): method fork_launched (line 1093) | def fork_launched(self): method use_distributed (line 1097) | def use_distributed(self): method is_fsdp2 (line 1104) | def is_fsdp2(self) -> bool: method is_last_process (line 1108) | def is_last_process(self) -> bool: method is_main_process (line 1113) | def is_main_process(self) -> bool: method is_local_main_process (line 1118) | def is_local_main_process(self) -> bool: method wait_for_everyone (line 1122) | def wait_for_everyone(self): method split_between_processes (line 1126) | def split_between_processes(self, inputs: list | tuple | dict | torch.... method main_process_first (line 1168) | def main_process_first(self): method local_main_process_first (line 1178) | def local_main_process_first(self): method deepspeed_plugin (line 1188) | def deepspeed_plugin(self): method get_deepspeed_plugin (line 1202) | def get_deepspeed_plugin(self, name: str): method select_deepspeed_plugin (line 1209) | def select_deepspeed_plugin(self, name: str | None = None): method print (line 1218) | def print(self, *args, **kwargs): method __getattr__ (line 1221) | def __getattr__(self, name: str): class GradientState (line 1234) | class GradientState: method __init__ (line 1259) | def __init__(self, gradient_accumulation_plugin: GradientAccumulationP... method num_steps (line 1274) | def num_steps(self) -> int: method adjust_scheduler (line 1279) | def adjust_scheduler(self) -> bool: method sync_with_dataloader (line 1284) | def sync_with_dataloader(self) -> bool: method initialized (line 1289) | def initialized(self) -> bool: method end_of_dataloader (line 1294) | def end_of_dataloader(self) -> bool: method remainder (line 1301) | def remainder(self) -> int: method __repr__ (line 1307) | def __repr__(self): method is_xla_gradients_synced (line 1316) | def is_xla_gradients_synced(self): method is_xla_gradients_synced (line 1323) | def is_xla_gradients_synced(self, is_synced): method _set_sync_gradients (line 1327) | def _set_sync_gradients(self, sync_gradients): method _add_dataloader (line 1338) | def _add_dataloader(self, dataloader): method _remove_dataloader (line 1344) | def _remove_dataloader(self, dataloader): method active_dataloader (line 1352) | def active_dataloader(self): method dataloader_references (line 1356) | def dataloader_references(self): method dataloader_references (line 1361) | def dataloader_references(self, references): method in_dataloader (line 1367) | def in_dataloader(self) -> bool: method _reset_state (line 1372) | def _reset_state(): FILE: src/accelerate/test_utils/examples.py function get_function_contents_by_name (line 26) | def get_function_contents_by_name(lines: list[str], name: str): function clean_lines (line 52) | def clean_lines(lines: list[str]): function compare_against_test (line 63) | def compare_against_test( FILE: src/accelerate/test_utils/scripts/external_deps/test_checkpointing.py function get_dataloaders (line 33) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, mode... function evaluation_loop (line 78) | def evaluation_loop(accelerator, model, eval_dataloader, metric): function training_function (line 106) | def training_function(config, args): function main (line 229) | def main(): FILE: src/accelerate/test_utils/scripts/external_deps/test_ds_alst_ulysses_sp.py function collate_fn (line 61) | def collate_fn(batch): FILE: src/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py class NoiseModel (line 40) | class NoiseModel(torch.nn.Module): method __init__ (line 41) | def __init__(self, noise_factor=0.1): method forward (line 45) | def forward(self, loss): function get_dataloaders (line 49) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, mode... function single_model_training (line 106) | def single_model_training(config, args): function multiple_model_training (line 189) | def multiple_model_training(config, args): function main (line 302) | def main(): FILE: src/accelerate/test_utils/scripts/external_deps/test_metrics.py class ListHandler (line 37) | class ListHandler(logging.Handler): method __init__ (line 38) | def __init__(self, *args, **kwargs): method emit (line 42) | def emit(self, record): function get_basic_setup (line 46) | def get_basic_setup(accelerator, num_samples=82, batch_size=16): function get_dataloader (line 58) | def get_dataloader(accelerator: Accelerator, use_longest=False): function get_mrpc_setup (line 83) | def get_mrpc_setup(dispatch_batches, split_batches): function generate_predictions (line 97) | def generate_predictions(model, dataloader, accelerator): function test_torch_metrics (line 113) | def test_torch_metrics( function test_mrpc (line 123) | def test_mrpc(dispatch_batches: bool = False, split_batches: bool = False): function test_gather_for_metrics_with_non_tensor_objects_iterable_dataset (line 156) | def test_gather_for_metrics_with_non_tensor_objects_iterable_dataset(): function test_gather_for_metrics_with_iterable_dataset (line 188) | def test_gather_for_metrics_with_iterable_dataset(): function test_gather_for_metrics_drop_last (line 224) | def test_gather_for_metrics_drop_last(): function main (line 243) | def main(): function _mp_fn (line 301) | def _mp_fn(index): FILE: src/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py function b2mb (line 43) | def b2mb(x): class TorchTracemalloc (line 48) | class TorchTracemalloc: method __enter__ (line 49) | def __enter__(self): method __exit__ (line 85) | def __exit__(self, *exc): function get_dataloaders (line 124) | def get_dataloaders( function training_function (line 182) | def training_function(config, args): function main (line 278) | def main(): FILE: src/accelerate/test_utils/scripts/external_deps/test_performance.py function get_dataloaders (line 37) | def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, mode... function training_function (line 83) | def training_function(config, args): function main (line 248) | def main(): FILE: src/accelerate/test_utils/scripts/external_deps/test_pippy.py function get_model_and_data_for_text (line 34) | def get_model_and_data_for_text(model_name, device, num_processes: int =... function test_bert (line 48) | def test_bert(batch_size: int = 2): function test_gpt2 (line 64) | def test_gpt2(batch_size: int = 2): FILE: src/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py function init_torch_dist_then_launch_deepspeed (line 29) | def init_torch_dist_then_launch_deepspeed(): function main (line 54) | def main(): FILE: src/accelerate/test_utils/scripts/test_cli.py function main (line 19) | def main(): FILE: src/accelerate/test_utils/scripts/test_ddp_comm_hook.py class MockModel (line 20) | class MockModel(torch.nn.Module): method __init__ (line 21) | def __init__(self): method forward (line 26) | def forward(self, x, rank): function _run_and_get_grads (line 30) | def _run_and_get_grads(model, rank): function test_ddp_comm_hook (line 39) | def test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option): function main (line 60) | def main(): FILE: src/accelerate/test_utils/scripts/test_distributed_data_loop.py class DummyDataset (line 42) | class DummyDataset(Dataset): method __len__ (line 43) | def __len__(self): method __getitem__ (line 46) | def __getitem__(self, index): class DummyIterableDataset (line 65) | class DummyIterableDataset(IterableDataset): method __init__ (line 66) | def __init__(self, data): method __iter__ (line 69) | def __iter__(self): function create_accelerator (line 73) | def create_accelerator(even_batches=True): function create_dataloader (line 80) | def create_dataloader( function verify_dataloader_batch_sizes (line 100) | def verify_dataloader_batch_sizes( function test_default_ensures_even_batch_sizes (line 120) | def test_default_ensures_even_batch_sizes(): function test_can_disable_even_batches (line 142) | def test_can_disable_even_batches(): function test_can_join_uneven_inputs (line 162) | def test_can_join_uneven_inputs(): function test_join_raises_warning_for_non_ddp_distributed (line 186) | def test_join_raises_warning_for_non_ddp_distributed(accelerator): function test_join_can_override_even_batches (line 195) | def test_join_can_override_even_batches(): function test_join_can_override_for_mixed_type_dataloaders (line 214) | def test_join_can_override_for_mixed_type_dataloaders(): function test_join_raises_warning_for_iterable_when_overriding_even_batches (line 236) | def test_join_raises_warning_for_iterable_when_overriding_even_batches(): function test_pickle_accelerator (line 250) | def test_pickle_accelerator(): function test_data_loader (line 260) | def test_data_loader(data_loader, accelerator): function _test_stateful_dataloader_resume (line 278) | def _test_stateful_dataloader_resume(accelerator, iterable): function test_stateful_dataloader (line 315) | def test_stateful_dataloader(accelerator): function _test_stateful_dataloader_save_state_resume (line 326) | def _test_stateful_dataloader_save_state_resume(accelerator, iterable): function test_stateful_dataloader_save_state (line 361) | def test_stateful_dataloader_save_state(accelerator): function main (line 372) | def main(): FILE: src/accelerate/test_utils/scripts/test_merge_weights.py class TinyModel (line 38) | class TinyModel(torch.nn.Module): method __init__ (line 39) | def __init__(self): method forward (line 46) | def forward(self, x): function setup (line 50) | def setup(): function mock_training (line 64) | def mock_training(accelerator, model): function check_weights (line 80) | def check_weights(operation, state_1, state_2): function check_safetensors_weights (line 88) | def check_safetensors_weights(path, model): function check_pytorch_weights (line 96) | def check_pytorch_weights(path, model): function test_merge_weights_safetensors (line 104) | def test_merge_weights_safetensors(model, path): function test_merge_weights_command_safetensors (line 110) | def test_merge_weights_command_safetensors(model, path): function test_merge_weights_pytorch (line 116) | def test_merge_weights_pytorch(model, path): function test_merge_weights_command_pytorch (line 122) | def test_merge_weights_command_pytorch(model, path): FILE: src/accelerate/test_utils/scripts/test_notebook.py function basic_function (line 29) | def basic_function(): function tough_nut_function (line 34) | def tough_nut_function(queue): function bipolar_sleep_function (line 45) | def bipolar_sleep_function(sleep_sec: int): function test_can_initialize (line 56) | def test_can_initialize(): function test_static_rdzv_backend (line 61) | def test_static_rdzv_backend(): function test_c10d_rdzv_backend (line 66) | def test_c10d_rdzv_backend(): function test_fault_tolerant (line 71) | def test_fault_tolerant(max_restarts: int = 3): function test_monitoring (line 86) | def test_monitoring(monitor_interval: float = 0.01, sleep_sec: int = 100): function test_problematic_imports (line 99) | def test_problematic_imports(): function main (line 106) | def main(): FILE: src/accelerate/test_utils/scripts/test_ops.py function create_tensor (line 33) | def create_tensor(state): function test_gather (line 37) | def test_gather(state): function test_gather_object (line 43) | def test_gather_object(state): function test_gather_non_contiguous (line 53) | def test_gather_non_contiguous(state): function test_broadcast (line 65) | def test_broadcast(state): function test_pad_across_processes (line 72) | def test_pad_across_processes(state): function test_reduce_sum (line 85) | def test_reduce_sum(state): function test_reduce_mean (line 95) | def test_reduce_mean(state): function test_op_checker (line 105) | def test_op_checker(state): function test_copy_tensor_to_devices (line 140) | def test_copy_tensor_to_devices(state): function _mp_fn (line 151) | def _mp_fn(index): function main (line 156) | def main(): FILE: src/accelerate/test_utils/scripts/test_script.py function generate_baseline_dataloader (line 57) | def generate_baseline_dataloader(train_set, generator, batch_size, use_s... function print_main (line 72) | def print_main(state): function print_local_main (line 76) | def print_local_main(state): function print_last (line 80) | def print_last(state): function print_on (line 84) | def print_on(state, process_idx): function process_execution_check (line 88) | def process_execution_check(): function init_state_check (line 161) | def init_state_check(): function rng_sync_check (line 169) | def rng_sync_check(): function dl_preparation_check (line 187) | def dl_preparation_check(): function central_dl_preparation_check (line 247) | def central_dl_preparation_check(): function custom_sampler_check (line 312) | def custom_sampler_check(): function check_seedable_sampler (line 358) | def check_seedable_sampler(): function check_seedable_sampler_in_batch_sampler_shard (line 384) | def check_seedable_sampler_in_batch_sampler_shard(): function check_seedable_sampler_with_data_seed (line 403) | def check_seedable_sampler_with_data_seed(): function mock_training (line 431) | def mock_training(length, batch_size, generator, use_seedable_sampler=Fa... function training_check (line 449) | def training_check(use_seedable_sampler=False): function test_split_between_processes_dataset (line 623) | def test_split_between_processes_dataset(datasets_Dataset): function test_split_between_processes_list (line 671) | def test_split_between_processes_list(): function test_split_between_processes_nested_dict (line 704) | def test_split_between_processes_nested_dict(): function test_split_between_processes_tensor (line 742) | def test_split_between_processes_tensor(): function test_split_between_processes_evenly (line 776) | def test_split_between_processes_evenly(): function test_trigger (line 794) | def test_trigger(): function test_reinstantiated_state (line 811) | def test_reinstantiated_state(): function main (line 827) | def main(): FILE: src/accelerate/test_utils/scripts/test_sync.py function check_model_parameters (line 29) | def check_model_parameters(model_a, model_b, did_step, iteration, **kwar... function step_model (line 45) | def step_model(model, input, target, accelerator, do_backward=True): function get_training_setup (line 56) | def get_training_setup(accelerator, sched=False): function test_noop_sync (line 79) | def test_noop_sync(accelerator): function test_distributed_sync (line 113) | def test_distributed_sync(accelerator): function test_distributed_sync_multiple_fwd (line 153) | def test_distributed_sync_multiple_fwd(accelerator): function test_gradient_accumulation (line 207) | def test_gradient_accumulation(split_batches=False, dispatch_batches=Fal... function test_gradient_accumulation_with_opt_and_scheduler (line 248) | def test_gradient_accumulation_with_opt_and_scheduler( function test_dataloader_break (line 306) | def test_dataloader_break(): function main (line 331) | def main(): function _mp_fn (line 407) | def _mp_fn(index): FILE: src/accelerate/test_utils/testing.py function get_backend (line 84) | def get_backend(): function get_launch_command (line 114) | def get_launch_command(**kwargs) -> list: function parse_flag_from_env (line 136) | def parse_flag_from_env(key, default=False): function skip (line 155) | def skip(test_case): function slow (line 160) | def slow(test_case): function require_cpu (line 168) | def require_cpu(test_case): function require_non_cpu (line 175) | def require_non_cpu(test_case): function require_cuda (line 183) | def require_cuda(test_case): function require_cuda_or_hpu (line 191) | def require_cuda_or_hpu(test_case): function require_xpu (line 201) | def require_xpu(test_case): function require_cuda_or_xpu (line 208) | def require_cuda_or_xpu(test_case): function require_non_xpu (line 218) | def require_non_xpu(test_case): function require_non_hpu (line 225) | def require_non_hpu(test_case): function require_fp16 (line 232) | def require_fp16(test_case): function require_fp8 (line 240) | def require_fp8(test_case): function require_fsdp2 (line 258) | def require_fsdp2(test_case): function require_mlu (line 262) | def require_mlu(test_case): function require_sdaa (line 269) | def require_sdaa(test_case): function require_musa (line 276) | def require_musa(test_case): function require_npu (line 283) | def require_npu(test_case): function require_neuron (line 290) | def require_neuron(test_case): function require_mps (line 297) | def require_mps(test_case): function require_huggingface_suite (line 305) | def require_huggingface_suite(test_case): function require_datasets (line 315) | def require_datasets(test_case): function require_transformers (line 322) | def require_transformers(test_case): function require_timm (line 329) | def require_timm(test_case): function require_torchvision (line 336) | def require_torchvision(test_case): function require_triton (line 343) | def require_triton(test_case): function require_schedulefree (line 350) | def require_schedulefree(test_case): function require_bnb (line 357) | def require_bnb(test_case): function require_tpu (line 364) | def require_tpu(test_case): function require_non_torch_xla (line 371) | def require_non_torch_xla(test_case): function require_single_device (line 379) | def require_single_device(test_case): function require_single_gpu (line 389) | def require_single_gpu(test_case): function require_single_xpu (line 397) | def require_single_xpu(test_case): function require_multi_device (line 405) | def require_multi_device(test_case): function require_multi_gpu (line 413) | def require_multi_gpu(test_case): function require_multi_xpu (line 421) | def require_multi_xpu(test_case): function require_multi_gpu_or_xpu (line 429) | def require_multi_gpu_or_xpu(test_case): function require_deepspeed (line 439) | def require_deepspeed(test_case): function require_tp (line 446) | def require_tp(test_case): function require_torch_min_version (line 456) | def require_torch_min_version(test_case=None, version=None): function require_tensorboard (line 466) | def require_tensorboard(test_case): function require_wandb (line 474) | def require_wandb(test_case): function require_trackio (line 481) | def require_trackio(test_case): function require_comet_ml (line 488) | def require_comet_ml(test_case): function require_aim (line 495) | def require_aim(test_case): function require_clearml (line 502) | def require_clearml(test_case): function require_dvclive (line 509) | def require_dvclive(test_case): function require_swanlab (line 516) | def require_swanlab(test_case): function require_pandas (line 523) | def require_pandas(test_case): function require_mlflow (line 530) | def require_mlflow(test_case): function require_pippy (line 537) | def require_pippy(test_case): function require_import_timer (line 545) | def require_import_timer(test_case): function require_transformer_engine (line 553) | def require_transformer_engine(test_case): function require_transformer_engine_mxfp8 (line 561) | def require_transformer_engine_mxfp8(test_case): function require_torchao (line 571) | def require_torchao(test_case): function require_matplotlib (line 578) | def require_matplotlib(test_case): function require_trackers (line 592) | def require_trackers(test_case): function require_torchdata_stateful_dataloader (line 603) | def require_torchdata_stateful_dataloader(test_case): function run_first (line 615) | def run_first(test_case): class TempDirTestCase (line 634) | class TempDirTestCase(unittest.TestCase): method setUpClass (line 647) | def setUpClass(cls): method tearDownClass (line 652) | def tearDownClass(cls): method setUp (line 657) | def setUp(self): class AccelerateTestCase (line 667) | class AccelerateTestCase(unittest.TestCase): method tearDown (line 674) | def tearDown(self): class MockingTestCase (line 680) | class MockingTestCase(unittest.TestCase): method add_mocks (line 698) | def add_mocks(self, mocks: Union[mock.Mock, list[mock.Mock]]): function are_the_same_tensors (line 713) | def are_the_same_tensors(tensor): class _RunOutput (line 724) | class _RunOutput: method __init__ (line 725) | def __init__(self, returncode, stdout, stderr): function _read_stream (line 731) | async def _read_stream(stream, callback): function _stream_subprocess (line 740) | async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, qu... function execute_subprocess_async (line 781) | def execute_subprocess_async(cmd: list, env=None, stdin=None, timeout=18... function pytest_xdist_worker_id (line 800) | def pytest_xdist_worker_id(): function get_torch_dist_unique_port (line 810) | def get_torch_dist_unique_port(): class SubprocessCallException (line 822) | class SubprocessCallException(Exception): function run_command (line 826) | def run_command(command: list[str], return_stdout=False, env=None): function path_in_accelerate_package (line 849) | def path_in_accelerate_package(*components: str) -> Path: function assert_exception (line 865) | def assert_exception(exception_class: Exception, msg: Optional[str] = No... function capture_call_output (line 883) | def capture_call_output(func, *args, **kwargs): FILE: src/accelerate/test_utils/training.py class RegressionDataset (line 22) | class RegressionDataset: method __init__ (line 23) | def __init__(self, a=2, b=3, length=64, seed=None): method __len__ (line 29) | def __len__(self): method __getitem__ (line 32) | def __getitem__(self, i): class RegressionModel (line 36) | class RegressionModel(torch.nn.Module): method __init__ (line 37) | def __init__(self, a=0, b=0, double_output=False): method forward (line 43) | def forward(self, x=None): function mocked_dataloaders (line 50) | def mocked_dataloaders(accelerator, batch_size: int = 16): function mocked_dataloaders_for_autoregressive_models (line 90) | def mocked_dataloaders_for_autoregressive_models(accelerator, batch_size... FILE: src/accelerate/tracking.py function on_main_process (line 77) | def on_main_process(function): function get_available_trackers (line 96) | def get_available_trackers(): class GeneralTracker (line 101) | class GeneralTracker: method __init__ (line 120) | def __init__(self, _blank=False): method start (line 142) | def start(self): method store_init_configuration (line 149) | def store_init_configuration(self, values: dict): method log (line 161) | def log(self, values: dict, step: Optional[int], **kwargs): method finish (line 174) | def finish(self): class TensorBoardTracker (line 182) | class TensorBoardTracker(GeneralTracker): method __init__ (line 198) | def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike]... method start (line 205) | def start(self): method tracker (line 218) | def tracker(self): method store_init_configuration (line 222) | def store_init_configuration(self, values: dict): method log (line 246) | def log(self, values: dict, step: Optional[int] = None, **kwargs): method log_images (line 272) | def log_images(self, values: dict, step: Optional[int], **kwargs): method finish (line 289) | def finish(self): class WandBTracker (line 297) | class WandBTracker(GeneralTracker): method __init__ (line 312) | def __init__(self, run_name: str, **kwargs): method start (line 318) | def start(self): method tracker (line 328) | def tracker(self): method store_init_configuration (line 332) | def store_init_configuration(self, values: dict): method log (line 347) | def log(self, values: dict, step: Optional[int] = None, **kwargs): method log_images (line 364) | def log_images(self, values: dict, step: Optional[int] = None, **kwargs): method log_table (line 383) | def log_table( method finish (line 414) | def finish(self): class TrackioTracker (line 422) | class TrackioTracker(GeneralTracker): method __init__ (line 439) | def __init__(self, run_name: str, **kwargs): method start (line 445) | def start(self): method tracker (line 455) | def tracker(self): method store_init_configuration (line 459) | def store_init_configuration(self, values: dict): method log (line 474) | def log(self, values: dict, step: Optional[int] = None, **kwargs): method finish (line 491) | def finish(self): class CometMLTracker (line 499) | class CometMLTracker(GeneralTracker): method __init__ (line 520) | def __init__(self, run_name: str, **kwargs): method start (line 526) | def start(self): method tracker (line 542) | def tracker(self): method store_init_configuration (line 546) | def store_init_configuration(self, values: dict): method log (line 559) | def log(self, values: dict, step: Optional[int] = None, **kwargs): method finish (line 585) | def finish(self): class AimTracker (line 593) | class AimTracker(GeneralTracker): method __init__ (line 607) | def __init__(self, run_name: str, logging_dir: Optional[Union[str, os.... method start (line 614) | def start(self): method tracker (line 625) | def tracker(self): method store_init_configuration (line 629) | def store_init_configuration(self, values: dict): method log (line 640) | def log(self, values: dict, step: Optional[int], **kwargs): method log_images (line 657) | def log_images(self, values: dict, step: Optional[int] = None, kwargs:... method finish (line 689) | def finish(self): class MLflowTracker (line 696) | class MLflowTracker(GeneralTracker): method __init__ (line 727) | def __init__( method start (line 754) | def start(self): method tracker (line 784) | def tracker(self): method store_init_configuration (line 788) | def store_init_configuration(self, values: dict): method log (line 816) | def log(self, values: dict, step: Optional[int]): method log_figure (line 841) | def log_figure(self, figure: Any, artifact_file: str, **save_kwargs): method log_artifacts (line 860) | def log_artifacts(self, local_dir: str, artifact_path: Optional[str] =... method log_artifact (line 877) | def log_artifact(self, local_path: str, artifact_path: Optional[str] =... method finish (line 894) | def finish(self): class ClearMLTracker (line 903) | class ClearMLTracker(GeneralTracker): method __init__ (line 918) | def __init__(self, run_name: Optional[str] = None, **kwargs): method start (line 925) | def start(self): method tracker (line 940) | def tracker(self): method store_init_configuration (line 944) | def store_init_configuration(self, values: dict): method log (line 955) | def log(self, values: dict[str, Union[int, float]], step: Optional[int... method log_images (line 989) | def log_images(self, values: dict, step: Optional[int] = None, **kwargs): method log_table (line 1007) | def log_table( method finish (line 1045) | def finish(self): method _get_title_series (line 1054) | def _get_title_series(name): class DVCLiveTracker (line 1061) | class DVCLiveTracker(GeneralTracker): method __init__ (line 1084) | def __init__(self, run_name: Optional[str] = None, live: Optional[Any]... method start (line 1090) | def start(self): method tracker (line 1096) | def tracker(self): method store_init_configuration (line 1100) | def store_init_configuration(self, values: dict): method log (line 1113) | def log(self, values: dict, step: Optional[int] = None, **kwargs): method finish (line 1142) | def finish(self): class SwanLabTracker (line 1149) | class SwanLabTracker(GeneralTracker): method __init__ (line 1164) | def __init__(self, run_name: str, **kwargs): method start (line 1170) | def start(self): method tracker (line 1181) | def tracker(self): method store_init_configuration (line 1185) | def store_init_configuration(self, values: dict): method log (line 1200) | def log(self, values: dict, step: Optional[int] = None, **kwargs): method log_images (line 1220) | def log_images(self, values: dict, step: Optional[int] = None, **kwargs): method finish (line 1241) | def finish(self): function filter_trackers (line 1262) | def filter_trackers( FILE: src/accelerate/utils/ao.py function find_first_last_linear_layers (line 32) | def find_first_last_linear_layers(model: torch.nn.Module): function filter_linear_layers (line 49) | def filter_linear_layers(module, fqn: str, layers_to_filter: list[str]) ... function filter_first_and_last_linear_layers (line 72) | def filter_first_and_last_linear_layers(module, fqn: str) -> bool: function has_ao_layers (line 94) | def has_ao_layers(model: torch.nn.Module): function convert_model_to_fp8_ao (line 104) | def convert_model_to_fp8_ao( FILE: src/accelerate/utils/bnb.py function load_and_quantize_model (line 44) | def load_and_quantize_model( function get_quantized_model_device_map (line 191) | def get_quantized_model_device_map( function replace_with_bnb_layers (line 271) | def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_n... function _replace_with_bnb_layers (line 303) | def _replace_with_bnb_layers( function get_keys_to_not_convert (line 369) | def get_keys_to_not_convert(model): function has_4bit_bnb_layers (line 421) | def has_4bit_bnb_layers(model): function get_parameter_device (line 432) | def get_parameter_device(parameter: nn.Module): function quantize_and_offload_8bit (line 436) | def quantize_and_offload_8bit(model, param, param_name, new_dtype, offlo... FILE: src/accelerate/utils/dataclasses.py class KwargsHandler (line 68) | class KwargsHandler: method to_dict (line 73) | def to_dict(self): method to_kwargs (line 76) | def to_kwargs(self): class EnumWithContains (line 89) | class EnumWithContains(enum.EnumMeta): method __contains__ (line 92) | def __contains__(cls, item): class BaseEnum (line 100) | class BaseEnum(enum.Enum, metaclass=EnumWithContains): method __str__ (line 103) | def __str__(self): method list (line 107) | def list(cls): class AutocastKwargs (line 113) | class AutocastKwargs(KwargsHandler): class DDPCommunicationHookType (line 134) | class DDPCommunicationHookType(BaseEnum): class DistributedDataParallelKwargs (line 155) | class DistributedDataParallelKwargs(KwargsHandler): method to_dict (line 197) | def to_dict(self, ignore_keys=("comm_hook", "comm_wrapper", "comm_stat... method register_comm_hook (line 200) | def register_comm_hook(self, model): class GradScalerKwargs (line 241) | class GradScalerKwargs(KwargsHandler): class InitProcessGroupKwargs (line 273) | class InitProcessGroupKwargs(KwargsHandler): method __post_init__ (line 296) | def __post_init__(self): class AORecipeKwargs (line 311) | class AORecipeKwargs(KwargsHandler): method __post_init__ (line 337) | def __post_init__(self): class TERecipeKwargs (line 359) | class TERecipeKwargs(KwargsHandler): method __post_init__ (line 406) | def __post_init__(self): class MSAMPRecipeKwargs (line 438) | class MSAMPRecipeKwargs(KwargsHandler): method __post_init__ (line 446) | def __post_init__(self): class FP8RecipeKwargs (line 455) | class FP8RecipeKwargs(TERecipeKwargs, MSAMPRecipeKwargs): method __post_init__ (line 463) | def __post_init__(self): class ProfileKwargs (line 484) | class ProfileKwargs(KwargsHandler): method _get_profiler_activity (line 544) | def _get_profiler_activity(self, activity: ProfilerActivity) -> torch.... method build (line 574) | def build(self) -> torch.profiler.profile: class DistributedType (line 600) | class DistributedType(str, enum.Enum): class SageMakerDistributedType (line 639) | class SageMakerDistributedType(str, enum.Enum): class FP8BackendType (line 656) | class FP8BackendType(str, enum.Enum): class ComputeEnvironment (line 673) | class ComputeEnvironment(str, enum.Enum): class DynamoBackend (line 688) | class DynamoBackend(str, BaseEnum): class LoggerType (line 741) | class LoggerType(BaseEnum): class PrecisionType (line 769) | class PrecisionType(str, BaseEnum): class RNGType (line 785) | class RNGType(BaseEnum): class CustomDtype (line 799) | class CustomDtype(enum.Enum): class TensorInformation (line 813) | class TensorInformation: class DataLoaderConfiguration (line 819) | class DataLoaderConfiguration: class ProjectConfiguration (line 914) | class ProjectConfiguration: method set_directories (line 966) | def set_directories(self, project_dir: Optional[str] = None): method __post_init__ (line 972) | def __post_init__(self): class GradientAccumulationPlugin (line 977) | class GradientAccumulationPlugin(KwargsHandler): class TorchDynamoPlugin (line 1029) | class TorchDynamoPlugin(KwargsHandler): method __post_init__ (line 1088) | def __post_init__(self): method to_dict (line 1106) | def to_dict(self): method to_kwargs (line 1111) | def to_kwargs(self): class DeepSpeedPlugin (line 1118) | class DeepSpeedPlugin: method __post_init__ (line 1222) | def __post_init__(self): method fill_match (line 1353) | def fill_match(self, ds_key_long, mismatches=None, must_match=True, **... method is_auto (line 1378) | def is_auto(self, ds_key_long): method get_value (line 1385) | def get_value(self, ds_key_long, default=None): method deepspeed_config_process (line 1388) | def deepspeed_config_process(self, prefix="", mismatches=None, config=... method set_mixed_precision (line 1411) | def set_mixed_precision(self, mixed_precision): method set_deepspeed_weakref (line 1444) | def set_deepspeed_weakref(self): method is_zero3_init_enabled (line 1475) | def is_zero3_init_enabled(self): method zero3_init_context_manager (line 1479) | def zero3_init_context_manager(self, enable=False): method _deepspeed_config_checks (line 1492) | def _deepspeed_config_checks(self): method set_moe_leaf_modules (line 1519) | def set_moe_leaf_modules(self, model): method select (line 1539) | def select(self, _from_accelerator_state: bool = False): method _unselect (line 1550) | def _unselect(self): method _set_selected (line 1553) | def _set_selected(self, value: bool): method selected (line 1560) | def selected(self): method selected (line 1564) | def selected(self, value): class FullyShardedDataParallelPlugin (line 1571) | class FullyShardedDataParallelPlugin: method __post_init__ (line 1800) | def __post_init__(self): method set_state_dict_type (line 1996) | def set_state_dict_type(self, state_dict_type=None): method set_auto_wrap_policy (line 2041) | def set_auto_wrap_policy(self, model): method set_mixed_precision (line 2075) | def set_mixed_precision(self, mixed_precision, buffer_autocast=False, ... method validate_mixed_precision_policy (line 2127) | def validate_mixed_precision_policy(self): method set_cpu_offload (line 2144) | def set_cpu_offload(self): method validate_cpu_offload (line 2159) | def validate_cpu_offload(self): class TorchTensorParallelPlugin (line 2176) | class TorchTensorParallelPlugin: class TorchContextParallelConfig (line 2191) | class TorchContextParallelConfig: method __post_init__ (line 2203) | def __post_init__(self): class DeepSpeedSequenceParallelConfig (line 2219) | class DeepSpeedSequenceParallelConfig: method __post_init__ (line 2239) | def __post_init__(self): class TorchTensorParallelConfig (line 2279) | class TorchTensorParallelConfig: method __post_init__ (line 2286) | def __post_init__(self): class MegatronLMPlugin (line 2301) | class MegatronLMPlugin: method __post_init__ (line 2593) | def __post_init__(self): method set_network_size_args (line 2723) | def set_network_size_args(self, model, batch_data=None): method set_mixed_precision (line 2734) | def set_mixed_precision(self, mixed_precision): method set_training_args (line 2742) | def set_training_args(self, micro_batch_size, dp_degree): method set_optimizer_type (line 2750) | def set_optimizer_type(self, optimizer): method set_scheduler_args (line 2766) | def set_scheduler_args(self, scheduler): method set_tensorboard_logging_options (line 2795) | def set_tensorboard_logging_options(self): function add_model_config_to_megatron_parser (line 2812) | def add_model_config_to_megatron_parser(model_type: str): function parse_bert_config (line 2825) | def parse_bert_config(megatron_lm_plugin, model, batch_data): function parse_gpt2_config (line 2859) | def parse_gpt2_config(megatron_lm_plugin, model, batch_data): function parse_t5_config (line 2891) | def parse_t5_config(megatron_lm_plugin, model, batch_data): function parse_llama_config (line 2922) | def parse_llama_config(megatron_lm_plugin, model, batch_data): function parse_glm4_moe_config (line 2956) | def parse_glm4_moe_config(megatron_lm_plugin, model, batch_data): class BnbQuantizationConfig (line 3040) | class BnbQuantizationConfig: method __post_init__ (line 3119) | def __post_init__(self): function get_module_class_from_name (line 3194) | def get_module_class_from_name(module, name): FILE: src/accelerate/utils/deepspeed.py function map_pytorch_optim_to_deepspeed (line 29) | def map_pytorch_optim_to_deepspeed(optimizer): function get_active_deepspeed_plugin (line 100) | def get_active_deepspeed_plugin(state): class HfDeepSpeedConfig (line 119) | class HfDeepSpeedConfig: method __init__ (line 136) | def __init__(self, config_file_or_dict): method set_stage_and_offload (line 162) | def set_stage_and_offload(self): method find_config_node (line 181) | def find_config_node(self, ds_key_long): method get_value (line 194) | def get_value(self, ds_key_long, default=None): method del_config_sub_tree (line 203) | def del_config_sub_tree(self, ds_key_long, must_exist=False): method is_true (line 226) | def is_true(self, ds_key_long): method is_false (line 235) | def is_false(self, ds_key_long): method is_zero2 (line 243) | def is_zero2(self): method is_zero3 (line 246) | def is_zero3(self): method is_offload (line 249) | def is_offload(self): class DeepSpeedEngineWrapper (line 253) | class DeepSpeedEngineWrapper: method __init__ (line 261) | def __init__(self, engine): method backward (line 264) | def backward(self, loss, sync_gradients=True, **kwargs): method get_global_grad_norm (line 286) | def get_global_grad_norm(self): class DeepSpeedOptimizerWrapper (line 295) | class DeepSpeedOptimizerWrapper(AcceleratedOptimizer): method __init__ (line 304) | def __init__(self, optimizer): method zero_grad (line 308) | def zero_grad(self, set_to_none=None): method step (line 311) | def step(self): method step_was_skipped (line 315) | def step_was_skipped(self): class DeepSpeedSchedulerWrapper (line 322) | class DeepSpeedSchedulerWrapper(AcceleratedScheduler): method __init__ (line 332) | def __init__(self, scheduler, optimizers): method step (line 335) | def step(self): class DummyOptim (line 339) | class DummyOptim: method __init__ (line 355) | def __init__(self, params, lr=0.001, weight_decay=0, **kwargs): class DummyScheduler (line 362) | class DummyScheduler: method __init__ (line 380) | def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0... FILE: src/accelerate/utils/environment.py function convert_dict_to_env_variables (line 34) | def convert_dict_to_env_variables(current_env: dict): function str_to_bool (line 59) | def str_to_bool(value, to_bool: bool = False) -> Union[int, bool]: function get_int_from_env (line 74) | def get_int_from_env(env_keys, default): function parse_flag_from_env (line 83) | def parse_flag_from_env(key, default=False): function parse_choice_from_env (line 89) | def parse_choice_from_env(key, default="no"): function are_libraries_initialized (line 94) | def are_libraries_initialized(*library_names: str) -> list[str]: function get_current_device_type (line 101) | def get_current_device_type() -> tuple[str, str]: function _nvidia_smi (line 154) | def _nvidia_smi(): function get_gpu_info (line 169) | def get_gpu_info(): function get_driver_version (line 187) | def get_driver_version(): function check_cuda_p2p_ib_support (line 200) | def check_cuda_p2p_ib_support(): function check_cuda_fp8_capability (line 229) | def check_cuda_fp8_capability(): class CPUInformation (line 251) | class CPUInformation: function get_cpu_distributed_information (line 266) | def get_cpu_distributed_information() -> CPUInformation: function override_numa_affinity (line 286) | def override_numa_affinity(local_process_index: int, verbose: Optional[b... function set_numa_affinity (line 326) | def set_numa_affinity(local_process_index: int, verbose: Optional[bool] ... function clear_environment (line 344) | def clear_environment(): function patch_environment (line 379) | def patch_environment(**kwargs): function purge_accelerate_environment (line 415) | def purge_accelerate_environment(func_or_cls): FILE: src/accelerate/utils/fsdp_utils.py function enable_fsdp_ram_efficient_loading (line 39) | def enable_fsdp_ram_efficient_loading(): function disable_fsdp_ram_efficient_loading (line 49) | def disable_fsdp_ram_efficient_loading(): function _get_model_state_dict (line 56) | def _get_model_state_dict(model, adapter_only=False, sd_options=None): function _set_model_state_dict (line 71) | def _set_model_state_dict(model, state_dict, adapter_only=False, sd_opti... function _prepare_sd_options (line 86) | def _prepare_sd_options(fsdp_plugin): function save_fsdp_model (line 103) | def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_i... function load_fsdp_model (line 161) | def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_in... function save_fsdp_optimizer (line 233) | def save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, outp... function load_fsdp_optimizer (line 281) | def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, inpu... function _distributed_checkpoint_to_merged_weights (line 338) | def _distributed_checkpoint_to_merged_weights(checkpoint_dir: str, save_... function merge_fsdp_weights (line 366) | def merge_fsdp_weights( function ensure_weights_retied (line 421) | def ensure_weights_retied(param_init_fn, model: torch.nn.Module, device:... function fsdp2_load_full_state_dict (line 467) | def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full... function fsdp2_switch_optimizer_parameters (line 557) | def fsdp2_switch_optimizer_parameters(optimizer: torch.optim.Optimizer, ... function fsdp2_apply_ac (line 588) | def fsdp2_apply_ac(accelerator, model: torch.nn.Module): function fsdp2_prepare_model (line 621) | def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn... function fsdp2_prepare_auto_wrap_policy (line 749) | def fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model: torch.nn.Module)... function get_fsdp2_grad_scaler (line 802) | def get_fsdp2_grad_scaler(**kwargs): function fsdp2_canonicalize_names (line 813) | def fsdp2_canonicalize_names(named_params: dict) -> dict: function get_parameters_from_modules (line 832) | def get_parameters_from_modules( FILE: src/accelerate/utils/imports.py function _is_package_available (line 50) | def _is_package_available(pkg_name, metadata_name=None): function is_torch_distributed_available (line 62) | def is_torch_distributed_available() -> bool: function is_xccl_available (line 66) | def is_xccl_available(): function is_import_timer_available (line 72) | def is_import_timer_available(): function is_pynvml_available (line 76) | def is_pynvml_available(): function is_pytest_available (line 80) | def is_pytest_available(): function is_msamp_available (line 84) | def is_msamp_available(): function is_schedulefree_available (line 88) | def is_schedulefree_available(): function is_transformer_engine_available (line 92) | def is_transformer_engine_available(): function is_transformer_engine_mxfp8_available (line 99) | def is_transformer_engine_mxfp8_available(): function is_lomo_available (line 107) | def is_lomo_available(): function is_cuda_available (line 111) | def is_cuda_available(): function is_torch_xla_available (line 123) | def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): function is_torchao_available (line 140) | def is_torchao_available(): function is_deepspeed_available (line 148) | def is_deepspeed_available(): function is_pippy_available (line 152) | def is_pippy_available(): function is_bf16_available (line 156) | def is_bf16_available(ignore_tpu=False): function is_fp16_available (line 171) | def is_fp16_available(): function is_fp8_available (line 179) | def is_fp8_available(): function is_4bit_bnb_available (line 184) | def is_4bit_bnb_available(): function is_8bit_bnb_available (line 192) | def is_8bit_bnb_available(): function is_bnb_available (line 200) | def is_bnb_available(min_version=None): function is_bitsandbytes_multi_backend_available (line 209) | def is_bitsandbytes_multi_backend_available(): function is_torchvision_available (line 217) | def is_torchvision_available(): function is_megatron_lm_available (line 221) | def is_megatron_lm_available(): function is_transformers_available (line 233) | def is_transformers_available(): function is_datasets_available (line 237) | def is_datasets_available(): function is_peft_available (line 241) | def is_peft_available(): function is_timm_available (line 245) | def is_timm_available(): function is_triton_available (line 249) | def is_triton_available(): function is_aim_available (line 255) | def is_aim_available(): function is_tensorboard_available (line 263) | def is_tensorboard_available(): function is_wandb_available (line 267) | def is_wandb_available(): function is_comet_ml_available (line 271) | def is_comet_ml_available(): function is_swanlab_available (line 275) | def is_swanlab_available(): function is_trackio_available (line 279) | def is_trackio_available(): function is_boto3_available (line 283) | def is_boto3_available(): function is_rich_available (line 287) | def is_rich_available(): function is_sagemaker_available (line 293) | def is_sagemaker_available(): function is_tqdm_available (line 297) | def is_tqdm_available(): function is_clearml_available (line 301) | def is_clearml_available(): function is_pandas_available (line 305) | def is_pandas_available(): function is_matplotlib_available (line 309) | def is_matplotlib_available(): function is_mlflow_available (line 313) | def is_mlflow_available(): function is_mps_available (line 326) | def is_mps_available(min_version="1.12"): function is_mlu_available (line 334) | def is_mlu_available(check_device=False): function is_musa_available (line 351) | def is_musa_available(check_device=False): function is_npu_available (line 369) | def is_npu_available(check_device=False): function is_sdaa_available (line 392) | def is_sdaa_available(check_device=False): function is_hpu_available (line 410) | def is_hpu_available(init_hccl=False): function is_habana_gaudi1 (line 426) | def is_habana_gaudi1(): function is_xpu_available (line 437) | def is_xpu_available(check_device=False): function is_neuron_available (line 457) | def is_neuron_available(check_device=False): function is_dvclive_available (line 474) | def is_dvclive_available(): function is_torchdata_available (line 478) | def is_torchdata_available(): function is_torchdata_stateful_dataloader_available (line 483) | def is_torchdata_stateful_dataloader_available(): function torchao_required (line 491) | def torchao_required(func): function deepspeed_required (line 508) | def deepspeed_required(func): function is_weights_only_available (line 528) | def is_weights_only_available(): function is_numpy_available (line 534) | def is_numpy_available(min_version="1.25.0"): FILE: src/accelerate/utils/launch.py function _filter_args (line 47) | def _filter_args(args, parser, default_args=[]): function _get_mpirun_args (line 58) | def _get_mpirun_args(): function setup_fp8_env (line 82) | def setup_fp8_env(args: argparse.Namespace, current_env: dict[str, str]): function prepare_simple_launcher_cmd_env (line 100) | def prepare_simple_launcher_cmd_env(args: argparse.Namespace) -> tuple[l... function prepare_multi_gpu_env (line 201) | def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]: function prepare_extend_env_parallelism_config (line 402) | def prepare_extend_env_parallelism_config( function prepare_deepspeed_cmd_env (line 429) | def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> tuple[list[st... function prepare_tpu (line 593) | def prepare_tpu( function _convert_nargs_to_dict (line 613) | def _convert_nargs_to_dict(nargs: list[str]) -> dict[str, str]: function prepare_sagemager_args_inputs (line 655) | def prepare_sagemager_args_inputs( function env_var_path_add (line 773) | def env_var_path_add(env_var_name, path_to_add): class PrepareForLaunch (line 783) | class PrepareForLaunch: method __init__ (line 796) | def __init__(self, launcher, distributed_type="NO", debug=False): method __call__ (line 801) | def __call__(self, index, *args): FILE: src/accelerate/utils/megatron_lm.py function model_provider_func (line 85) | def model_provider_func(pre_process=True, post_process=True, add_encoder... function prepare_model_optimizer_scheduler (line 134) | def prepare_model_optimizer_scheduler(accelerator): class MegatronLMDummyDataLoader (line 162) | class MegatronLMDummyDataLoader: method __init__ (line 170) | def __init__(self, **dataset_kwargs): method set_megatron_data_args (line 179) | def set_megatron_data_args(self): method get_train_valid_test_datasets_provider (line 189) | def get_train_valid_test_datasets_provider(self, accelerator): method build_train_valid_test_data_iterators (line 249) | def build_train_valid_test_data_iterators(self, accelerator): function _handle_megatron_data_iterator (line 271) | def _handle_megatron_data_iterator(accelerator, data_iterator): function prepare_data_loader (line 289) | def prepare_data_loader(accelerator, dataloader): class MegatronLMOptimizerWrapper (line 355) | class MegatronLMOptimizerWrapper(AcceleratedOptimizer): method __init__ (line 356) | def __init__(self, optimizer): method zero_grad (line 359) | def zero_grad(self, set_to_none=None): method step (line 362) | def step(self): method step_was_skipped (line 366) | def step_was_skipped(self): function prepare_optimizer (line 371) | def prepare_optimizer(accelerator, model): class MegatronLMDummyScheduler (line 378) | class MegatronLMDummyScheduler: method __init__ (line 394) | def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0... class MegatronLMSchedulerWrapper (line 401) | class MegatronLMSchedulerWrapper(AcceleratedScheduler): method __init__ (line 402) | def __init__(self, scheduler, optimizers): method step (line 405) | def step(self, *args, **kwargs): function prepare_scheduler (line 409) | def prepare_scheduler(accelerator, optimizer, scheduler): class AbstractTrainStep (line 415) | class AbstractTrainStep(ABC): method __init__ (line 418) | def __init__(self, name): method get_batch_func (line 422) | def get_batch_func(self, accelerator, megatron_dataset_flag): method get_forward_step_func (line 425) | def get_forward_step_func(self): method get_loss_func (line 428) | def get_loss_func(self, accelerator): class BertTrainStep (line 432) | class BertTrainStep(AbstractTrainStep): method __init__ (line 440) | def __init__(self, accelerator, args): method get_batch_func (line 452) | def get_batch_func(self, accelerator, megatron_dataset_flag): method get_loss_func (line 516) | def get_loss_func(self, accelerator, pretraining_flag, num_labels): method get_forward_step_func (line 557) | def get_forward_step_func(self, pretraining_flag, bert_binary_head): class GPTTrainStep (line 574) | class GPTTrainStep(AbstractTrainStep): method __init__ (line 582) | def __init__(self, accelerator, args): method get_batch_func (line 602) | def get_batch_func(self, accelerator, megatron_dataset_flag): method get_loss_func (line 669) | def get_loss_func(self, accelerator): method get_forward_step_func (line 706) | def get_forward_step_func(self): class T5TrainStep (line 718) | class T5TrainStep(AbstractTrainStep): method __init__ (line 726) | def __init__(self, accelerator, args): method attn_mask_postprocess (line 739) | def attn_mask_postprocess(attention_mask): method get_decoder_mask (line 752) | def get_decoder_mask(seq_length, device): method get_enc_dec_mask (line 758) | def get_enc_dec_mask(attention_mask, dec_seq_length, device): method get_batch_func (line 769) | def get_batch_func(self, accelerator, megatron_dataset_flag): method get_loss_func (line 832) | def get_loss_func(self, accelerator): method get_forward_step_func (line 846) | def get_forward_step_func(self): function finish_mpu_init (line 863) | def finish_mpu_init(): function initialize (line 876) | def initialize(accelerator, extra_args_provider=None, args_defaults=None): class MegatronEngine (line 926) | class MegatronEngine(torch.nn.Module): method __init__ (line 937) | def __init__(self, accelerator, model, optimizer, scheduler): method get_module_config (line 968) | def get_module_config(self): method train (line 994) | def train(self): method eval (line 1003) | def eval(self): method get_batch_data_iterator (line 1010) | def get_batch_data_iterator(self, batch_data): method train_step (line 1035) | def train_step(self, **batch_data): method eval_step (line 1059) | def eval_step(self, **batch_data): method forward (line 1099) | def forward(self, **batch_data): method log_eval_results (line 1162) | def log_eval_results(self): method save_checkpoint (line 1188) | def save_checkpoint(self, output_dir): method load_checkpoint (line 1202) | def load_checkpoint(self, input_dir): function avg_losses_across_data_parallel_group (line 1217) | def avg_losses_across_data_parallel_group(losses): function gather_across_data_parallel_groups (line 1228) | def gather_across_data_parallel_groups(tensor): FILE: src/accelerate/utils/memory.py function clear_device_cache (line 40) | def clear_device_cache(garbage_collection=False): function release_memory (line 70) | def release_memory(*objects): function should_reduce_batch_size (line 100) | def should_reduce_batch_size(exception: Exception) -> bool: function find_executable_batch_size (line 119) | def find_executable_batch_size( FILE: src/accelerate/utils/modeling.py function is_peft_model (line 73) | def is_peft_model(model): function check_device_same (line 82) | def check_device_same(first_device, second_device): function convert_file_size_to_int (line 109) | def convert_file_size_to_int(size: Union[int, str]): function dtype_byte_size (line 153) | def dtype_byte_size(dtype: torch.dtype): function id_tensor_storage (line 181) | def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, ... function set_module_tensor_to_device (line 217) | def set_module_tensor_to_device( function named_module_tensors (line 430) | def named_module_tensors( function get_non_persistent_buffers (line 460) | def get_non_persistent_buffers(module: nn.Module, recurse: bool = False,... function check_tied_parameters_in_config (line 484) | def check_tied_parameters_in_config(model: nn.Module): function _get_param_device (line 524) | def _get_param_device(param, device_map): function check_tied_parameters_on_same_device (line 534) | def check_tied_parameters_on_same_device(tied_params, device_map): function find_tied_parameters (line 557) | def find_tied_parameters(model: torch.nn.Module, **kwargs) -> list[list[... function retie_parameters (line 612) | def retie_parameters(model, tied_params): function _get_proper_dtype (line 643) | def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype: function compute_module_sizes (line 654) | def compute_module_sizes( function compute_module_total_buffer_size (line 696) | def compute_module_total_buffer_size( function get_max_layer_size (line 708) | def get_max_layer_size( function get_max_memory (line 747) | def get_max_memory(max_memory: Optional[dict[Union[int, str], Union[int,... function clean_device_map (line 858) | def clean_device_map(device_map: dict[str, Union[int, str, torch.device]... function load_offloaded_weights (line 880) | def load_offloaded_weights(model, index, offload_folder): function get_module_leaves (line 910) | def get_module_leaves(module_sizes): function get_balanced_memory (line 921) | def get_balanced_memory( function calculate_maximum_sizes (line 1055) | def calculate_maximum_sizes(model: torch.nn.Module): function _init_infer_auto_device_map (line 1073) | def _init_infer_auto_device_map( function get_module_size_with_ties (line 1137) | def get_module_size_with_ties( function fallback_allocate (line 1173) | def fallback_allocate( function infer_auto_device_map (line 1281) | def infer_auto_device_map( function check_device_map (line 1589) | def check_device_map(model: nn.Module, device_map: dict[str, Union[int, ... function load_state_dict (line 1623) | def load_state_dict(checkpoint_file, device_map=None): function get_state_dict_offloaded_model (line 1718) | def get_state_dict_offloaded_model(model: nn.Module): function get_state_dict_from_offload (line 1755) | def get_state_dict_from_offload( function load_checkpoint_in_model (line 1791) | def load_checkpoint_in_model( function get_mixed_precision_context_manager (line 2052) | def get_mixed_precision_context_manager(native_amp: bool = False, autoca... function get_grad_scaler (line 2096) | def get_grad_scaler(distributed_type: DistributedType = None, **kwargs): function has_offloaded_params (line 2138) | def has_offloaded_params(module: torch.nn.Module) -> bool: function align_module_device (line 2155) | def align_module_device(module: torch.nn.Module, execution_device: Optio... FILE: src/accelerate/utils/offload.py function offload_weight (line 25) | def offload_weight(weight, weight_name, offload_folder, index=None): function load_offloaded_weight (line 46) | def load_offloaded_weight(weight_file, weight_info): function save_offload_index (line 68) | def save_offload_index(index, offload_folder): function offload_state_dict (line 85) | def offload_state_dict(save_dir: Union[str, os.PathLike], state_dict: di... class PrefixedDataset (line 104) | class PrefixedDataset(Mapping): method __init__ (line 113) | def __init__(self, dataset: Mapping, prefix: str): method __getitem__ (line 117) | def __getitem__(self, key): method __iter__ (line 120) | def __iter__(self): method __len__ (line 123) | def __len__(self): class OffloadedWeightsLoader (line 127) | class OffloadedWeightsLoader(Mapping): method __init__ (line 141) | def __init__( method __getitem__ (line 161) | def __getitem__(self, key: str): method __iter__ (line 187) | def __iter__(self): method __len__ (line 190) | def __len__(self): function extract_submodules_state_dict (line 194) | def extract_submodules_state_dict(state_dict: dict[str, torch.Tensor], s... FILE: src/accelerate/utils/operations.py function is_torch_tensor (line 45) | def is_torch_tensor(tensor): function is_torch_xpu_tensor (line 49) | def is_torch_xpu_tensor(tensor): function is_tensor_information (line 62) | def is_tensor_information(tensor_info): function is_namedtuple (line 66) | def is_namedtuple(data): function honor_type (line 74) | def honor_type(obj, generator): function recursively_apply (line 85) | def recursively_apply(func, data, *args, test_type=is_torch_tensor, erro... function send_to_device (line 136) | def send_to_device(tensor, device, non_blocking=False, skip_keys=None): function get_data_structure (line 188) | def get_data_structure(data): function get_shape (line 206) | def get_shape(data): function initialize_tensors (line 224) | def initialize_tensors(data_structure): function find_batch_size (line 238) | def find_batch_size(data): function ignorant_find_batch_size (line 261) | def ignorant_find_batch_size(data): function listify (line 278) | def listify(data): function _tpu_gather (line 301) | def _tpu_gather(tensor): function _gpu_gather (line 316) | def _gpu_gather(tensor): class DistributedOperationException (line 355) | class DistributedOperationException(Exception): function verify_operation (line 364) | def verify_operation(function): function chained_operation (line 399) | def chained_operation(function): function gather (line 419) | def gather(tensor): function _gpu_gather_object (line 438) | def _gpu_gather_object(object: Any): function gather_object (line 445) | def gather_object(object: Any): function _gpu_broadcast (line 464) | def _gpu_broadcast(data, src=0): function _tpu_broadcast (line 472) | def _tpu_broadcast(tensor, src=0, name="broadcast tensor"): function gather_tensor_shape (line 496) | def gather_tensor_shape(tensor): function copy_tensor_to_devices (line 521) | def copy_tensor_to_devices(tensor=None) -> torch.Tensor: function broadcast (line 539) | def broadcast(tensor, from_process: int = 0): function broadcast_object_list (line 560) | def broadcast_object_list(object_list, from_process: int = 0): function slice_tensors (line 581) | def slice_tensors(data, tensor_slice, process_index=None, num_processes=... function concatenate (line 601) | def concatenate(data, dim=0): class CannotPadNestedTensorWarning (line 627) | class CannotPadNestedTensorWarning(UserWarning): function pad_across_processes (line 632) | def pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False): function pad_input_tensors (line 687) | def pad_input_tensors(tensor, batch_size, num_processes, dim=0): function reduce (line 728) | def reduce(tensor, reduction="mean", scale=1.0): function convert_to_fp32 (line 769) | def convert_to_fp32(tensor): class ConvertOutputsToFp32 (line 793) | class ConvertOutputsToFp32: method __init__ (line 806) | def __init__(self, model_forward): method __call__ (line 810) | def __call__(self, *args, **kwargs): method __getstate__ (line 813) | def __getstate__(self): function convert_outputs_to_fp32 (line 819) | def convert_outputs_to_fp32(model_forward): function find_device (line 831) | def find_device(data): function GatheredParameters (line 853) | def GatheredParameters(params, modifier_rank=None, fwd_module=None, enab... FILE: src/accelerate/utils/other.py function is_compiled_module (line 54) | def is_compiled_module(module: torch.nn.Module) -> bool: function has_compiled_regions (line 64) | def has_compiled_regions(module: torch.nn.Module) -> bool: function is_repeated_blocks (line 79) | def is_repeated_blocks(module: torch.nn.Module) -> bool: function has_repeated_blocks (line 92) | def has_repeated_blocks(module: torch.nn.Module) -> bool: function compile_regions (line 106) | def compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.... function compile_regions_deepspeed (line 178) | def compile_regions_deepspeed(module: torch.nn.Module, **compile_kwargs): function model_has_dtensor (line 202) | def model_has_dtensor(model: torch.nn.Module) -> bool: function extract_model_from_parallel (line 222) | def extract_model_from_parallel( function wait_for_everyone (line 310) | def wait_for_everyone(): function clean_state_dict_for_safetensors (line 323) | def clean_state_dict_for_safetensors(state_dict: dict): function save (line 358) | def save(obj, f, save_on_each_node: bool = False, safe_serialization: bo... function load (line 408) | def load(f, map_location=None, **kwargs): function get_pretty_name (line 440) | def get_pretty_name(obj): function merge_dicts (line 453) | def merge_dicts(source, destination): function is_port_in_use (line 471) | def is_port_in_use(port: Optional[int] = None) -> bool: function get_free_port (line 482) | def get_free_port() -> int: function convert_bytes (line 495) | def convert_bytes(size): function check_os_kernel (line 505) | def check_os_kernel(): function recursive_getattr (line 523) | def recursive_getattr(obj, attr: str): function get_module_children_bottom_up (line 540) | def get_module_children_bottom_up(model: torch.nn.Module, return_fqns: b... FILE: src/accelerate/utils/random.py function set_seed (line 40) | def set_seed(seed: int, device_specific: bool = False, deterministic: bo... function synchronize_rng_state (line 81) | def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator:... function synchronize_rng_states (line 163) | def synchronize_rng_states(rng_types: list[Union[str, RNGType]], generat... FILE: src/accelerate/utils/torch_xla.py function install_xla (line 20) | def install_xla(upgrade: bool = False): FILE: src/accelerate/utils/tqdm.py function tqdm (line 25) | def tqdm(*args, main_process_only: bool = True, **kwargs): FILE: src/accelerate/utils/transformer_engine.py function convert_model (line 26) | def convert_model(model, to_transformer_engine=True, _convert_linear=Tru... function has_transformer_engine_layers (line 95) | def has_transformer_engine_layers(model): function contextual_fp8_autocast (line 118) | def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=F... function apply_fp8_autowrap (line 142) | def apply_fp8_autowrap(model, fp8_recipe_handler): FILE: src/accelerate/utils/versions.py function compare_versions (line 26) | def compare_versions(library_or_version: Union[str, Version], operation:... function is_torch_version (line 46) | def is_torch_version(operation: str, version: str): FILE: tests/deepspeed/test_alst_ulysses_sp.py class DeepSpeedALSTUlyssesSPTest (line 29) | class DeepSpeedALSTUlyssesSPTest(TempDirTestCase): method test_deepspeed_alst_ulysses_sp (line 33) | def test_deepspeed_alst_ulysses_sp(self, stage): FILE: tests/deepspeed/test_deepspeed.py function parameterized_custom_name_func (line 90) | def parameterized_custom_name_func(func, param_num, param): class DummyConfig (line 102) | class DummyConfig: method __init__ (line 103) | def __init__(self): class DeepSpeedConfigIntegration (line 109) | class DeepSpeedConfigIntegration(AccelerateTestCase): method setUp (line 110) | def setUp(self): method get_config_dict (line 142) | def get_config_dict(self, stage): method test_deepspeed_plugin (line 147) | def test_deepspeed_plugin(self, stage): method test_accelerate_state_deepspeed (line 247) | def test_accelerate_state_deepspeed(self, dtype): method test_init_zero3 (line 262) | def test_init_zero3(self): method test_prepare_deepspeed (line 281) | def test_prepare_deepspeed(self, optim_type, scheduler_type): method test_dataloader_with_batch_sampler (line 517) | def test_dataloader_with_batch_sampler(self): method test_save_checkpoints (line 559) | def test_save_checkpoints(self): method test_autofill_dsconfig (line 610) | def test_autofill_dsconfig(self): method test_autofill_comm_buffers_dsconfig (line 650) | def test_autofill_comm_buffers_dsconfig(self, model_type): method test_autofill_dsconfig_from_ds_plugin (line 706) | def test_autofill_dsconfig_from_ds_plugin(self, dtype): method test_ds_config_assertions (line 788) | def test_ds_config_assertions(self): method test_ds_zero3_no_init_autofill (line 812) | def test_ds_zero3_no_init_autofill(self): method test_ds_config (line 842) | def test_ds_config(self, stage): method test_prepare_deepspeed_prepare_moe (line 850) | def test_prepare_deepspeed_prepare_moe(self): method test_basic_run (line 876) | def test_basic_run(self): class DeepSpeedIntegrationTest (line 904) | class DeepSpeedIntegrationTest(TempDirTestCase): method setUp (line 907) | def setUp(self): method test_performance (line 934) | def test_performance(self): method test_checkpointing (line 979) | def test_checkpointing(self): method test_peak_memory_usage (line 1034) | def test_peak_memory_usage(self): method test_lr_scheduler (line 1102) | def test_lr_scheduler(self): method test_zero3_integration (line 1127) | def test_zero3_integration(self): FILE: tests/deepspeed/test_deepspeed_gradient_accumulation.py class DeepSpeedGradientAccumulationTest (line 40) | class DeepSpeedGradientAccumulationTest(AccelerateTestCase): method setUp (line 41) | def setUp(self): method test_gradient_accumulation_boundary_integration (line 71) | def test_gradient_accumulation_boundary_integration(self): method test_clip_grad_norm_returns_deepspeed_grad_norm (line 136) | def test_clip_grad_norm_returns_deepspeed_grad_norm(self): method test_accelerator_backward_passes_sync_gradients (line 185) | def test_accelerator_backward_passes_sync_gradients(self): FILE: tests/deepspeed/test_deepspeed_multiple_model.py class DeepSpeedConfigIntegration (line 45) | class DeepSpeedConfigIntegration(AccelerateTestCase): method setUp (line 49) | def setUp(self): method get_ds_plugins (line 80) | def get_ds_plugins(self, zero3_inference=False): method test_select_plugin (line 89) | def test_select_plugin(self): method test_config_reference_update (line 115) | def test_config_reference_update(self): method test_enable_disable_manually_set (line 132) | def test_enable_disable_manually_set(self): method test_multiple_accelerators (line 143) | def test_multiple_accelerators(self): method test_prepare_multiple_models_zero3_inference (line 152) | def test_prepare_multiple_models_zero3_inference(self): method test_train_multiple_models (line 179) | def test_train_multiple_models(self): FILE: tests/fsdp/test_fsdp.py class FSDPPluginIntegration (line 68) | class FSDPPluginIntegration(AccelerateTestCase): method setUp (line 69) | def setUp(self): method test_sharding_strategy (line 90) | def test_sharding_strategy(self): method test_backward_prefetch (line 139) | def test_backward_prefetch(self): method test_state_dict_type (line 181) | def test_state_dict_type(self): method test_auto_wrap_policy (line 214) | def test_auto_wrap_policy(self): method test_mixed_precision (line 291) | def test_mixed_precision(self): method test_mixed_precision_buffer_autocast_override (line 332) | def test_mixed_precision_buffer_autocast_override(self): method test_cpu_offload (line 361) | def test_cpu_offload(self): method test_cpu_ram_efficient_loading (line 388) | def test_cpu_ram_efficient_loading(self): method test_ignored_modules_regex (line 404) | def test_ignored_modules_regex(self): class FSDP2PluginIntegration (line 424) | class FSDP2PluginIntegration(FSDPPluginIntegration): method setUp (line 425) | def setUp(self): method test_param_mapping_error_handling (line 429) | def test_param_mapping_error_handling(self): class FSDPIntegrationTest (line 480) | class FSDPIntegrationTest(TempDirTestCase): method setUp (line 483) | def setUp(self): method test_performance (line 519) | def test_performance(self): method test_checkpointing (line 569) | def test_checkpointing(self): method test_peak_memory_usage (line 621) | def test_peak_memory_usage(self): class FSDP2IntegrationTest (line 678) | class FSDP2IntegrationTest(FSDPIntegrationTest): method setUp (line 679) | def setUp(self): FILE: tests/test_accelerator.py class ModelWithTiedWeights (line 62) | class ModelWithTiedWeights(torch.nn.Module): method __init__ (line 63) | def __init__(self): method forward (line 70) | def forward(self, x): function create_components (line 74) | def create_components(tied_weights=False): class ModelForTest (line 83) | class ModelForTest(torch.nn.Module): method __init__ (line 84) | def __init__(self): method forward (line 90) | def forward(self, x): function create_dataloaders_for_test (line 94) | def create_dataloaders_for_test(batch_size=3, n_train_batches: int = 12,... function get_signature (line 109) | def get_signature(model): function load_random_weights (line 113) | def load_random_weights(model): function parameterized_custom_name_func (line 121) | def parameterized_custom_name_func(func, param_num, param): class AcceleratorTester (line 134) | class AcceleratorTester(AccelerateTestCase): method test_partial_state_after_reset (line 135) | def test_partial_state_after_reset(self): method test_accelerator_state_after_reset (line 156) | def test_accelerator_state_after_reset(self): method test_accelerator_can_be_reinstantiated (line 178) | def test_accelerator_can_be_reinstantiated(self): method test_setting_cpu_affinity (line 186) | def test_setting_cpu_affinity(self): method test_mutable_states (line 193) | def test_mutable_states(self): method test_prepared_objects_are_referenced (line 205) | def test_prepared_objects_are_referenced(self): method test_free_memory_dereferences_prepared_components (line 224) | def test_free_memory_dereferences_prepared_components(self): method test_env_var_device (line 253) | def test_env_var_device(self): method test_save_load_model (line 269) | def test_save_load_model(self, use_safetensors, tied_weights): method test_save_model (line 288) | def test_save_model(self, use_safetensors): method test_save_sharded_model (line 300) | def test_save_sharded_model(self, use_safetensors): method test_save_model_offload (line 316) | def test_save_model_offload(self, use_safetensors): method test_get_state_dict_from_offload (line 337) | def test_get_state_dict_from_offload(self, use_safetensors): method test_save_load_model_with_hooks (line 364) | def test_save_load_model_with_hooks(self, use_safetensors): method test_accelerator_none (line 426) | def test_accelerator_none(self): method test_is_accelerator_prepared (line 438) | def test_is_accelerator_prepared(self): method test_accelerator_bnb (line 470) | def test_accelerator_bnb(self): method test_accelerator_bnb_cpu_error (line 488) | def test_accelerator_bnb_cpu_error(self): method test_accelerator_bnb_multi_device (line 520) | def test_accelerator_bnb_multi_device(self): method test_accelerator_bnb_multi_device_no_distributed (line 557) | def test_accelerator_bnb_multi_device_no_distributed(self): method test_accelerator_cpu_flag_prepare (line 579) | def test_accelerator_cpu_flag_prepare(self): method test_can_unwrap_model_te (line 587) | def test_can_unwrap_model_te(self): method test_can_unwrap_model_fp16 (line 604) | def test_can_unwrap_model_fp16(self): method test_can_unwrap_model (line 621) | def test_can_unwrap_model(self): method test_can_unwrap_distributed_compiled_model_keep_torch_compile (line 635) | def test_can_unwrap_distributed_compiled_model_keep_torch_compile(self): method test_can_unwrap_distributed_compiled_model_remove_torch_compile (line 647) | def test_can_unwrap_distributed_compiled_model_remove_torch_compile(se... method test_can_pickle_dataloader (line 660) | def test_can_pickle_dataloader(self, dispatch_batches): method test_prepared_objects_are_referenced_with_stateful_dataloader (line 707) | def test_prepared_objects_are_referenced_with_stateful_dataloader(self): method test_save_model_with_stateful_dataloader (line 734) | def test_save_model_with_stateful_dataloader(self, use_safetensors, ti... method test_nested_hook (line 814) | def test_nested_hook(self): method test_prepare_model_8bit_cpu_offload_raises_valueerror_not_typeerror (line 875) | def test_prepare_model_8bit_cpu_offload_raises_valueerror_not_typeerro... FILE: tests/test_big_modeling.py class ModelForTest (line 65) | class ModelForTest(nn.Module): method __init__ (line 66) | def __init__(self): method forward (line 72) | def forward(self, x): class LinearWithNonPersistentBuffers (line 76) | class LinearWithNonPersistentBuffers(nn.Module): method __init__ (line 77) | def __init__(self, in_features: int, out_features: int, bias: bool = T... method forward (line 88) | def forward(self, input: torch.Tensor) -> torch.Tensor: class ModelForTestNonPersistentBuffers (line 92) | class ModelForTestNonPersistentBuffers(nn.Module): method __init__ (line 93) | def __init__(self): method forward (line 99) | def forward(self, x): class ModelForTestCopy (line 103) | class ModelForTestCopy(nn.Module): method __init__ (line 104) | def __init__(self, id: int): method forward (line 111) | def forward(self, x): class ModelForTestTiedWeights (line 115) | class ModelForTestTiedWeights(nn.Module): method __init__ (line 116) | def __init__(self): method forward (line 122) | def forward(self, x): class BiggerModelForTest (line 126) | class BiggerModelForTest(nn.Module): method __init__ (line 127) | def __init__(self): method forward (line 135) | def forward(self, x): class ModuleWithUnusedSubModules (line 140) | class ModuleWithUnusedSubModules(nn.Module): method __init__ (line 141) | def __init__(self, input_dim, output_dim): method forward (line 145) | def forward(self, x): class ModelWithUnusedSubModulesForTest (line 149) | class ModelWithUnusedSubModulesForTest(nn.Module): method __init__ (line 150) | def __init__(self): method forward (line 158) | def forward(self, x): class BigModelingTester (line 162) | class BigModelingTester(unittest.TestCase): method test_init_empty_weights (line 163) | def test_init_empty_weights(self): method test_init_empty_weights_very_large_model (line 191) | def test_init_empty_weights_very_large_model(self): method test_init_on_device (line 197) | def test_init_on_device(self): method test_cpu_offload (line 204) | def test_cpu_offload(self): method test_cpu_offload_with_unused_submodules (line 222) | def test_cpu_offload_with_unused_submodules(self): method test_cpu_offload_gpt2 (line 247) | def test_cpu_offload_gpt2(self): method test_disk_offload (line 256) | def test_disk_offload(self): method test_disk_offload_with_unused_submodules (line 276) | def test_disk_offload_with_unused_submodules(self): method test_disk_offload_gpt2 (line 306) | def test_disk_offload_gpt2(self): method test_dispatch_model_and_remove_hook (line 317) | def test_dispatch_model_and_remove_hook(self): method test_dispatch_model (line 343) | def test_dispatch_model(self): method test_dispatch_model_with_non_persistent_buffers (line 356) | def test_dispatch_model_with_non_persistent_buffers(self): method test_dispatch_model_tied_weights (line 368) | def test_dispatch_model_tied_weights(self): method test_dispatch_model_tied_weights_memory (line 377) | def test_dispatch_model_tied_weights_memory(self): method test_dispatch_model_tied_weights_memory_with_nested_offload_cpu (line 442) | def test_dispatch_model_tied_weights_memory_with_nested_offload_cpu(se... method test_dispatch_model_tied_weights_memory_with_nested_offload_disk (line 543) | def test_dispatch_model_tied_weights_memory_with_nested_offload_disk(s... method test_dispatch_model_multi_devices (line 649) | def test_dispatch_model_multi_devices(self): method test_dispatch_model_copy (line 663) | def test_dispatch_model_copy(self): method test_dispatch_model_move_offloaded_model (line 682) | def test_dispatch_model_move_offloaded_model(self): method test_dispatch_model_move_model_warning (line 692) | def test_dispatch_model_move_model_warning(self): method test_dispatch_model_gpt2_on_two_devices (line 708) | def test_dispatch_model_gpt2_on_two_devices(self): method test_dispatch_model_with_unused_submodules (line 749) | def test_dispatch_model_with_unused_submodules(self): method test_dispatch_model_with_unused_submodules_multi_device (line 765) | def test_dispatch_model_with_unused_submodules_multi_device(self): method test_dispatch_model_force_hooks (line 781) | def test_dispatch_model_force_hooks(self): method test_load_checkpoint_and_dispatch (line 793) | def test_load_checkpoint_and_dispatch(self): method test_load_checkpoint_and_dispatch_device_map_none (line 814) | def test_load_checkpoint_and_dispatch_device_map_none(self): method test_load_checkpoint_and_dispatch_multi_device (line 833) | def test_load_checkpoint_and_dispatch_multi_device(self): method test_load_checkpoint_and_dispatch_with_unused_submodules (line 858) | def test_load_checkpoint_and_dispatch_with_unused_submodules(self): method test_load_checkpoint_and_dispatch_multi_device_with_unused_submodules (line 885) | def test_load_checkpoint_and_dispatch_multi_device_with_unused_submodu... method test_cpu_offload_with_hook (line 912) | def test_cpu_offload_with_hook(self): method test_dispatch_model_bnb (line 946) | def test_dispatch_model_bnb(self): method test_dispatch_model_int8_simple (line 977) | def test_dispatch_model_int8_simple(self): method test_dipatch_model_fp4_simple (line 1040) | def test_dipatch_model_fp4_simple(self): FILE: tests/test_cli.py class AccelerateLauncherTester (line 42) | class AccelerateLauncherTester(unittest.TestCase): method setUpClass (line 61) | def setUpClass(cls): method tearDownClass (line 66) | def tearDownClass(cls): method test_no_config (line 71) | def test_no_config(self): method test_config_compatibility (line 81) | def test_config_compatibility(self): method test_invalid_keys (line 91) | def test_invalid_keys(self): method test_accelerate_test (line 101) | def test_accelerate_test(self): method test_notebook_launcher (line 108) | def test_notebook_launcher(self): method test_mpi_multicpu_config_cmd (line 117) | def test_mpi_multicpu_config_cmd(self): method test_validate_launch_command (line 148) | def test_validate_launch_command(self): class LaunchArgTester (line 171) | class LaunchArgTester(unittest.TestCase): method test_hyphen (line 178) | def test_hyphen(self): method test_underscore (line 197) | def test_underscore(self): method test_duplicate_entities (line 215) | def test_duplicate_entities(self): class ClusterConfigTester (line 228) | class ClusterConfigTester(unittest.TestCase): method test_base_config (line 235) | def test_base_config(self): method test_cluster_config (line 250) | def test_cluster_config(self): method test_sagemaker_config (line 281) | def test_sagemaker_config(self): class TpuConfigTester (line 299) | class TpuConfigTester(unittest.TestCase): method setUp (line 312) | def setUp(self): method test_base (line 315) | def test_base(self): method test_base_backward_compatibility (line 322) | def test_base_backward_compatibility(self): method test_with_config_file (line 339) | def test_with_config_file(self): method test_with_config_file_and_command (line 347) | def test_with_config_file_and_command(self): method test_with_config_file_and_multiple_command (line 354) | def test_with_config_file_and_multiple_command(self): method test_with_config_file_and_command_file (line 372) | def test_with_config_file_and_command_file(self): method test_with_config_file_and_command_file_backward_compatibility (line 382) | def test_with_config_file_and_command_file_backward_compatibility(self): method test_accelerate_install (line 402) | def test_accelerate_install(self): method test_accelerate_install_version (line 412) | def test_accelerate_install_version(self): class ModelEstimatorTester (line 430) | class ModelEstimatorTester(unittest.TestCase): method test_invalid_model_name (line 440) | def test_invalid_model_name(self): method test_invalid_model_name_timm (line 446) | def test_invalid_model_name_timm(self): method test_invalid_model_name_transformers (line 452) | def test_invalid_model_name_transformers(self): method test_no_metadata (line 457) | def test_no_metadata(self): method test_gated (line 464) | def test_gated(self): method test_remote_code (line 474) | def test_remote_code(self): method test_explicit_dtypes (line 485) | def test_explicit_dtypes(self): method test_transformers_model (line 513) | def test_transformers_model(self): method test_no_split_modules (line 526) | def test_no_split_modules(self): method test_timm_model (line 536) | def test_timm_model(self): class ToFSDP2Tester (line 549) | class ToFSDP2Tester(unittest.TestCase): method setUpClass (line 558) | def setUpClass(cls): method tearDownClass (line 563) | def tearDownClass(cls): method tearDown (line 567) | def tearDown(self): method test_nonexistent_config_file (line 571) | def test_nonexistent_config_file(self): method test_no_output_without_overwrite (line 576) | def test_no_output_without_overwrite(self): method test_overwrite_when_output_file_exists (line 582) | def test_overwrite_when_output_file_exists(self, mock_exists): method test_fsdp2_config (line 595) | def test_fsdp2_config(self): method test_config_already_fsdp2 (line 610) | def test_config_already_fsdp2(self): method test_fsdp2_overwrite (line 629) | def test_fsdp2_overwrite(self): FILE: tests/test_compile.py class RegionalCompilationTester (line 40) | class RegionalCompilationTester(unittest.TestCase): method _get_model_and_inputs (line 41) | def _get_model_and_inputs(self): method test_regions_are_compiled (line 51) | def test_regions_are_compiled(self): method test_extract_model_keep_torch_compile (line 65) | def test_extract_model_keep_torch_compile(self): method test_extract_model_remove_torch_compile (line 75) | def test_extract_model_remove_torch_compile(self): method test_regional_compilation_cold_start (line 87) | def test_regional_compilation_cold_start(self): method test_regional_compilation_inference_speedup (line 116) | def test_regional_compilation_inference_speedup(self): FILE: tests/test_cpu.py class MultiCPUTester (line 22) | class MultiCPUTester(unittest.TestCase): method test_cpu (line 23) | def test_cpu(self): method test_ops (line 26) | def test_ops(self): FILE: tests/test_data_loader.py function parameterized_custom_name_func (line 46) | def parameterized_custom_name_func(func, param_num, param): class RandomIterableDataset (line 53) | class RandomIterableDataset(IterableDataset): method __init__ (line 55) | def __init__(self, p_stop=0.01, max_length=1000): method __iter__ (line 59) | def __iter__(self): class SimpleIterableDataset (line 68) | class SimpleIterableDataset(IterableDataset): method __init__ (line 69) | def __init__(self, num_samples=1000): method __iter__ (line 72) | def __iter__(self): method __len__ (line 76) | def __len__(self): method set_epoch (line 79) | def set_epoch(self, epoch): class SimpleBatchSampler (line 83) | class SimpleBatchSampler(BatchSampler): method __init__ (line 84) | def __init__(self, sampler, batch_size, drop_last, generator, seed): method __iter__ (line 90) | def __iter__(self): method set_epoch (line 94) | def set_epoch(self, epoch): class DataLoaderTester (line 98) | class DataLoaderTester(AccelerateTestCase): method check_batch_sampler_shards (line 99) | def check_batch_sampler_shards(self, batch_sampler, expected, split_ba... method test_batch_sampler_shards_with_no_splits (line 109) | def test_batch_sampler_shards_with_no_splits(self): method test_batch_sampler_shards_with_splits (line 178) | def test_batch_sampler_shards_with_splits(self): method test_batch_sampler_shards_with_no_splits_no_even (line 230) | def test_batch_sampler_shards_with_no_splits_no_even(self): method test_batch_sampler_shards_with_splits_no_even (line 299) | def test_batch_sampler_shards_with_splits_no_even(self): method test_batch_sampler_with_varying_batch_size (line 351) | def test_batch_sampler_with_varying_batch_size(self): method check_iterable_dataset_shards (line 361) | def check_iterable_dataset_shards( method test_iterable_dataset_shard (line 401) | def test_iterable_dataset_shard(self): method test_iterable_dataset_using_none_batch_size (line 418) | def test_iterable_dataset_using_none_batch_size(self): method test_iterable_dataset_with_non_tensor_samples (line 425) | def test_iterable_dataset_with_non_tensor_samples(self): method test_reproducibility (line 442) | def test_reproducibility(self, num_processes): method test_skip_batch_sampler (line 471) | def test_skip_batch_sampler(self): method test_dataloader_inheritance (line 476) | def test_dataloader_inheritance(self): method test_skip_data_loader (line 506) | def test_skip_data_loader(self): method test_skip_first_batches (line 510) | def test_skip_first_batches(self): method test_end_of_dataloader (line 515) | def test_end_of_dataloader(self): method test_end_of_dataloader_dispatcher (line 524) | def test_end_of_dataloader_dispatcher(self): method test_set_epoch_in_batch_sampler (line 533) | def test_set_epoch_in_batch_sampler(self): method test_iterable_dataset_native_sharding_when_n_shards_equals_num_processes (line 548) | def test_iterable_dataset_native_sharding_when_n_shards_equals_num_pro... method test_ensure_dataloader_gets_cleaned_up (line 561) | def test_ensure_dataloader_gets_cleaned_up(self): class StatefulDataLoaderTester (line 591) | class StatefulDataLoaderTester(AccelerateTestCase): method test_skip_data_loader (line 593) | def test_skip_data_loader(self): method test_end_of_dataloader (line 599) | def test_end_of_dataloader(self): method test_end_of_dataloader_dispatcher (line 611) | def test_end_of_dataloader_dispatcher(self): method test_dataloader_state_dict (line 623) | def test_dataloader_state_dict(self, num_workers): method test_dataloader_dispatcher_state_dict (line 650) | def test_dataloader_dispatcher_state_dict(self, num_workers): method test_dataloader_inheritance (line 677) | def test_dataloader_inheritance(self): method test_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader (line 705) | def test_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_... method test_decoupled_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader (line 815) | def test_decoupled_stateful_dataloader_adapter_equivalent_to_torchdata... FILE: tests/test_dataclasses.py function _should_skip_cp_test (line 31) | def _should_skip_cp_test(cp_size): function _should_skip_sp_test (line 36) | def _should_skip_sp_test(sp_size): function _should_skip_tp_test (line 45) | def _should_skip_tp_test(tp_size): class TestParallelismConfig (line 62) | class TestParallelismConfig: method mock_init_device_mesh (line 64) | def mock_init_device_mesh(self): method test_get_mesh (line 107) | def test_get_mesh( method test_build_device_mesh (line 146) | def test_build_device_mesh( method test_from_env (line 199) | def test_from_env( method test_cp_torch_handler (line 225) | def test_cp_torch_handler(self): method test_sp_deepspeed_handler (line 259) | def test_sp_deepspeed_handler(self): method test_tp_handler (line 276) | def test_tp_handler(self): FILE: tests/test_examples.py class ExampleDifferenceTests (line 70) | class ExampleDifferenceTests(unittest.TestCase): method one_complete_example (line 93) | def one_complete_example( method test_nlp_examples (line 134) | def test_nlp_examples(self): method test_cv_examples (line 138) | def test_cv_examples(self): class FeatureExamplesTests (line 158) | class FeatureExamplesTests(TempDirTestCase): method setUpClass (line 162) | def setUpClass(cls): method tearDownClass (line 171) | def tearDownClass(cls): method test_checkpointing_by_epoch (line 175) | def test_checkpointing_by_epoch(self): method test_checkpointing_by_steps (line 184) | def test_checkpointing_by_steps(self): method test_load_states_by_epoch (line 193) | def test_load_states_by_epoch(self): method test_load_states_by_steps (line 202) | def test_load_states_by_steps(self): method test_cross_validation (line 225) | def test_cross_validation(self): method test_multi_process_metrics (line 237) | def test_multi_process_metrics(self): method test_schedulefree (line 242) | def test_schedulefree(self): method test_tracking (line 251) | def test_tracking(self): method test_gradient_accumulation (line 260) | def test_gradient_accumulation(self): method test_gradient_accumulation_for_autoregressive_models (line 264) | def test_gradient_accumulation_for_autoregressive_models(self): method test_local_sgd (line 272) | def test_local_sgd(self): method test_early_stopping (line 276) | def test_early_stopping(self): method test_profiler (line 280) | def test_profiler(self): method test_ddp_comm_hook (line 286) | def test_ddp_comm_hook(self): method test_distributed_inference_examples_stable_diffusion (line 292) | def test_distributed_inference_examples_stable_diffusion(self): method test_distributed_inference_examples_phi2 (line 298) | def test_distributed_inference_examples_phi2(self): method test_pippy_examples_bert (line 305) | def test_pippy_examples_bert(self): method test_pippy_examples_gpt2 (line 312) | def test_pippy_examples_gpt2(self): FILE: tests/test_fp8.py function can_convert_te_model (line 46) | def can_convert_te_model(from_config=False): function maintain_proper_deepspeed_config (line 64) | def maintain_proper_deepspeed_config(expected_version): function can_convert_ao_model (line 70) | def can_convert_ao_model(from_config=False): class TestTransformerEngine (line 91) | class TestTransformerEngine(unittest.TestCase): method test_can_prepare_model_single_gpu (line 92) | def test_can_prepare_model_single_gpu(self): method test_can_prepare_model_single_gpu_from_config (line 97) | def test_can_prepare_model_single_gpu_from_config(self): method test_can_prepare_model_with_mxfp8_block_scaling (line 116) | def test_can_prepare_model_with_mxfp8_block_scaling(self): method test_can_prepare_model_multi_gpu (line 136) | def test_can_prepare_model_multi_gpu(self): method test_can_prepare_model_multigpu_deepspeed (line 143) | def test_can_prepare_model_multigpu_deepspeed(self): method test_can_prepare_model_multigpu_deepspeed_from_config (line 175) | def test_can_prepare_model_multigpu_deepspeed_from_config(self): class TestTorchAO (line 205) | class TestTorchAO(unittest.TestCase): method test_can_prepare_model_single_accelerator (line 206) | def test_can_prepare_model_single_accelerator(self): method test_can_prepare_model_single_gpu_from_config (line 211) | def test_can_prepare_model_single_gpu_from_config(self): method test_can_prepare_model_single_gpu_from_config_with_additional_params (line 229) | def test_can_prepare_model_single_gpu_from_config_with_additional_para... method test_can_prepare_model_multi_accelerator (line 250) | def test_can_prepare_model_multi_accelerator(self): method test_can_prepare_model_multi_accelerator_deepspeed (line 257) | def test_can_prepare_model_multi_accelerator_deepspeed(self): FILE: tests/test_grad_sync.py class SyncScheduler (line 31) | class SyncScheduler(AccelerateTestCase): method test_gradient_sync_cpu_noop (line 35) | def test_gradient_sync_cpu_noop(self): method test_gradient_sync_cpu_multi (line 39) | def test_gradient_sync_cpu_multi(self): method test_gradient_sync_gpu (line 43) | def test_gradient_sync_gpu(self): method test_gradient_sync_gpu_multi (line 48) | def test_gradient_sync_gpu_multi(self): FILE: tests/test_hooks.py class ModelForTest (line 44) | class ModelForTest(nn.Module): method __init__ (line 45) | def __init__(self): method forward (line 51) | def forward(self, x): class PreForwardHook (line 55) | class PreForwardHook(ModelHook): method pre_forward (line 56) | def pre_forward(self, module, *args, **kwargs): class PostForwardHook (line 60) | class PostForwardHook(ModelHook): method post_forward (line 61) | def post_forward(self, module, output): class HooksModelTester (line 65) | class HooksModelTester(unittest.TestCase): method check_dtype_for_layerwise_upcasting (line 66) | def check_dtype_for_layerwise_upcasting( method test_add_and_remove_hooks (line 94) | def test_add_and_remove_hooks(self): method test_append_and_remove_hooks (line 110) | def test_append_and_remove_hooks(self): method test_pre_forward_hook_is_executed (line 129) | def test_pre_forward_hook_is_executed(self): method test_post_forward_hook_is_executed (line 153) | def test_post_forward_hook_is_executed(self): method test_no_grad_in_hook (line 176) | def test_no_grad_in_hook(self): method test_align_devices_as_model_parallelism (line 193) | def test_align_devices_as_model_parallelism(self): method test_align_devices_as_cpu_offload (line 221) | def test_align_devices_as_cpu_offload(self): method test_attach_align_device_hook_as_cpu_offload (line 285) | def test_attach_align_device_hook_as_cpu_offload(self): method test_attach_align_device_hook_as_cpu_offload_with_weight_map (line 334) | def test_attach_align_device_hook_as_cpu_offload_with_weight_map(self): method test_add_remove_hook_fx_graph_module (line 391) | def test_add_remove_hook_fx_graph_module(self): method test_layerwise_upcasting_inference (line 446) | def test_layerwise_upcasting_inference(self, storage_dtype, compute_dt... method test_cpu_offload_hook_moves_model (line 464) | def test_cpu_offload_hook_moves_model(self): method test_cpu_offload_hook_with_prev_module (line 486) | def test_cpu_offload_hook_with_prev_module(self): FILE: tests/test_imports.py function convert_list_to_string (line 27) | def convert_list_to_string(data): function run_import_time (line 35) | def run_import_time(command: str): class ImportSpeedTester (line 41) | class ImportSpeedTester(TempDirTestCase): method setUpClass (line 56) | def setUpClass(cls): method test_base_import (line 63) | def test_base_import(self): method test_cli_import (line 75) | def test_cli_import(self): class LazyImportTester (line 89) | class LazyImportTester(TempDirTestCase): method test_te_import (line 97) | def test_te_import(self): FILE: tests/test_kwargs_handlers.py class MockClass (line 44) | class MockClass(KwargsHandler): class KwargsHandlerTester (line 50) | class KwargsHandlerTester(AccelerateTestCase): method test_kwargs_handler (line 51) | def test_kwargs_handler(self): method test_grad_scaler_kwargs (line 60) | def test_grad_scaler_kwargs(self): method test_ddp_kwargs (line 79) | def test_ddp_kwargs(self): method test_autocast_kwargs (line 85) | def test_autocast_kwargs(self): method test_profile_kwargs (line 109) | def test_profile_kwargs(self): method test_torch_dynamo_plugin (line 154) | def test_torch_dynamo_plugin(self): method test_ddp_comm_hook (line 168) | def test_ddp_comm_hook(self): function main (line 173) | def main(): FILE: tests/test_launch.py class TestPrepareMultiGpuEnv (line 21) | class TestPrepareMultiGpuEnv(unittest.TestCase): method test_auto_port_selection (line 22) | def test_auto_port_selection(self): FILE: tests/test_load_checkpoint_and_dispatch_with_broadcast.py function manage_process_group (line 47) | def manage_process_group(func: Callable[..., Any]) -> Callable[..., Any]: function load_checkpoint_and_dispatch_fsdp2 (line 69) | def load_checkpoint_and_dispatch_fsdp2(): function load_checkpoint_and_dispatch_no_broadcast_from_rank0 (line 121) | def load_checkpoint_and_dispatch_no_broadcast_from_rank0(): function load_checkpoint_and_dispatch_ddp (line 161) | def load_checkpoint_and_dispatch_ddp(): class TestLoadCheckpointAndDispatchWithBroadcast (line 196) | class TestLoadCheckpointAndDispatchWithBroadcast(unittest.TestCase): method setUp (line 197) | def setUp(self): method test_load_checkpoint_and_dispatch_fsdp2 (line 200) | def test_load_checkpoint_and_dispatch_fsdp2(self): method test_load_checkpoint_and_dispatch_no_broadcast_from_rank0 (line 212) | def test_load_checkpoint_and_dispatch_no_broadcast_from_rank0(self): method test_load_checkpoint_and_dispatch_ddp (line 224) | def test_load_checkpoint_and_dispatch_ddp(self): class CLIArgs (line 242) | class CLIArgs(argparse.Namespace): FILE: tests/test_logging.py function current_lineno (line 25) | def current_lineno() -> int: class CustomLogger (line 32) | class CustomLogger(logging.LoggerAdapter): method log (line 34) | def log(self, level, msg, *args, **kwargs): function accelerator (line 44) | def accelerator(): function test_log_stack (line 51) | def test_log_stack(caplog): function test_custom_stacklevel (line 74) | def test_custom_stacklevel(caplog): FILE: tests/test_memory_utils.py function raise_fake_out_of_memory (line 28) | def raise_fake_out_of_memory(): class ModelForTest (line 32) | class ModelForTest(nn.Module): method __init__ (line 33) | def __init__(self): method forward (line 39) | def forward(self, x): class BigModelForTest (line 43) | class BigModelForTest(ModelForTest): method __init__ (line 44) | def __init__(self): method forward (line 48) | def forward(self, x): class MemoryTest (line 52) | class MemoryTest(unittest.TestCase): method test_memory_implicit (line 53) | def test_memory_implicit(self): method test_memory_explicit (line 90) | def test_memory_explicit(self): method test_start_zero (line 129) | def test_start_zero(self): method test_approach_zero (line 138) | def test_approach_zero(self): method test_verbose_guard (line 149) | def test_verbose_guard(self): method test_any_other_error (line 160) | def test_any_other_error(self): method test_release_memory (line 171) | def test_release_memory(self): FILE: tests/test_metrics.py class MetricTester (line 37) | class MetricTester(unittest.TestCase): method setUp (line 38) | def setUp(self): method test_metric_cpu_noop (line 46) | def test_metric_cpu_noop(self): method test_metric_cpu_multi (line 50) | def test_metric_cpu_multi(self): method test_metric_accelerator (line 54) | def test_metric_accelerator(self): method test_metric_accelerator_multi (line 59) | def test_metric_accelerator_multi(self): FILE: tests/test_modeling_utils.py class ModelForTest (line 61) | class ModelForTest(nn.Module): method __init__ (line 62) | def __init__(self): method forward (line 68) | def forward(self, x): class NestedModelForTest (line 72) | class NestedModelForTest(nn.Module): method __init__ (line 73) | def __init__(self): method forward (line 77) | def forward(self, x): class LinearWithNonPersistentBuffers (line 81) | class LinearWithNonPersistentBuffers(nn.Module): method __init__ (line 82) | def __init__(self, in_features: int, out_features: int, bias: bool = T... method forward (line 93) | def forward(self, input: torch.Tensor) -> torch.Tensor: class ModelSeveralDtypes (line 97) | class ModelSeveralDtypes(nn.Module): method __init__ (line 98) | def __init__(self): method forward (line 103) | def forward(self, x): function sequential_model (line 107) | def sequential_model(num_layers): class ModelingUtilsTester (line 112) | class ModelingUtilsTester(unittest.TestCase): method check_set_module_tensor_for_device (line 113) | def check_set_module_tensor_for_device(self, model, device1, device2): method test_set_module_tensor_to_meta_and_cpu (line 172) | def test_set_module_tensor_to_meta_and_cpu(self): method test_set_module_tensor_to_cpu_and_gpu (line 177) | def test_set_module_tensor_to_cpu_and_gpu(self): method test_set_module_tensor_to_meta_and_gpu (line 182) | def test_set_module_tensor_to_meta_and_gpu(self): method test_set_module_tensor_between_gpus (line 188) | def test_set_module_tensor_between_gpus(self): method test_set_module_tensor_sets_dtype (line 192) | def test_set_module_tensor_sets_dtype(self): method test_set_module_tensor_checks_shape (line 197) | def test_set_module_tensor_checks_shape(self): method test_named_tensors (line 207) | def test_named_tensors(self): method test_find_tied_parameters (line 256) | def test_find_tied_parameters(self): method test_retie_parameters (line 285) | def test_retie_parameters(self): method test_compute_module_sizes (line 310) | def test_compute_module_sizes(self): method test_compute_module_total_buffer_size (line 333) | def test_compute_module_total_buffer_size(self): method test_check_device_map (line 345) | def test_check_device_map(self): method test_check_device_map_invalid_keys (line 353) | def test_check_device_map_invalid_keys(self): method shard_test_model (line 373) | def shard_test_model(self, model, tmp_dir): method test_load_checkpoint_in_model (line 392) | def test_load_checkpoint_in_model(self): method test_load_checkpoint_in_model_one_gpu (line 414) | def test_load_checkpoint_in_model_one_gpu(self): method test_load_checkpoint_in_model_disk_offload (line 449) | def test_load_checkpoint_in_model_disk_offload(self): method test_load_checkpoint_in_model_two_gpu (line 475) | def test_load_checkpoint_in_model_two_gpu(self): method test_load_checkpoint_in_model_dtype (line 509) | def test_load_checkpoint_in_model_dtype(self): method test_load_checkpoint_in_model_unexpected_keys (line 523) | def test_load_checkpoint_in_model_unexpected_keys(self, device_map: Op... method test_clean_device_map (line 541) | def test_clean_device_map(self): method test_infer_auto_device_map (line 554) | def test_infer_auto_device_map(self): method test_infer_auto_device_map_with_tied_weights (line 590) | def test_infer_auto_device_map_with_tied_weights(self): method test_infer_auto_device_map_on_t0pp (line 675) | def test_infer_auto_device_map_on_t0pp(self): method test_infer_auto_device_map_with_buffer_check (line 698) | def test_infer_auto_device_map_with_buffer_check(self): method test_infer_auto_device_map_with_buffer_check_and_multi_devices (line 721) | def test_infer_auto_device_map_with_buffer_check_and_multi_devices(self): method test_infer_auto_device_map_with_fallback_allocation (line 754) | def test_infer_auto_device_map_with_fallback_allocation(self): method test_infer_auto_device_map_with_fallback_allocation_no_fit (line 788) | def test_infer_auto_device_map_with_fallback_allocation_no_fit(self): method test_infer_auto_device_map_with_fallback_allocation_partial_fit (line 813) | def test_infer_auto_device_map_with_fallback_allocation_partial_fit(se... method test_infer_auto_device_map_with_fallback_allocation_tied_weights (line 833) | def test_infer_auto_device_map_with_fallback_allocation_tied_weights(s... method test_infer_auto_device_map_with_fallback_allocation_and_buffers (line 852) | def test_infer_auto_device_map_with_fallback_allocation_and_buffers(se... method test_get_balanced_memory (line 880) | def test_get_balanced_memory(self): method test_get_module_size_with_ties (line 912) | def test_get_module_size_with_ties(self): method test_load_state_dict (line 954) | def test_load_state_dict(self): method test_convert_file_size (line 969) | def test_convert_file_size(self): method test_get_state_dict_offloaded_model (line 1000) | def test_get_state_dict_offloaded_model(self): method test_align_module_device_simple (line 1013) | def test_align_module_device_simple(self): method test_align_module_device_offloaded (line 1036) | def test_align_module_device_offloaded(self): method test_align_module_device_offloaded_nested (line 1060) | def test_align_module_device_offloaded_nested(self): method test_extract_model_from_parallel_partial_compile (line 1070) | def test_extract_model_from_parallel_partial_compile(self): FILE: tests/test_multidevice.py class MultiDeviceTester (line 41) | class MultiDeviceTester(unittest.TestCase): method test_multi_device (line 50) | def test_multi_device(self): method test_multi_device_ops (line 58) | def test_multi_device_ops(self): method test_pad_across_processes (line 66) | def test_pad_across_processes(self): method test_multi_device_merge_fsdp_weights (line 74) | def test_multi_device_merge_fsdp_weights(self): method test_distributed_data_loop (line 85) | def test_distributed_data_loop(self): method test_pippy (line 114) | def test_pippy(self): class ModelForTest (line 156) | class ModelForTest(torch.nn.Module): method __init__ (line 157) | def __init__(self): method forward (line 163) | def forward(self, x): FILE: tests/test_offload.py class ModelForTest (line 31) | class ModelForTest(nn.Module): method __init__ (line 32) | def __init__(self): method forward (line 38) | def forward(self, x): class OffloadTester (line 42) | class OffloadTester(unittest.TestCase): method test_offload_state_dict (line 43) | def test_offload_state_dict(self): method test_offload_weight (line 56) | def test_offload_weight(self): method test_offload_weights_loader (line 70) | def test_offload_weights_loader(self): method test_extract_submodules_state_dict (line 107) | def test_extract_submodules_state_dict(self): FILE: tests/test_optimizer.py class CPUOptimizerTester (line 25) | class CPUOptimizerTester(AccelerateTestCase): method test_accelerated_optimizer_pickling (line 26) | def test_accelerated_optimizer_pickling(self): class OptimizerTester (line 39) | class OptimizerTester(AccelerateTestCase): method test_accelerated_optimizer_step_was_skipped (line 40) | def test_accelerated_optimizer_step_was_skipped(self): FILE: tests/test_quantization.py class BitsAndBytesConfigIntegration (line 36) | class BitsAndBytesConfigIntegration(unittest.TestCase): method test_BnbQuantizationConfig (line 37) | def test_BnbQuantizationConfig(self): class MixedInt8EmptyModelTest (line 47) | class MixedInt8EmptyModelTest(AccelerateTestCase): method setUp (line 62) | def setUp(self): method tearDown (line 93) | def tearDown(self): method test_memory_footprint (line 103) | def test_memory_footprint(self): method test_linear_are_8bit (line 116) | def test_linear_are_8bit(self): method test_llm_skip (line 133) | def test_llm_skip(self): method check_inference_correctness (line 161) | def check_inference_correctness(self, model): method test_generate_quality (line 177) | def test_generate_quality(self): method test_fp32_8bit_conversion (line 180) | def test_fp32_8bit_conversion(self): method test_cpu_gpu_loading_custom_device_map (line 202) | def test_cpu_gpu_loading_custom_device_map(self): method test_cpu_gpu_loading_custom_device_map_offload_state_dict (line 257) | def test_cpu_gpu_loading_custom_device_map_offload_state_dict(self): method test_cpu_gpu_disk_loading_custom_device_map_kwargs (line 314) | def test_cpu_gpu_disk_loading_custom_device_map_kwargs(self): method test_int8_serialization (line 372) | def test_int8_serialization(self): method test_int8_serialization_offload (line 405) | def test_int8_serialization_offload(self): method test_int8_serialization_shard (line 465) | def test_int8_serialization_shard(self): class MixedInt8LoaddedModelTest (line 504) | class MixedInt8LoaddedModelTest(unittest.TestCase): method setUp (line 519) | def setUp(self): method tearDown (line 537) | def tearDown(self): method test_memory_footprint (line 547) | def test_memory_footprint(self): method test_linear_are_8bit (line 560) | def test_linear_are_8bit(self): method test_generate_quality (line 577) | def test_generate_quality(self): method test_fp32_8bit_conversion (line 591) | def test_fp32_8bit_conversion(self): class Bnb4BitEmptyModelTest (line 609) | class Bnb4BitEmptyModelTest(unittest.TestCase): method setUp (line 626) | def setUp(self): method tearDown (line 655) | def tearDown(self): method test_memory_footprint (line 666) | def test_memory_footprint(self): method check_inference_correctness (line 679) | def check_inference_correctness(self, model): method test_generate_quality (line 693) | def test_generate_quality(self): method test_linear_are_4bit (line 696) | def test_linear_are_4bit(self): method test_fp32_4bit_conversion (line 715) | def test_fp32_4bit_conversion(self): method test_cpu_gpu_loading_random_device_map (line 737) | def test_cpu_gpu_loading_random_device_map(self): method test_cpu_gpu_loading_custom_device_map (line 790) | def test_cpu_gpu_loading_custom_device_map(self): method test_cpu_gpu_disk_loading_custom_device_map_kwargs (line 820) | def test_cpu_gpu_disk_loading_custom_device_map_kwargs(self): class Bnb4BitTestLoadedModel (line 858) | class Bnb4BitTestLoadedModel(unittest.TestCase): method setUp (line 875) | def setUp(self): method tearDown (line 895) | def tearDown(self): method test_memory_footprint (line 906) | def test_memory_footprint(self): method test_linear_are_4bit (line 919) | def test_linear_are_4bit(self): method test_generate_quality (line 938) | def test_generate_quality(self): method test_fp32_4bit_conversion (line 952) | def test_fp32_4bit_conversion(self): FILE: tests/test_sagemaker.py class MockLaunchConfig (line 25) | class MockLaunchConfig(SageMakerConfig): class SageMakerLaunch (line 65) | class SageMakerLaunch(unittest.TestCase): method test_args_convert (line 66) | def test_args_convert(self): FILE: tests/test_scheduler.py function one_cycle_test (line 26) | def one_cycle_test(num_processes=2, step_scheduler_with_optimizer=True, ... function lambda_test (line 45) | def lambda_test(num_processes=2, step_scheduler_with_optimizer=True, spl... function accumulation_test (line 70) | def accumulation_test(num_processes: int = 2): class SchedulerTester (line 105) | class SchedulerTester(unittest.TestCase): method test_lambda_scheduler_steps_with_optimizer_single_process (line 106) | def test_lambda_scheduler_steps_with_optimizer_single_process(self): method test_one_cycle_scheduler_steps_with_optimizer_single_process (line 110) | def test_one_cycle_scheduler_steps_with_optimizer_single_process(self): method test_lambda_scheduler_not_step_with_optimizer_single_process (line 114) | def test_lambda_scheduler_not_step_with_optimizer_single_process(self): method test_one_cycle_scheduler_not_step_with_optimizer_single_process (line 117) | def test_one_cycle_scheduler_not_step_with_optimizer_single_process(se... method test_lambda_scheduler_steps_with_optimizer_multiprocess (line 120) | def test_lambda_scheduler_steps_with_optimizer_multiprocess(self): method test_one_cycle_scheduler_steps_with_optimizer_multiprocess (line 125) | def test_one_cycle_scheduler_steps_with_optimizer_multiprocess(self): method test_lambda_scheduler_not_step_with_optimizer_multiprocess (line 130) | def test_lambda_scheduler_not_step_with_optimizer_multiprocess(self): method test_one_cycle_scheduler_not_step_with_optimizer_multiprocess (line 134) | def test_one_cycle_scheduler_not_step_with_optimizer_multiprocess(self): method test_accumulation (line 139) | def test_accumulation(self): FILE: tests/test_state_checkpointing.py function dummy_dataloaders (line 45) | def dummy_dataloaders(a=2, b=3, batch_size=16, n_train_batches: int = 10... function train (line 59) | def train(num_epochs, model, dataloader, optimizer, accelerator, schedul... class DummyModel (line 78) | class DummyModel(nn.Module): method __init__ (line 81) | def __init__(self): method forward (line 86) | def forward(self, x): function parameterized_custom_name_func (line 90) | def parameterized_custom_name_func(func, param_num, param): class CheckpointTest (line 98) | class CheckpointTest(AccelerateTestCase): method check_adam_state (line 99) | def check_adam_state(self, state1, state2, distributed_type): method test_with_save_limit (line 108) | def test_with_save_limit(self): method test_can_resume_training_with_folder (line 127) | def test_can_resume_training_with_folder(self): method test_can_resume_training (line 180) | def test_can_resume_training(self): method test_can_resume_training_checkpoints_relative_path (line 232) | def test_can_resume_training_checkpoints_relative_path(self): method test_invalid_registration (line 298) | def test_invalid_registration(self): method test_with_scheduler (line 312) | def test_with_scheduler(self): method test_automatic_loading (line 335) | def test_automatic_loading(self): method test_checkpoint_deletion (line 364) | def test_checkpoint_deletion(self): method test_map_location (line 382) | def test_map_location(self): FILE: tests/test_tpu.py class MultiTPUTester (line 22) | class MultiTPUTester(unittest.TestCase): method test_tpu (line 27) | def test_tpu(self): FILE: tests/test_tracking.py class TensorBoardTrackingTest (line 89) | class TensorBoardTrackingTest(unittest.TestCase): method test_init_trackers (line 91) | def test_init_trackers(self): method test_log (line 102) | def test_log(self): method test_log_with_tensor (line 115) | def test_log_with_tensor(self): method test_project_dir (line 144) | def test_project_dir(self): method test_project_dir_with_config (line 150) | def test_project_dir_with_config(self): class WandBTrackingTest (line 158) | class WandBTrackingTest(TempDirTestCase, MockingTestCase): method setUp (line 159) | def setUp(self): method parse_log (line 165) | def parse_log(log: str, section: str, record: bool = True): method test_wandb (line 186) | def test_wandb(self): class MLflowTrackingTest (line 223) | class MLflowTrackingTest(unittest.TestCase): method setUp (line 224) | def setUp(self): method create_mock_figure (line 231) | def create_mock_figure(self): method test_log (line 238) | def test_log(self): method test_log_figure (line 259) | def test_log_figure(self): method test_log_artifact (line 277) | def test_log_artifact(self): method test_log_artifacts (line 300) | def test_log_artifacts(self): class CometMLTest (line 327) | class CometMLTest(unittest.TestCase): method get_value_from_key (line 329) | def get_value_from_key(log_list, key: str, is_param: bool = False): method test_init_trackers (line 345) | def test_init_trackers(self): method test_log (line 365) | def test_log(self): class ClearMLTest (line 388) | class ClearMLTest(TempDirTestCase, MockingTestCase): method setUp (line 389) | def setUp(self): method _get_offline_dir (line 395) | def _get_offline_dir(accelerator): method _get_metrics (line 401) | def _get_metrics(offline_dir): method test_init_trackers (line 409) | def test_init_trackers(self): method test_log (line 426) | def test_log(self): method test_log_images (line 456) | def test_log_images(self): method test_log_table (line 477) | def test_log_table(self): method test_log_table_pandas (line 505) | def test_log_table_pandas(self): class SwanLabTrackingTest (line 530) | class SwanLabTrackingTest(TempDirTestCase, MockingTestCase): method setUp (line 531) | def setUp(self): method test_swanlab (line 537) | def test_swanlab(self): class MyCustomTracker (line 645) | class MyCustomTracker(GeneralTracker): method __init__ (line 661) | def __init__(self, dir: str, **kwargs): method start (line 667) | def start(self): method tracker (line 674) | def tracker(self): method store_init_configuration (line 677) | def store_init_configuration(self, values: dict): method log (line 681) | def log(self, values: dict, step: Optional[int]): method finish (line 685) | def finish(self): class CustomTrackerTestCase (line 689) | class CustomTrackerTestCase(unittest.TestCase): method test_init_trackers (line 690) | def test_init_trackers(self): method test_log (line 711) | def test_log(self): class DVCLiveTrackingTest (line 736) | class DVCLiveTrackingTest(unittest.TestCase): method test_init_trackers (line 737) | def test_init_trackers(self, mock_repo): method test_log (line 754) | def test_log(self, mock_repo): class TrackerDeferredInitializationTest (line 779) | class TrackerDeferredInitializationTest(unittest.TestCase): method test_tensorboard_deferred_init (line 788) | def test_tensorboard_deferred_init(self): method test_wandb_deferred_init (line 798) | def test_wandb_deferred_init(self): method test_trackio_deferred_init (line 807) | def test_trackio_deferred_init(self): method test_comet_ml_deferred_init (line 816) | def test_comet_ml_deferred_init(self): method test_aim_deferred_init (line 825) | def test_aim_deferred_init(self): method test_mlflow_deferred_init (line 835) | def test_mlflow_deferred_init(self): method test_clearml_deferred_init (line 845) | def test_clearml_deferred_init(self): method test_dvclive_deferred_init (line 854) | def test_dvclive_deferred_init(self): method test_swanlab_deferred_init (line 864) | def test_swanlab_deferred_init(self): FILE: tests/test_utils.py class UtilsTester (line 73) | class UtilsTester(unittest.TestCase): method setUp (line 74) | def setUp(self): method test_send_to_device (line 78) | def test_send_to_device(self): method test_honor_type (line 117) | def test_honor_type(self): method test_listify (line 125) | def test_listify(self): method test_patch_environment (line 135) | def test_patch_environment(self): method test_patch_environment_key_exists (line 143) | def test_patch_environment_key_exists(self): method test_patch_environment_restores_on_error (line 162) | def test_patch_environment_restores_on_error(self): method test_clear_environment (line 172) | def test_clear_environment(self): method test_can_undo_convert_outputs (line 181) | def test_can_undo_convert_outputs(self): method test_can_undo_fp16_conversion (line 189) | def test_can_undo_fp16_conversion(self): method test_dynamo (line 199) | def test_dynamo(self): method test_extract_model (line 208) | def test_extract_model(self): method test_extract_model_recursive_fsdpv2 (line 218) | def test_extract_model_recursive_fsdpv2(self): method test_dynamo_extract_model_keep_torch_compile (line 242) | def test_dynamo_extract_model_keep_torch_compile(self): method test_dynamo_extract_model_remove_torch_compile (line 253) | def test_dynamo_extract_model_remove_torch_compile(self): method test_find_device (line 264) | def test_find_device(self): method test_check_os_kernel_no_warning_when_release_gt_min (line 269) | def test_check_os_kernel_no_warning_when_release_gt_min(self): method test_check_os_kernel_no_warning_when_not_linux (line 276) | def test_check_os_kernel_no_warning_when_not_linux(self): method test_check_os_kernel_warning_when_release_lt_min (line 283) | def test_check_os_kernel_warning_when_release_lt_min(self): method test_save_safetensor_shared_memory (line 294) | def test_save_safetensor_shared_memory(self): method test_pad_across_processes (line 313) | def test_pad_across_processes(self): method test_slice_and_concatenate (line 330) | def test_slice_and_concatenate(self): method test_send_to_device_compiles (line 395) | def test_send_to_device_compiles(self): method test_convert_to_fp32 (line 399) | def test_convert_to_fp32(self): method test_named_tuples (line 403) | def test_named_tuples(self): method test_convert_dict_to_env_variables (line 425) | def test_convert_dict_to_env_variables(self): method test_has_offloaded_params (line 431) | def test_has_offloaded_params(self): method test_concatenate (line 446) | def test_concatenate(self): function set_dummy_accelerate_env_var (line 535) | def set_dummy_accelerate_env_var(): class MyUnittest (line 549) | class MyUnittest(unittest.TestCase): method test_purge_env_vars_unittest_1 (line 550) | def test_purge_env_vars_unittest_1(self): method test_purge_env_vars_unittest_2 (line 555) | def test_purge_env_vars_unittest_2(self): class MyUnittestWithDecorators (line 562) | class MyUnittestWithDecorators(unittest.TestCase): method test_purge_env_vars_unittest_with_wrapper_1 (line 563) | def test_purge_env_vars_unittest_with_wrapper_1(self): method test_purge_env_vars_unittest_with_wrapper_2 (line 568) | def test_purge_env_vars_unittest_with_wrapper_2(self): method test_purge_env_vars_unittest_with_wrapper_3 (line 572) | def test_purge_env_vars_unittest_with_wrapper_3(self): method test_purge_env_vars_unittest_with_wrapper_4 (line 576) | def test_purge_env_vars_unittest_with_wrapper_4(self): class _BaseCls (line 582) | class _BaseCls(unittest.TestCase): method test_purge_env_vars_unittest_with_inheritance_3 (line 583) | def test_purge_env_vars_unittest_with_inheritance_3(self): class MyUnittestWithInheritance (line 587) | class MyUnittestWithInheritance(_BaseCls): method test_purge_env_vars_unittest_with_inheritance_1 (line 588) | def test_purge_env_vars_unittest_with_inheritance_1(self): method test_purge_env_vars_unittest_with_inheritance_2 (line 593) | def test_purge_env_vars_unittest_with_inheritance_2(self): class TestMyPytest (line 598) | class TestMyPytest: method test_purge_env_vars_pytest_1 (line 599) | def test_purge_env_vars_pytest_1(self): method test_purge_env_vars_pytest_2 (line 604) | def test_purge_env_vars_pytest_2(self): function dummy_fixture (line 609) | def dummy_fixture(): class TestPytestWithWrapper (line 618) | class TestPytestWithWrapper: method test_purge_env_vars_pytest_with_wrapper_1 (line 619) | def test_purge_env_vars_pytest_with_wrapper_1(self): method test_purge_env_vars_pytest_with_wrapper_2 (line 624) | def test_purge_env_vars_pytest_with_wrapper_2(self): method test_purge_env_vars_pytest_with_wrapper_3 (line 629) | def test_purge_env_vars_pytest_with_wrapper_3(self): method test_purge_env_vars_pytest_with_wrapper_4_should_be_skipped (line 633) | def test_purge_env_vars_pytest_with_wrapper_4_should_be_skipped(self): class _PytestBaseCls (line 639) | class _PytestBaseCls: method test_purge_env_vars_pytest_with_inheritance_3 (line 640) | def test_purge_env_vars_pytest_with_inheritance_3(self): class TestPytestWithInheritance (line 644) | class TestPytestWithInheritance(_PytestBaseCls): method test_purge_env_vars_pytest_with_inheritance_1 (line 645) | def test_purge_env_vars_pytest_with_inheritance_1(self): method test_purge_env_vars_pytest_with_inheritance_2 (line 650) | def test_purge_env_vars_pytest_with_inheritance_2(self): function test_purge_env_vars_standalone_1 (line 655) | def test_purge_env_vars_standalone_1(): function test_purge_env_vars_standalone_2 (line 661) | def test_purge_env_vars_standalone_2(): function test_purge_env_vars_restores_previous_values (line 665) | def test_purge_env_vars_restores_previous_values(): FILE: tests/tp/fsdp2_tp_preparation.py class LmHeadWrapper (line 27) | class LmHeadWrapper(torch.nn.Module): method __init__ (line 28) | def __init__(self, lm_head): method forward (line 32) | def forward(self, x): function build_simple_dataloader (line 36) | def build_simple_dataloader(tokenizer, seq_len=64, batch_size=2): function main (line 60) | def main(): FILE: tests/tp/test_tp.py class TPIntegrationTest (line 39) | class TPIntegrationTest(TempDirTestCase): method setUp (line 42) | def setUp(self): method test_working_of_tp (line 51) | def test_working_of_tp(self): method test_working_of_tp_and_fsdp (line 67) | def test_working_of_tp_and_fsdp(self): FILE: tests/xla_spawn.py function parse_args (line 36) | def parse_args(): function main (line 74) | def main(): FILE: utils/stale.py function main (line 33) | def main():