SYMBOL INDEX (574 symbols across 44 files) FILE: example/mezo_runner/metrics.py function normalize_answer (line 10) | def normalize_answer(s): function calculate_metric (line 29) | def calculate_metric(predictions, metric_name): function f1 (line 62) | def f1(pred, gold): FILE: example/mezo_runner/run.py class OurArguments (line 40) | class OurArguments(TrainingArguments): function parse_args (line 125) | def parse_args(): function set_seed (line 133) | def set_seed(seed: int): class Framework (line 140) | class Framework: method __init__ (line 142) | def __init__(self, args, task): method load_model (line 148) | def load_model(self): method forward (line 272) | def forward(self, input_ids, option_len=None, generation=False): method one_step_pred (line 305) | def one_step_pred(self, train_samples, eval_sample, verbose=False): method evaluate (line 382) | def evaluate(self, train_samples, eval_samples, one_train_set_per_eval... method train (line 404) | def train(self, train_samples, eval_samples): function result_file_tag (line 524) | def result_file_tag(args): function main (line 538) | def main(): FILE: example/mezo_runner/tasks.py function get_task (line 26) | def get_task(task_name): class Sample (line 39) | class Sample: class Dataset (line 46) | class Dataset: method __init__ (line 51) | def __init__(self, subtask=None, **kwargs) -> None: method get_task_name (line 54) | def get_task_name(self): method load_dataset (line 57) | def load_dataset(): method get_template (line 60) | def get_template(self, template_version=0): method build_sample (line 64) | def build_sample(self, example): method sample_train_sets (line 67) | def sample_train_sets(self, num_train=32, num_dev=None, num_eval=None,... method sample_subset (line 98) | def sample_subset(self, data_split="train", seed=0, num=100, exclude=N... method valid_samples (line 110) | def valid_samples(self): class SST2Dataset (line 114) | class SST2Dataset(Dataset): method __init__ (line 116) | def __init__(self, subtask=None, **kwargs) -> None: method load_dataset (line 119) | def load_dataset(self, path, **kwargs): method build_sample (line 130) | def build_sample(self, example): method get_template (line 134) | def get_template(self, template_version=0): class CopaDataset (line 138) | class CopaDataset(Dataset): method __init__ (line 142) | def __init__(self, subtask=None, **kwargs) -> None: method load_dataset (line 145) | def load_dataset(self, path, **kwargs): method build_sample (line 154) | def build_sample(self, example): method get_template (line 165) | def get_template(self, template_version=0): class BoolQDataset (line 169) | class BoolQDataset(Dataset): method __init__ (line 170) | def __init__(self, subtask=None, **kwargs) -> None: method load_dataset (line 173) | def load_dataset(self, path, **kwargs): method build_sample (line 182) | def build_sample(self, example): method get_template (line 192) | def get_template(self, template_version=2): class MultiRCDataset (line 196) | class MultiRCDataset(Dataset): method __init__ (line 198) | def __init__(self, subtask=None, **kwargs) -> None: method load_dataset (line 201) | def load_dataset(self, path, **kwargs): method build_sample (line 210) | def build_sample(self, example): method get_template (line 220) | def get_template(self, template_version=0): class CBDataset (line 224) | class CBDataset(Dataset): method __init__ (line 226) | def __init__(self, subtask=None, **kwargs) -> None: method load_dataset (line 229) | def load_dataset(self, path, **kwargs): method build_sample (line 238) | def build_sample(self, example): method get_template (line 248) | def get_template(self, template_version=0): class WICDataset (line 252) | class WICDataset(Dataset): method __init__ (line 254) | def __init__(self, subtask=None, **kwargs) -> None: method load_dataset (line 257) | def load_dataset(self, path, **kwargs): method build_sample (line 266) | def build_sample(self, example): method get_template (line 276) | def get_template(self, template_version=0): class WSCDataset (line 280) | class WSCDataset(Dataset): method __init__ (line 282) | def __init__(self, subtask=None, **kwargs) -> None: method load_dataset (line 285) | def load_dataset(self, path, **kwargs): method build_sample (line 294) | def build_sample(self, example): method get_template (line 304) | def get_template(self, template_version=0): class ReCoRDDataset (line 308) | class ReCoRDDataset(Dataset): method __init__ (line 310) | def __init__(self, subtask=None, **kwargs) -> None: method load_dataset (line 313) | def load_dataset(self, path, **kwargs): method build_sample (line 322) | def build_sample(self, example): method get_template (line 332) | def get_template(self, template_version=0): class RTEDataset (line 336) | class RTEDataset(Dataset): method __init__ (line 338) | def __init__(self, subtask=None, **kwargs) -> None: method load_dataset (line 341) | def load_dataset(self, path, **kwargs): method build_sample (line 350) | def build_sample(self, example): method get_template (line 360) | def get_template(self, template_version=0): class SQuADDataset (line 364) | class SQuADDataset(Dataset): method __init__ (line 368) | def __init__(self, subtask=None, **kwargs) -> None: method load_dataset (line 371) | def load_dataset(self): method build_sample (line 381) | def build_sample(self, example, idx): method get_template (line 396) | def get_template(self, template_version=0): class DROPDataset (line 400) | class DROPDataset(Dataset): method __init__ (line 404) | def __init__(self, subtask=None, **kwargs) -> None: method load_dataset (line 407) | def load_dataset(self): method build_sample (line 417) | def build_sample(self, example, idx): method get_template (line 431) | def get_template(self, template_version=0): FILE: example/mezo_runner/templates.py class Template (line 8) | class Template: method encode (line 9) | def encode(self, sample): method verbalize (line 15) | def verbalize(self, sample, candidate): method encode_sfc (line 21) | def encode_sfc(self, sample): method verbalize_sfc (line 27) | def verbalize_sfc(self, sample, candidate): class SST2Template (line 34) | class SST2Template(Template): method encode (line 36) | def encode(self, sample): method verbalize (line 40) | def verbalize(self, sample, candidate): method encode_sfc (line 44) | def encode_sfc(self, sample): method verbalize_sfc (line 47) | def verbalize_sfc(self, sample, candidate): class CopaTemplate (line 51) | class CopaTemplate(Template): method get_conjucture (line 56) | def get_conjucture(self, sample): method get_prompt (line 65) | def get_prompt(self, sample): method encode (line 77) | def encode(self, sample): method capitalize (line 81) | def capitalize(self, c): method verbalize (line 96) | def verbalize(self, sample, candidate): method encode_sfc (line 100) | def encode_sfc(self, sample): method verbalize_sfc (line 104) | def verbalize_sfc(self, sample, candidate): class BoolQTemplate (line 110) | class BoolQTemplate(Template): method encode (line 111) | def encode(self, sample): method verbalize (line 119) | def verbalize(self, sample, candidate): method encode_sfc (line 127) | def encode_sfc(self, sample): method verbalize_sfc (line 130) | def verbalize_sfc(self, sample, candidate): class BoolQTemplateV2 (line 134) | class BoolQTemplateV2(Template): method encode (line 135) | def encode(self, sample): method verbalize (line 143) | def verbalize(self, sample, candidate): method encode_sfc (line 151) | def encode_sfc(self, sample): method verbalize_sfc (line 154) | def verbalize_sfc(self, sample, candidate): class BoolQTemplateV3 (line 158) | class BoolQTemplateV3(Template): method encode (line 159) | def encode(self, sample): method verbalize (line 167) | def verbalize(self, sample, candidate): method encode_sfc (line 175) | def encode_sfc(self, sample): method verbalize_sfc (line 178) | def verbalize_sfc(self, sample, candidate): class MultiRCTemplate (line 182) | class MultiRCTemplate(Template): method encode (line 186) | def encode(self, sample): method verbalize (line 192) | def verbalize(self, sample, candidate): method encode_sfc (line 198) | def encode_sfc(self, sample): method verbalize_sfc (line 201) | def verbalize_sfc(self, sample, candidate): class CBTemplate (line 205) | class CBTemplate(Template): method encode (line 209) | def encode(self, sample): method verbalize (line 214) | def verbalize(self, sample, candidate): method encode_sfc (line 219) | def encode_sfc(self, sample): method verbalize_sfc (line 222) | def verbalize_sfc(self, sample, candidate): class WICTemplate (line 226) | class WICTemplate(Template): method encode (line 230) | def encode(self, sample): method verbalize (line 236) | def verbalize(self, sample, candidate): method encode_sfc (line 242) | def encode_sfc(self, sample): method verbalize_sfc (line 245) | def verbalize_sfc(self, sample, candidate): class WSCTemplate (line 249) | class WSCTemplate(Template): method encode (line 253) | def encode(self, sample): method verbalize (line 259) | def verbalize(self, sample, candidate): method encode_sfc (line 265) | def encode_sfc(self, sample): method verbalize_sfc (line 268) | def verbalize_sfc(self, sample, candidate): class ReCoRDTemplate (line 272) | class ReCoRDTemplate(Template): method encode (line 275) | def encode(self, sample): method verbalize (line 280) | def verbalize(self, sample, candidate): method encode_sfc (line 285) | def encode_sfc(self, sample): method verbalize_sfc (line 288) | def verbalize_sfc(self, sample, candidate): class ReCoRDTemplateGPT3 (line 292) | class ReCoRDTemplateGPT3(Template): method encode (line 295) | def encode(self, sample): method verbalize (line 299) | def verbalize(self, sample, candidate): method encode_sfc (line 308) | def encode_sfc(self, sample): method verbalize_sfc (line 311) | def verbalize_sfc(self, sample, candidate): class RTETemplate (line 316) | class RTETemplate(Template): method encode (line 320) | def encode(self, sample): method verbalize (line 325) | def verbalize(self, sample, candidate): method encode_sfc (line 330) | def encode_sfc(self, sample): method verbalize_sfc (line 333) | def verbalize_sfc(self, sample, candidate): class SQuADv2Template (line 337) | class SQuADv2Template(Template): method encode (line 339) | def encode(self, sample): method verbalize (line 347) | def verbalize(self, sample, candidate): method encode_sfc (line 356) | def encode_sfc(self, sample): method verbalize_sfc (line 359) | def verbalize_sfc(self, sample, candidate): class DROPTemplate (line 363) | class DROPTemplate(Template): method encode (line 365) | def encode(self, sample): method verbalize (line 373) | def verbalize(self, sample, candidate): method encode_sfc (line 382) | def encode_sfc(self, sample): method verbalize_sfc (line 385) | def verbalize_sfc(self, sample, candidate): FILE: example/mezo_runner/utils.py function custom_loss_fn_with_option_len (line 37) | def custom_loss_fn_with_option_len(self, input_ids, logits, labels, opti... function forward_wrap_with_option_len (line 87) | def forward_wrap_with_option_len(self, input_ids=None, labels=None, opti... function encode_prompt (line 161) | def encode_prompt(task, template, train_samples, eval_sample, tokenizer,... class ICLCollator (line 233) | class ICLCollator: method __call__ (line 239) | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: class DataCollatorWithPaddingAndNesting (line 260) | class DataCollatorWithPaddingAndNesting: method __call__ (line 271) | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: class NondiffCollator (line 290) | class NondiffCollator(DataCollatorMixin): method torch_call (line 301) | def torch_call(self, features): class SIGUSR1Callback (line 344) | class SIGUSR1Callback(transformers.TrainerCallback): method __init__ (line 350) | def __init__(self) -> None: method handle_signal (line 357) | def handle_signal(self, signum, frame): method on_step_end (line 361) | def on_step_end(self, args, state, control, **kwargs): method on_train_end (line 366) | def on_train_end(self, args, state, control, **kwargs): class Prediction (line 372) | class Prediction: function count_time (line 378) | def count_time(name): function temp_seed (line 388) | def temp_seed(seed): class EnhancedJSONEncoder (line 397) | class EnhancedJSONEncoder(json.JSONEncoder): method default (line 398) | def default(self, o): function write_predictions_to_file (line 404) | def write_predictions_to_file(final_preds, output): function write_metrics_to_file (line 410) | def write_metrics_to_file(metrics, output): FILE: script/add-copyright.py function add_license_header (line 9) | def add_license_header(file_path, comment_style): FILE: test/mezo_sgd/hf_opt/test_acc.py function train_mezo_sgd_causalLM (line 22) | def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): function train_mezo2_sgd_causalLM (line 38) | def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): function eval_mezo_sgd_causalLM (line 54) | def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): function eval_mezo2_sgd_causalLM (line 70) | def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): function train_mezo_sgd_sequence_classification (line 87) | def train_mezo_sgd_sequence_classification(model_config, zo_config, devi... function train_mezo2_sgd_sequence_classification (line 103) | def train_mezo2_sgd_sequence_classification(model_config, zo_config, dev... function eval_mezo_sgd_sequence_classification (line 119) | def eval_mezo_sgd_sequence_classification(model_config, zo_config, devic... function eval_mezo2_sgd_sequence_classification (line 135) | def eval_mezo2_sgd_sequence_classification(model_config, zo_config, devi... function train_mezo_sgd_question_answering (line 152) | def train_mezo_sgd_question_answering(model_config, zo_config, device='c... function train_mezo2_sgd_question_answering (line 168) | def train_mezo2_sgd_question_answering(model_config, zo_config, device='... function eval_mezo_sgd_question_answering (line 184) | def eval_mezo_sgd_question_answering(model_config, zo_config, device='cu... function eval_mezo2_sgd_question_answering (line 200) | def eval_mezo2_sgd_question_answering(model_config, zo_config, device='c... function test_mezo_sgd_causalLM_training (line 217) | def test_mezo_sgd_causalLM_training(): function test_mezo2_sgd_causalLM_training (line 227) | def test_mezo2_sgd_causalLM_training(): function test_mezo_sgd_causalLM_eval (line 238) | def test_mezo_sgd_causalLM_eval(): function test_mezo2_sgd_causalLM_eval (line 248) | def test_mezo2_sgd_causalLM_eval(): function test_mezo_sgd_sequence_classification_training (line 260) | def test_mezo_sgd_sequence_classification_training(): function test_mezo2_sgd_sequence_classification_training (line 270) | def test_mezo2_sgd_sequence_classification_training(): function test_mezo_sgd_sequence_classification_eval (line 281) | def test_mezo_sgd_sequence_classification_eval(): function test_mezo2_sgd_sequence_classification_eval (line 291) | def test_mezo2_sgd_sequence_classification_eval(): function test_mezo_sgd_question_answering_training (line 303) | def test_mezo_sgd_question_answering_training(): function test_mezo2_sgd_question_answering_training (line 313) | def test_mezo2_sgd_question_answering_training(): function test_mezo_sgd_question_answering_eval (line 324) | def test_mezo_sgd_question_answering_eval(): function test_mezo2_sgd_question_answering_eval (line 334) | def test_mezo2_sgd_question_answering_eval(): FILE: test/mezo_sgd/hf_opt/test_memory.py function train_mezo_sgd_causalLM (line 25) | def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'): function train_mezo2_sgd_causalLM (line 42) | def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'): function eval_mezo_sgd_causalLM (line 59) | def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'): function eval_mezo2_sgd_causalLM (line 76) | def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'): function train_mezo_sgd_sequence_classification (line 94) | def train_mezo_sgd_sequence_classification(model_config, zo_config, devi... function train_mezo2_sgd_sequence_classification (line 111) | def train_mezo2_sgd_sequence_classification(model_config, zo_config, dev... function eval_mezo_sgd_sequence_classification (line 128) | def eval_mezo_sgd_sequence_classification(model_config, zo_config, devic... function eval_mezo2_sgd_sequence_classification (line 145) | def eval_mezo2_sgd_sequence_classification(model_config, zo_config, devi... function train_mezo_sgd_question_answering (line 163) | def train_mezo_sgd_question_answering(model_config, zo_config, device='c... function train_mezo2_sgd_question_answering (line 180) | def train_mezo2_sgd_question_answering(model_config, zo_config, device='... function eval_mezo_sgd_question_answering (line 197) | def eval_mezo_sgd_question_answering(model_config, zo_config, device='cu... function eval_mezo2_sgd_question_answering (line 214) | def eval_mezo2_sgd_question_answering(model_config, zo_config, device='c... function test_mezo_sgd_causalLM_training (line 232) | def test_mezo_sgd_causalLM_training(): function test_mezo2_sgd_causalLM_training (line 243) | def test_mezo2_sgd_causalLM_training(): function test_mezo_sgd_causalLM_eval (line 254) | def test_mezo_sgd_causalLM_eval(): function test_mezo2_sgd_causalLM_eval (line 265) | def test_mezo2_sgd_causalLM_eval(): function test_mezo_sgd_sequence_classification_training (line 277) | def test_mezo_sgd_sequence_classification_training(): function test_mezo2_sgd_sequence_classification_training (line 288) | def test_mezo2_sgd_sequence_classification_training(): function test_mezo_sgd_sequence_classification_eval (line 299) | def test_mezo_sgd_sequence_classification_eval(): function test_mezo2_sgd_sequence_classification_eval (line 310) | def test_mezo2_sgd_sequence_classification_eval(): function test_mezo_sgd_question_answering_training (line 322) | def test_mezo_sgd_question_answering_training(): function test_mezo2_sgd_question_answering_training (line 333) | def test_mezo2_sgd_question_answering_training(): function test_mezo_sgd_question_answering_eval (line 344) | def test_mezo_sgd_question_answering_eval(): function test_mezo2_sgd_question_answering_eval (line 355) | def test_mezo2_sgd_question_answering_eval(): FILE: test/mezo_sgd/hf_opt/test_speed.py function train_mezo_sgd_causalLM (line 23) | def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): function train_mezo2_sgd_causalLM (line 36) | def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): function eval_mezo_sgd_causalLM (line 49) | def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): function eval_mezo2_sgd_causalLM (line 62) | def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): function train_mezo_sgd_sequence_classification (line 76) | def train_mezo_sgd_sequence_classification(model_config, zo_config, devi... function train_mezo2_sgd_sequence_classification (line 89) | def train_mezo2_sgd_sequence_classification(model_config, zo_config, dev... function eval_mezo_sgd_sequence_classification (line 102) | def eval_mezo_sgd_sequence_classification(model_config, zo_config, devic... function eval_mezo2_sgd_sequence_classification (line 115) | def eval_mezo2_sgd_sequence_classification(model_config, zo_config, devi... function train_mezo_sgd_question_answering (line 129) | def train_mezo_sgd_question_answering(model_config, zo_config, device='c... function train_mezo2_sgd_question_answering (line 142) | def train_mezo2_sgd_question_answering(model_config, zo_config, device='... function eval_mezo_sgd_question_answering (line 155) | def eval_mezo_sgd_question_answering(model_config, zo_config, device='cu... function eval_mezo2_sgd_question_answering (line 168) | def eval_mezo2_sgd_question_answering(model_config, zo_config, device='c... function test_mezo_sgd_causalLM_training (line 182) | def test_mezo_sgd_causalLM_training(): function test_mezo2_sgd_causalLM_training (line 193) | def test_mezo2_sgd_causalLM_training(): function test_mezo_sgd_causalLM_eval (line 204) | def test_mezo_sgd_causalLM_eval(): function test_mezo2_sgd_causalLM_eval (line 215) | def test_mezo2_sgd_causalLM_eval(): function test_mezo_sgd_sequence_classification_training (line 227) | def test_mezo_sgd_sequence_classification_training(): function test_mezo2_sgd_sequence_classification_training (line 238) | def test_mezo2_sgd_sequence_classification_training(): function test_mezo_sgd_sequence_classification_eval (line 249) | def test_mezo_sgd_sequence_classification_eval(): function test_mezo2_sgd_sequence_classification_eval (line 260) | def test_mezo2_sgd_sequence_classification_eval(): function test_mezo_sgd_question_answering_training (line 272) | def test_mezo_sgd_question_answering_training(): function test_mezo2_sgd_question_answering_training (line 283) | def test_mezo2_sgd_question_answering_training(): function test_mezo_sgd_question_answering_eval (line 294) | def test_mezo_sgd_question_answering_eval(): function test_mezo2_sgd_question_answering_eval (line 305) | def test_mezo2_sgd_question_answering_eval(): FILE: test/mezo_sgd/hf_opt/utils.py function get_args (line 13) | def get_args(): class OPTConfigs (line 36) | class OPTConfigs: function model_size (line 56) | def model_size(model: torch.nn.Module): function prepare_data_for_causalLM (line 62) | def prepare_data_for_causalLM(V, B, T, device='cuda'): function prepare_data_for_sequence_classification (line 68) | def prepare_data_for_sequence_classification(V, B, T, device='cuda'): function prepare_data_for_question_answering (line 73) | def prepare_data_for_question_answering(V, B, T, device='cuda'): function check_peak_gpu_memory_usage (line 82) | def check_peak_gpu_memory_usage(iter, device=0, use_tqdm=False): function check_and_update_peak_cpu_memory_usage (line 94) | def check_and_update_peak_cpu_memory_usage(iter, use_tqdm=False): function reset_peak_cpu_memory_usage (line 105) | def reset_peak_cpu_memory_usage(): function check_throughput (line 112) | def check_throughput(iter, total_token_batch_size_per_iter, fn, *args, u... FILE: test/mezo_sgd/hf_qwen3/test_acc.py function train_mezo_sgd_causalLM (line 20) | def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): function train_mezo2_sgd_causalLM (line 36) | def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): function eval_mezo_sgd_causalLM (line 52) | def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): function eval_mezo2_sgd_causalLM (line 68) | def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): function test_mezo_sgd_causalLM_training (line 85) | def test_mezo_sgd_causalLM_training(): function test_mezo2_sgd_causalLM_training (line 97) | def test_mezo2_sgd_causalLM_training(): function test_mezo_sgd_causalLM_eval (line 110) | def test_mezo_sgd_causalLM_eval(): function test_mezo2_sgd_causalLM_eval (line 121) | def test_mezo2_sgd_causalLM_eval(): FILE: test/mezo_sgd/hf_qwen3/test_memory.py function train_mezo_sgd_causalLM (line 23) | def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'): function train_mezo2_sgd_causalLM (line 40) | def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'): function eval_mezo_sgd_causalLM (line 57) | def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda:0'): function eval_mezo2_sgd_causalLM (line 74) | def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda:0'): function test_mezo_sgd_causalLM_training (line 93) | def test_mezo_sgd_causalLM_training(): function test_mezo2_sgd_causalLM_training (line 104) | def test_mezo2_sgd_causalLM_training(): function test_mezo_sgd_causalLM_eval (line 115) | def test_mezo_sgd_causalLM_eval(): function test_mezo2_sgd_causalLM_eval (line 126) | def test_mezo2_sgd_causalLM_eval(): FILE: test/mezo_sgd/hf_qwen3/test_speed.py function train_mezo_sgd_causalLM (line 21) | def train_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): function train_mezo2_sgd_causalLM (line 34) | def train_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): function eval_mezo_sgd_causalLM (line 47) | def eval_mezo_sgd_causalLM(model_config, zo_config, device='cuda'): function eval_mezo2_sgd_causalLM (line 60) | def eval_mezo2_sgd_causalLM(model_config, zo_config, device='cuda'): function test_mezo_sgd_causalLM_training (line 74) | def test_mezo_sgd_causalLM_training(): function test_mezo2_sgd_causalLM_training (line 85) | def test_mezo2_sgd_causalLM_training(): function test_mezo_sgd_causalLM_eval (line 96) | def test_mezo_sgd_causalLM_eval(): function test_mezo2_sgd_causalLM_eval (line 107) | def test_mezo2_sgd_causalLM_eval(): FILE: test/mezo_sgd/hf_qwen3/utils.py function get_args (line 13) | def get_args(): class Qwen3Configs (line 35) | class Qwen3Configs: function model_size (line 52) | def model_size(model: torch.nn.Module): function prepare_data_for_causalLM (line 58) | def prepare_data_for_causalLM(V, B, T, device='cuda'): function check_peak_gpu_memory_usage (line 67) | def check_peak_gpu_memory_usage(iter, device=0, use_tqdm=False): function check_and_update_peak_cpu_memory_usage (line 79) | def check_and_update_peak_cpu_memory_usage(iter, use_tqdm=False): function reset_peak_cpu_memory_usage (line 90) | def reset_peak_cpu_memory_usage(): function check_throughput (line 97) | def check_throughput(iter, total_token_batch_size_per_iter, fn, *args, u... FILE: test/mezo_sgd/nanogpt/test_acc.py function train_mezo_sgd (line 16) | def train_mezo_sgd(model, args, model_config, device='cuda'): function train_mezo2_sgd (line 28) | def train_mezo2_sgd(model, args, model_config, device='cuda'): function eval_mezo_sgd (line 40) | def eval_mezo_sgd(model, args, model_config, device='cuda'): function eval_mezo2_sgd (line 52) | def eval_mezo2_sgd(model, args, model_config, device='cuda'): function test_mezo_sgd_training (line 64) | def test_mezo_sgd_training(): function test_mezo2_sgd_training (line 79) | def test_mezo2_sgd_training(): function test_mezo_sgd_eval (line 94) | def test_mezo_sgd_eval(): function test_mezo2_sgd_eval (line 109) | def test_mezo2_sgd_eval(): FILE: test/mezo_sgd/nanogpt/test_memory.py function train_mezo_sgd (line 16) | def train_mezo_sgd(model, args, modelConfig, device='cuda:0'): function train_mezo2_sgd (line 30) | def train_mezo2_sgd(model, args, modelConfig, device='cuda:0'): function eval_mezo_sgd (line 44) | def eval_mezo_sgd(model, args, modelConfig, device='cuda:0'): function eval_mezo2_sgd (line 58) | def eval_mezo2_sgd(model, args, modelConfig, device='cuda:0'): function test_mezo_sgd_training (line 72) | def test_mezo_sgd_training(): function test_mezo2_sgd_training (line 87) | def test_mezo2_sgd_training(): function test_mezo_sgd_eval (line 102) | def test_mezo_sgd_eval(): function test_mezo2_sgd_eval (line 117) | def test_mezo2_sgd_eval(): FILE: test/mezo_sgd/nanogpt/test_speed.py function train_mezo_sgd (line 16) | def train_mezo_sgd(model, args, modelConfig, device='cuda'): function train_mezo2_sgd (line 26) | def train_mezo2_sgd(model, args, modelConfig, device='cuda'): function eval_mezo_sgd (line 36) | def eval_mezo_sgd(model, args, modelConfig, device='cuda'): function eval_mezo2_sgd (line 46) | def eval_mezo2_sgd(model, args, modelConfig, device='cuda'): function test_mezo_sgd_training (line 56) | def test_mezo_sgd_training(): function test_mezo2_sgd_training (line 71) | def test_mezo2_sgd_training(): function test_mezo_sgd_eval (line 86) | def test_mezo_sgd_eval(): function test_mezo2_sgd_eval (line 101) | def test_mezo2_sgd_eval(): FILE: test/mezo_sgd/nanogpt/utils.py function get_args (line 13) | def get_args(): function model_size (line 41) | def model_size(model: torch.nn.Module): function prepare_data (line 47) | def prepare_data(V, B, T, device='cuda'): function check_peak_gpu_memory_usage (line 57) | def check_peak_gpu_memory_usage(iter, device=0, use_tqdm=False): function check_and_update_peak_cpu_memory_usage (line 69) | def check_and_update_peak_cpu_memory_usage(iter, use_tqdm=False): function reset_peak_cpu_memory_usage (line 80) | def reset_peak_cpu_memory_usage(): function check_throughput (line 86) | def check_throughput(iter, total_token_batch_size_per_iter, fn, *args, u... FILE: zo2/config/__init__.py function ZOConfig (line 7) | def ZOConfig(method: str = "mezo-sgd", **kwargs): FILE: zo2/config/mezo_sgd.py class MeZOSGDConfig (line 8) | class MeZOSGDConfig: FILE: zo2/model/base.py class BaseZOModel (line 6) | class BaseZOModel(torch.nn.Module): method __init__ (line 7) | def __init__(self): method zo_train (line 17) | def zo_train(self): method zo_eval (line 24) | def zo_eval(self): method register_zo_train_loss_fn_pre_hook (line 31) | def register_zo_train_loss_fn_pre_hook(self, hook_fn): method register_zo_train_loss_fn_post_hook (line 34) | def register_zo_train_loss_fn_post_hook(self, hook_fn): method register_zo_eval_loss_fn_pre_hook (line 37) | def register_zo_eval_loss_fn_pre_hook(self, hook_fn): method register_zo_eval_loss_fn_post_hook (line 40) | def register_zo_eval_loss_fn_post_hook(self, hook_fn): method register_custom_opt (line 43) | def register_custom_opt(self, custom_opt_obj): FILE: zo2/model/huggingface/opt/__init__.py function get_opt_for_causalLM (line 8) | def get_opt_for_causalLM(zo_config): function get_opt_for_sequence_classification (line 14) | def get_opt_for_sequence_classification(zo_config): function get_opt_for_question_answering (line 20) | def get_opt_for_question_answering(zo_config): FILE: zo2/model/huggingface/opt/mezo_sgd/__init__.py function get_opt_for_causalLM_mezo_sgd (line 7) | def get_opt_for_causalLM_mezo_sgd(config: MeZOSGDConfig): function get_opt_for_sequence_classification_mezo_sgd (line 10) | def get_opt_for_sequence_classification_mezo_sgd(config: MeZOSGDConfig): function get_opt_for_question_answering_mezo_sgd (line 13) | def get_opt_for_question_answering_mezo_sgd(config: MeZOSGDConfig): FILE: zo2/model/huggingface/opt/mezo_sgd/utils.py function fn_get_opt_decoder_hidden_states_from_layer_outputs (line 6) | def fn_get_opt_decoder_hidden_states_from_layer_outputs(input): function get_shift_logits (line 9) | def get_shift_logits(logits): function get_shift_labels (line 12) | def get_shift_labels(labels): function get_pooled_logits (line 15) | def get_pooled_logits(logits, batch_size, sequence_lengths): function get_start_logits_and_end_logits (line 18) | def get_start_logits_and_end_logits(logits): function get_qa_loss (line 24) | def get_qa_loss(loss_fct, start_logits, start_positions, end_logits, end... function init_all_hidden_states (line 30) | def init_all_hidden_states(output_hidden_states): function init_all_self_attns (line 33) | def init_all_self_attns(output_attentions): function init_next_decoder_cache (line 36) | def init_next_decoder_cache(use_cache): function update_next_decoder_cache (line 39) | def update_next_decoder_cache(use_cache, next_decoder_cache, layer_outpu... function update_all_self_attns (line 44) | def update_all_self_attns(output_attentions, all_self_attns, layer_outpu... function update_all_hidden_states (line 49) | def update_all_hidden_states(output_hidden_states, all_hidden_states, hi... function get_past_key_value (line 54) | def get_past_key_value(past_key_values, idx): function get_opt_sequence_classification_pooled_logits (line 57) | def get_opt_sequence_classification_pooled_logits(self, logits, input_id... function get_opt_sequence_classification_loss (line 71) | def get_opt_sequence_classification_loss(self, loss, pooled_logits, labe... function get_opt_question_answering_start_end_logits (line 93) | def get_opt_question_answering_start_end_logits(logits): function get_opt_question_answering_loss (line 99) | def get_opt_question_answering_loss(total_loss, start_logits, start_posi... FILE: zo2/model/huggingface/opt/mezo_sgd/zo.py class OPTDecoder (line 42) | class OPTDecoder(modeling_opt.OPTDecoder, OPTPreTrainedModel): method __init__ (line 50) | def __init__(self, config: OPTConfig): class OPTModel (line 91) | class OPTModel(modeling_opt.OPTModel, OPTPreTrainedModel): method __init__ (line 92) | def __init__(self, config: OPTConfig): class OPTForCausalLM (line 99) | class OPTForCausalLM(modeling_opt.OPTForCausalLM, OPTPreTrainedModel, Ba... method __init__ (line 102) | def __init__(self, config: OPTConfig): method zo_init (line 113) | def zo_init(self, zo_config): method forward (line 117) | def forward( class OPTForSequenceClassification (line 221) | class OPTForSequenceClassification(modeling_opt.OPTForSequenceClassifica... method __init__ (line 222) | def __init__(self, config: OPTConfig): method zo_init (line 232) | def zo_init(self, zo_config): method forward (line 243) | def forward( class OPTForQuestionAnswering (line 274) | class OPTForQuestionAnswering(modeling_opt.OPTForQuestionAnswering, OPTP... method __init__ (line 275) | def __init__(self, config: OPTConfig): method zo_init (line 284) | def zo_init(self, zo_config): method forward (line 289) | def forward( class OptimizerOPTForCausalLM (line 358) | class OptimizerOPTForCausalLM(MeZOSGD): method inner_zo_forward (line 361) | def inner_zo_forward( method inner_zo_eval_forward (line 427) | def inner_zo_eval_forward( class OptimizerOPTForSequenceClassification (line 478) | class OptimizerOPTForSequenceClassification(MeZOSGD): method inner_zo_forward (line 481) | def inner_zo_forward( method inner_zo_eval_forward (line 572) | def inner_zo_eval_forward( class OptimizerOPTForQuestionAnswering (line 620) | class OptimizerOPTForQuestionAnswering(MeZOSGD): method inner_zo_forward (line 623) | def inner_zo_forward( method inner_zo_eval_forward (line 696) | def inner_zo_eval_forward( FILE: zo2/model/huggingface/opt/mezo_sgd/zo2.py class OPTDecoder (line 42) | class OPTDecoder(modeling_opt.OPTDecoder, OPTPreTrainedModel, BaseZOModel): method __init__ (line 50) | def __init__(self, config: OPTConfig): method zo_init (line 90) | def zo_init(self, zo_config): method forward (line 94) | def forward( class OPTModel (line 120) | class OPTModel(modeling_opt.OPTModel, OPTPreTrainedModel, BaseZOModel): method __init__ (line 121) | def __init__(self, config: OPTConfig): method zo_init (line 128) | def zo_init(self, zo_config): method forward (line 140) | def forward( class OPTForCausalLM (line 162) | class OPTForCausalLM(modeling_opt.OPTForCausalLM, OPTPreTrainedModel, Ba... method __init__ (line 165) | def __init__(self, config: OPTConfig): method zo_init (line 174) | def zo_init(self, zo_config): method forward (line 180) | def forward( class OPTForSequenceClassification (line 280) | class OPTForSequenceClassification(modeling_opt.OPTForSequenceClassifica... method __init__ (line 281) | def __init__(self, config: OPTConfig): method zo_init (line 291) | def zo_init(self, zo_config): method forward (line 303) | def forward( class OPTForQuestionAnswering (line 335) | class OPTForQuestionAnswering(modeling_opt.OPTForQuestionAnswering, OPTP... method __init__ (line 336) | def __init__(self, config: OPTConfig): method zo_init (line 345) | def zo_init(self, zo_config): method forward (line 351) | def forward( class OptimizerOPTDecoder (line 421) | class OptimizerOPTDecoder(MeZO2SGD): method init_zo2 (line 423) | def init_zo2(self): method init_zo2_upload (line 434) | def init_zo2_upload(self): method inner_zo_forward (line 457) | def inner_zo_forward( method inner_zo_eval_forward (line 655) | def inner_zo_eval_forward( class OptimizerOPTModel (line 919) | class OptimizerOPTModel(MeZO2SGD): method init_zo2 (line 921) | def init_zo2(self): method init_zo2_upload (line 932) | def init_zo2_upload(self): method inner_zo_forward (line 936) | def inner_zo_forward( method inner_zo_eval_forward (line 974) | def inner_zo_eval_forward( class OptimizerOPTForCausalLM (line 1020) | class OptimizerOPTForCausalLM(MeZO2SGD): method init_zo2_upload (line 1022) | def init_zo2_upload(self): method inner_zo_forward (line 1026) | def inner_zo_forward( method inner_zo_eval_forward (line 1114) | def inner_zo_eval_forward( class OptimizerOPTForSequenceClassification (line 1216) | class OptimizerOPTForSequenceClassification(MeZO2SGD): method init_zo2_upload (line 1218) | def init_zo2_upload(self): method inner_zo_forward (line 1222) | def inner_zo_forward( method inner_zo_eval_forward (line 1347) | def inner_zo_eval_forward( class OptimizerOPTForQuestionAnswering (line 1441) | class OptimizerOPTForQuestionAnswering(MeZO2SGD): method init_zo2_upload (line 1443) | def init_zo2_upload(self): method inner_zo_forward (line 1447) | def inner_zo_forward( method inner_zo_eval_forward (line 1540) | def inner_zo_eval_forward( FILE: zo2/model/huggingface/qwen3/__init__.py function get_qwen3_for_causalLM (line 8) | def get_qwen3_for_causalLM(zo_config): FILE: zo2/model/huggingface/qwen3/mezo_sgd/__init__.py function get_qwen3_for_causalLM_mezo_sgd (line 7) | def get_qwen3_for_causalLM_mezo_sgd(config: MeZOSGDConfig): FILE: zo2/model/huggingface/qwen3/mezo_sgd/utils.py function fn_get_qwen3_decoder_hidden_states_from_layer_outputs (line 6) | def fn_get_qwen3_decoder_hidden_states_from_layer_outputs(input): function fn_get_qwen3_sliced_logits_from_hidden_states (line 9) | def fn_get_qwen3_sliced_logits_from_hidden_states(hidden_states, slice_i... FILE: zo2/model/huggingface/qwen3/mezo_sgd/zo.py class Qwen3Model (line 38) | class Qwen3Model(modeling_qwen3.Qwen3Model, Qwen3PreTrainedModel): method __init__ (line 46) | def __init__(self, config: Qwen3Config): class Qwen3ForCausalLM (line 64) | class Qwen3ForCausalLM(modeling_qwen3.Qwen3ForCausalLM, Qwen3PreTrainedM... method __init__ (line 69) | def __init__(self, config: Qwen3Config): method zo_init (line 79) | def zo_init(self, zo_config): method forward (line 86) | def forward( class OptimizerQwen3ForCausalLM (line 148) | class OptimizerQwen3ForCausalLM(MeZOSGD): method inner_zo_forward (line 151) | def inner_zo_forward( method inner_zo_eval_forward (line 213) | def inner_zo_eval_forward( FILE: zo2/model/huggingface/qwen3/mezo_sgd/zo2.py class Qwen3Model (line 41) | class Qwen3Model(modeling_qwen3.Qwen3Model, Qwen3PreTrainedModel, BaseZO... method __init__ (line 48) | def __init__(self, config: Qwen3Config): method zo_init (line 68) | def zo_init(self, zo_config): method forward (line 74) | def forward( class Qwen3ForCausalLM (line 99) | class Qwen3ForCausalLM(modeling_qwen3.Qwen3ForCausalLM, Qwen3PreTrainedM... method __init__ (line 104) | def __init__(self, config): method zo_init (line 114) | def zo_init(self, zo_config): method forward (line 123) | def forward( class OptimizerQwen3Model (line 183) | class OptimizerQwen3Model(MeZO2SGD): method init_zo2 (line 185) | def init_zo2(self): method init_zo2_upload (line 196) | def init_zo2_upload(self): method inner_zo_forward (line 214) | def inner_zo_forward( method inner_zo_eval_forward (line 365) | def inner_zo_eval_forward( class OptimizerQwen3ForCausalLM (line 509) | class OptimizerQwen3ForCausalLM(MeZO2SGD): method init_zo2_upload (line 511) | def init_zo2_upload(self): method inner_zo_forward (line 515) | def inner_zo_forward( method inner_zo_eval_forward (line 594) | def inner_zo_eval_forward( FILE: zo2/model/huggingface/zo_init.py function zo_hf_init (line 25) | def zo_hf_init(zo_config): function main (line 37) | def main(): FILE: zo2/model/nanogpt/__init__.py function get_nanogpt (line 8) | def get_nanogpt(zo_config): FILE: zo2/model/nanogpt/mezo_sgd/__init__.py function get_nanogpt_mezo_sgd (line 9) | def get_nanogpt_mezo_sgd(config: MeZOSGDConfig): FILE: zo2/model/nanogpt/mezo_sgd/zo.py class GPT (line 13) | class GPT(model.GPT, BaseZOModel): method __init__ (line 14) | def __init__(self, config: model.GPTConfig, zo_config: MeZOSGDConfig): method forward (line 18) | def forward(self, idx, pos, targets=None): class Optimizer (line 26) | class Optimizer(MeZOSGD): method inner_zo_forward (line 29) | def inner_zo_forward(self, idx, pos, targets): method inner_zo_eval_forward (line 44) | def inner_zo_eval_forward(self, eval_fn, idx, pos, targets): FILE: zo2/model/nanogpt/mezo_sgd/zo2.py class GPT (line 14) | class GPT(model.GPT, BaseZOModel): method __init__ (line 15) | def __init__(self, config: model.GPTConfig, zo_config: MeZOSGDConfig): method forward (line 19) | def forward(self, idx, pos, targets=None): class Optimizer (line 27) | class Optimizer(MeZO2SGD): method init_zo2_upload (line 29) | def init_zo2_upload(self): method inner_zo_forward (line 50) | def inner_zo_forward(self, idx, pos, targets): method inner_zo_eval_forward (line 114) | def inner_zo_eval_forward(self, eval_fn, idx, pos, targets): FILE: zo2/model/nanogpt/model.py class GPTConfig (line 19) | class GPTConfig: class GPTConfigs (line 28) | class GPTConfigs: class LayerNorm (line 43) | class LayerNorm(nn.Module): method __init__ (line 46) | def __init__(self, ndim, bias): method forward (line 51) | def forward(self, input): class CausalSelfAttention (line 54) | class CausalSelfAttention(nn.Module): method __init__ (line 56) | def __init__(self, config): method forward (line 77) | def forward(self, x): class MLP (line 103) | class MLP(nn.Module): method __init__ (line 105) | def __init__(self, config): method forward (line 112) | def forward(self, x): class Block (line 119) | class Block(nn.Module): method __init__ (line 121) | def __init__(self, config): method forward (line 128) | def forward(self, x): class GPT (line 134) | class GPT(nn.Module): method __init__ (line 136) | def __init__(self, config): method get_num_params (line 166) | def get_num_params(self, non_embedding=True): method _init_weights (line 178) | def _init_weights(self, module): method forward (line 186) | def forward(self, idx, pos, targets=None): method crop_block_size (line 205) | def crop_block_size(self, block_size): method from_pretrained (line 217) | def from_pretrained(cls, model_type, override_args=None): method configure_optimizers (line 273) | def configure_optimizers(self, weight_decay, learning_rate, betas, dev... method estimate_mfu (line 299) | def estimate_mfu(self, fwdbwd_per_iter, dt): method generate (line 315) | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): FILE: zo2/optimizer/base.py class BaseOptimizer (line 7) | class BaseOptimizer(Optimizer): method __init__ (line 12) | def __init__(self, params, defaults): method _update_lr (line 25) | def _update_lr(self): method _set_lr (line 28) | def _set_lr(self): FILE: zo2/optimizer/mezo_sgd/utils/comm.py function module_to_bucket_inplace (line 9) | def module_to_bucket_inplace(module: nn.Module): function bucket_to_module_inplace (line 13) | def bucket_to_module_inplace(bucket: torch.Tensor, module: nn.Module): function create_disk_offload_path (line 23) | def create_disk_offload_path(path, module_id): function get_disk_offload_path (line 35) | def get_disk_offload_path(path, module_id): function clear_disk_offload_path (line 38) | def clear_disk_offload_path(path, module_id): function set_nested_attr (line 46) | def set_nested_attr(obj, attr, value): FILE: zo2/optimizer/mezo_sgd/zo.py class MeZOSGD (line 16) | class MeZOSGD(BaseOptimizer): method __init__ (line 21) | def __init__(self, model: nn.Module, config: MeZOSGDConfig): method zo_perturb_parameters (line 47) | def zo_perturb_parameters(self, module: nn.Module, scaling_factor: flo... method zo_update (line 65) | def zo_update(self, module, weight_decay=None): method zo_perturb_shifts (line 90) | def zo_perturb_shifts(self, first_perturb_shift=1, stride=2): method compute_grad (line 99) | def compute_grad(self, loss1, loss2): method zo_forward (line 103) | def zo_forward(self, *args, zo_random_seed: int=None, **kwargs): method zo_eval_forward (line 129) | def zo_eval_forward(self, *args, **kwargs): method inner_zo_forward (line 139) | def inner_zo_forward(self, idx, pos, targets): method inner_zo_eval_forward (line 159) | def inner_zo_eval_forward(self, eval_fn, idx, pos, targets): FILE: zo2/optimizer/mezo_sgd/zo2.py class MeZO2SGD (line 17) | class MeZO2SGD(MeZOSGD): method __init__ (line 26) | def __init__(self, model, config: MeZOSGDConfig): method init_zo2 (line 52) | def init_zo2(self): method init_zo2_amp (line 68) | def init_zo2_amp(self): method assign_zo2_attributes (line 84) | def assign_zo2_attributes(self, source, target): method zo_update (line 100) | def zo_update(self, module, weight_decay=None): method module_dual_forward (line 116) | def module_dual_forward(self, module, inputs1, inputs2, projected_grad... method function_dual_forward (line 146) | def function_dual_forward(self, fn, inputs1, inputs2): method zo_forward (line 164) | def zo_forward(self, *args, seed: int=None, **kwargs): method task_upload (line 188) | def task_upload(self, module, device='cuda', upload_sync=False, *args,... method task_offload (line 212) | def task_offload(self, module, device='cpu', offload_sync=False, *args... method task_compute_module (line 237) | def task_compute_module(self, module, inputs1, inputs2, grad, compute_... method task_compute_function (line 297) | def task_compute_function(self, fn, inputs1, inputs2, compute_sync=Fal... method zo_eval_forward (line 355) | def zo_eval_forward(self, *args, **kwargs): method add_zo2_eval_comm_hooks (line 371) | def add_zo2_eval_comm_hooks(self, blocks): method clear_zo2_eval_comm_hooks (line 391) | def clear_zo2_eval_comm_hooks(self, handles): method eval_upload_hook (line 401) | def eval_upload_hook(self, module, input): method eval_offload_hook (line 416) | def eval_offload_hook(self, module, input, output): method upload_impl (line 442) | def upload_impl( method offload_impl (line 484) | def offload_impl( method compute_module_impl (line 525) | def compute_module_impl( method compute_function_impl (line 551) | def compute_function_impl( method amp_decompress_impl (line 577) | def amp_decompress_impl(self, module: nn.Module) -> nn.Module: method amp_compress_impl (line 597) | def amp_compress_impl(self, module: nn.Module) -> nn.Module: method init_zo2_upload (line 618) | def init_zo2_upload(self): method inner_zo_forward (line 646) | def inner_zo_forward(self, idx, pos, targets): method inner_zo_eval_forward (line 721) | def inner_zo_eval_forward(self, eval_fn, idx, pos, targets): FILE: zo2/trainer/hf_transformers/trainer.py function _is_peft_model (line 259) | def _is_peft_model(model): function _get_fsdp_ckpt_kwargs (line 271) | def _get_fsdp_ckpt_kwargs(): function safe_globals (line 279) | def safe_globals(): class ZOTrainer (line 317) | class ZOTrainer(Trainer): method __init__ (line 322) | def __init__( method _inner_training_loop (line 359) | def _inner_training_loop( method _load_optimizer_and_scheduler (line 875) | def _load_optimizer_and_scheduler(self, checkpoint, model=None): method create_optimizer_and_scheduler (line 885) | def create_optimizer_and_scheduler(self, num_training_steps: int, mode... method _move_model_to_device (line 903) | def _move_model_to_device(self, model, device): method _zo2_unsupported_conditions (line 908) | def _zo2_unsupported_conditions(self, args): method register_zo2_training_step_pre_hook (line 920) | def register_zo2_training_step_pre_hook(self, hook_fn): method register_zo2_training_step_post_hook (line 931) | def register_zo2_training_step_post_hook(self, hook_fn): method zo2_training_step (line 952) | def zo2_training_step(self, model: nn.Module, inputs: dict[str, Union[... FILE: zo2/trainer/hf_trl/sft_trainer.py class ZOSFTTrainer (line 226) | class ZOSFTTrainer(SFTTrainer): method __init__ (line 228) | def __init__( method _inner_training_loop (line 278) | def _inner_training_loop(self, batch_size=None, args=None, resume_from... method train (line 713) | def train(self, *args, **kwargs): method _load_optimizer_and_scheduler (line 738) | def _load_optimizer_and_scheduler(self, checkpoint, model=None): method create_optimizer_and_scheduler (line 748) | def create_optimizer_and_scheduler(self, num_training_steps: int, mode... method _move_model_to_device (line 766) | def _move_model_to_device(self, model, device): method _zo2_unsupported_conditions (line 771) | def _zo2_unsupported_conditions(self, args): method register_zo2_training_step_pre_hook (line 787) | def register_zo2_training_step_pre_hook(self, hook_fn): method register_zo2_training_step_post_hook (line 790) | def register_zo2_training_step_post_hook(self, hook_fn): method zo2_training_step (line 793) | def zo2_training_step(self, model: nn.Module, inputs: Dict[str, Union[... FILE: zo2/utils/utils.py function print_all (line 10) | def print_all(module: nn.Module, inputs, outputs): function print_hook (line 29) | def print_hook(module, input, output): function print_para_and_device (line 33) | def print_para_and_device(model): function cal_self_reg_loss (line 37) | def cal_self_reg_loss(logits, labels): function seed_everything (line 44) | def seed_everything(seed):