SYMBOL INDEX (348 symbols across 11 files) FILE: algo.py class AR (line 16) | class AR(trainer_base.TrainerBase): method __init__ (line 17) | def __init__(self, config, tokenizer): method _validate_configuration (line 30) | def _validate_configuration(self): method _process_model_input (line 35) | def _process_model_input(self, x0, valid_tokens): method nll (line 41) | def nll(self, input_tokens, output_tokens, method generate_samples (line 50) | def generate_samples(self, num_samples, **kwargs): method _process_sigma (line 72) | def _process_sigma(self, sigma): class MDLM (line 77) | class MDLM(trainer_base.AbsorbingState): method __init__ (line 78) | def __init__(self, config, tokenizer): method _validate_configuration (line 82) | def _validate_configuration(self): method _process_model_output (line 87) | def _process_model_output(self, model_output, xt, sigma): method nll_per_token (line 104) | def nll_per_token(self, log_x_theta, xt, x0, alpha_t, method _get_score (line 113) | def _get_score(self, x, sigma): class D3PMAbsorb (line 158) | class D3PMAbsorb(trainer_base.AbsorbingState): method __init__ (line 159) | def __init__(self, config, tokenizer): method _validate_configuration (line 163) | def _validate_configuration(self): method _process_model_output (line 168) | def _process_model_output(self, model_output, xt, sigma): method nll_per_token (line 175) | def nll_per_token(self, log_x_theta, xt, x0, alpha_t, class SEDDAbsorb (line 207) | class SEDDAbsorb(trainer_base.AbsorbingState): method __init__ (line 208) | def __init__(self, config, tokenizer): method _validate_configuration (line 212) | def _validate_configuration(self): method _get_score (line 216) | def _get_score(self, x, sigma): method _process_model_output (line 219) | def _process_model_output(self, model_output, xt, sigma): method nll_per_token (line 236) | def nll_per_token(self, log_x_theta, xt, x0, alpha_t, class DUO_BASE (line 286) | class DUO_BASE(trainer_base.UniformState): method __init__ (line 287) | def __init__(self, config, tokenizer): method on_save_checkpoint (line 291) | def on_save_checkpoint(self, checkpoint): method on_load_checkpoint (line 297) | def on_load_checkpoint(self, checkpoint): method _process_model_output (line 303) | def _process_model_output(self, model_output, xt, sigma): method _posterior_from_x0 (line 307) | def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t): method nll_per_token (line 337) | def nll_per_token(self, log_x_theta, xt, x0, alpha_t, class Integral (line 375) | class Integral(torch.autograd.Function): method forward (line 381) | def forward(ctx, gamma_t, data): method backward (line 400) | def backward(ctx, grad_output): class DUO (line 404) | class DUO(DUO_BASE): method __init__ (line 405) | def __init__(self, config, tokenizer): method _initialize_curriculum_coefficients (line 419) | def _initialize_curriculum_coefficients(self): method _init_curriculum_cached (line 432) | def _init_curriculum_cached(self): method _init_curriculum_series (line 441) | def _init_curriculum_series(self): method _init_curriculum_approx (line 454) | def _init_curriculum_approx(self): method to (line 486) | def to(self, *args, **kwargs): method cuda (line 494) | def cuda(self, device=None): method cpu (line 503) | def cpu(self): method to (line 512) | def to(self, *args, **kwargs): method _compute_gumbel_tau_inverse (line 521) | def _compute_gumbel_tau_inverse(self): method training_step (line 535) | def training_step(self, batch, batch_idx): method _gamma_to_alpha_dalpha (line 543) | def _gamma_to_alpha_dalpha(self, gamma_t, t): method _gamma_to_alphat_integral (line 559) | def _gamma_to_alphat_integral(self, gamma_t): method _gamma_to_alpha_dalpha_cached (line 564) | def _gamma_to_alpha_dalpha_cached(self, gamma_t): method _prior_loss (line 573) | def _prior_loss(self): method _q_xt_gaussian (line 582) | def _q_xt_gaussian(self, x, gamma_t): method nll (line 593) | def nll(self, x0, output_tokens, class Distillation (line 639) | class Distillation(DUO): method __init__ (line 640) | def __init__(self, config, tokenizer): method _validate_configuration (line 650) | def _validate_configuration(self): method _maybe_update_teacher_weights (line 663) | def _maybe_update_teacher_weights(self): method _teacher_logits (line 675) | def _teacher_logits(self, xt, sigma): method _sample_trajectory (line 687) | def _sample_trajectory(self, x0, gamma_t, gamma_s): method _compute_dt (line 708) | def _compute_dt(self): method nll (line 716) | def nll(self, x0, output_tokens, method training_step (line 752) | def training_step(self, batch, batch_idx): FILE: dataloader.py class RawPixelsVisionTokenizer (line 31) | class RawPixelsVisionTokenizer: method __init__ (line 32) | def __init__(self, vocab_size, image_size, method __call__ (line 53) | def __call__(self, x): method batch_decode (line 56) | def batch_decode(self, x): method decode (line 62) | def decode(self, x): method __len__ (line 68) | def __len__(self): class DiscreteCIFAR10 (line 72) | class DiscreteCIFAR10(torch.utils.data.Dataset): method __init__ (line 73) | def __init__(self, cache_dir, train): method __len__ (line 88) | def __len__(self): method __getitem__ (line 91) | def __getitem__(self, index): function wt_detokenizer (line 100) | def wt_detokenizer(string): function ptb_detokenizer (line 132) | def ptb_detokenizer(x): function lm1b_detokenizer (line 146) | def lm1b_detokenizer(x): function lambada_detokenizer (line 169) | def lambada_detokenizer(text): function scientific_papers_detokenizer (line 175) | def scientific_papers_detokenizer(x): class SyntheticTokenizer (line 181) | class SyntheticTokenizer( method __init__ (line 184) | def __init__( method vocab_size (line 221) | def vocab_size(self) -> int: method _tokenize (line 224) | def _tokenize(self, text: str, **kwargs) -> typing.List[str]: method _convert_token_to_id (line 227) | def _convert_token_to_id(self, token: str) -> int: method _convert_id_to_token (line 231) | def _convert_id_to_token(self, index: int) -> str: method convert_tokens_to_string (line 234) | def convert_tokens_to_string(self, tokens): method get_vocab (line 237) | def get_vocab(self) -> typing.Dict[str, int]: function _generate_synthetic_data (line 241) | def _generate_synthetic_data(dataset_size, function generate_synthetic_dataset (line 261) | def generate_synthetic_dataset(train_dataset_size, class Text8Tokenizer (line 290) | class Text8Tokenizer(transformers.PreTrainedTokenizer): method __init__ (line 291) | def __init__( method vocab_size (line 325) | def vocab_size(self) -> int: method _tokenize (line 328) | def _tokenize(self, text: str, **kwargs) -> typing.List[str]: method _convert_token_to_id (line 331) | def _convert_token_to_id(self, token: str) -> int: method _convert_id_to_token (line 335) | def _convert_id_to_token(self, index: int) -> str: method convert_tokens_to_string (line 338) | def convert_tokens_to_string(self, tokens): method get_vocab (line 341) | def get_vocab(self) -> typing.Dict[str, int]: function get_lambada_test_dataset (line 345) | def get_lambada_test_dataset(): function get_text8_dataset (line 365) | def get_text8_dataset(cache_dir, max_seq_length=256, function _group_texts (line 462) | def _group_texts(examples, block_size, bos, eos): function get_dataset (line 488) | def get_dataset(dataset_name, function get_tokenizer (line 712) | def get_tokenizer(config): function get_dataloaders (line 755) | def get_dataloaders(config, tokenizer, skip_train=False, class RandomFaultTolerantSampler (line 843) | class RandomFaultTolerantSampler(torch.utils.data.RandomSampler): method __init__ (line 845) | def __init__(self, *args, generator=None, **kwargs): method state_dict (line 858) | def state_dict(self): method load_state_dict (line 862) | def load_state_dict(self, state_dict): method __iter__ (line 871) | def __iter__(self) -> typing.Iterator[int]: class FaultTolerantDistributedSampler (line 890) | class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler): method __init__ (line 892) | def __init__(self, *args, **kwargs): method state_dict (line 897) | def state_dict(self): method load_state_dict (line 900) | def load_state_dict(self, state_dict): method __iter__ (line 907) | def __iter__(self): FILE: discrete_diffusion_harness.py function requests_to_dataset (line 132) | def requests_to_dataset(config, requests, tokenizer, num_proc): function _eval_suffix_nll_generators (line 173) | def _eval_suffix_nll_generators(config, module, prefix, function eval_suffix_nll (line 205) | def eval_suffix_nll(config, module, prefix, suffix, batch_size, function eval_suffix_nll_ar (line 218) | def eval_suffix_nll_ar(config, module, prefix, suffix, function eval_suffix_nll_diffusion (line 234) | def eval_suffix_nll_diffusion(config, module, prefix, suffix, class DiscreteDiffusionHarness (line 252) | class DiscreteDiffusionHarness(LM): method __init__ (line 253) | def __init__(self, pretrained="NONE", max_length=1024, method suffix_greedy_prediction (line 290) | def suffix_greedy_prediction(self, prefix, target): method _suffix_greedy_prediction_ar (line 303) | def _suffix_greedy_prediction_ar(self, prefix, target): method _suffix_greedy_prediction_mdlm (line 316) | def _suffix_greedy_prediction_mdlm(self, prefix, target): method _suffix_greedy_prediction_duo_base (line 337) | def _suffix_greedy_prediction_duo_base(self, prefix, target): method loglikelihood (line 362) | def loglikelihood(self, requests: list[Instance]) \ method loglikelihood_rolling (line 387) | def loglikelihood_rolling( method generate_until (line 392) | def generate_until(self, context, max_length, stop, FILE: main.py function _load_from_checkpoint (line 31) | def _load_from_checkpoint(diffusion_model, config, tokenizer): function _print_config (line 43) | def _print_config( function _print_batch (line 78) | def _print_batch(config, train_ds, valid_ds, tokenizer, k=64): function _generate_samples (line 93) | def _generate_samples(diffusion_model, config, logger, function _eval_ppl (line 142) | def _eval_ppl(diffusion_model, config, logger, tokenizer): function _train (line 173) | def _train(diffusion_model, config, logger, tokenizer): function _eval_fid (line 217) | def _eval_fid(diffusion_model, config, logger, tokenizer): function main (line 298) | def main(config): FILE: metrics.py class NLL (line 13) | class NLL(torchmetrics.aggregation.MeanMetric): method update (line 14) | def update(self, class BPD (line 47) | class BPD(NLL): method compute (line 48) | def compute(self) -> torch.Tensor: class Perplexity (line 57) | class Perplexity(NLL): method compute (line 58) | def compute(self) -> torch.Tensor: class Metrics (line 67) | class Metrics: method __init__ (line 68) | def __init__(self, gen_ppl_eval_model_name_or_path=None, method to (line 87) | def to(self, *args, **kwargs): method reset (line 95) | def reset(self): method update_train (line 103) | def update_train(self, nll, aux_loss, num_tokens): method update_valid (line 107) | def update_valid(self, nll, aux_loss, num_tokens): method _eval_retokenize (line 113) | def _eval_retokenize(self, text_samples, max_length, method record_entropy (line 155) | def record_entropy(self, tokens): method record_generative_perplexity (line 164) | def record_generative_perplexity( FILE: models/dit.py function bias_dropout_add_scale (line 20) | def bias_dropout_add_scale( function get_bias_dropout_add_scale (line 37) | def get_bias_dropout_add_scale(training): function modulate (line 46) | def modulate(x: torch.Tensor, function bias_dropout_add_scale_fused_train (line 53) | def bias_dropout_add_scale_fused_train( function bias_dropout_add_scale_fused_inference (line 64) | def bias_dropout_add_scale_fused_inference( function modulate_fused (line 75) | def modulate_fused(x: torch.Tensor, class Rotary (line 81) | class Rotary(torch.nn.Module): method __init__ (line 82) | def __init__(self, dim, base=10_000): method forward (line 90) | def forward(self, x, seq_dim=1): function rotate_half (line 107) | def rotate_half(x): function split_and_apply_rotary_pos_emb (line 112) | def split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin): function apply_rotary_pos_emb (line 128) | def apply_rotary_pos_emb(qkv, cos, sin): function regular_attention_multi_headed (line 134) | def regular_attention_multi_headed(q, k, v): class LayerNorm (line 152) | class LayerNorm(nn.Module): method __init__ (line 153) | def __init__(self, dim): method forward (line 157) | def forward(self, x): function residual_linear (line 163) | def residual_linear(x, W, x_skip, residual_scale): class TimestepEmbedder (line 176) | class TimestepEmbedder(nn.Module): method __init__ (line 180) | def __init__(self, hidden_size, frequency_embedding_size=256): method timestep_embedding (line 189) | def timestep_embedding(t, dim, max_period=10000): method forward (line 212) | def forward(self, t): class LabelEmbedder (line 218) | class LabelEmbedder(nn.Module): method __init__ (line 223) | def __init__(self, num_classes, cond_size): method forward (line 230) | def forward(self, labels): class DDiTBlockCausal (line 239) | class DDiTBlockCausal(nn.Module): method __init__ (line 240) | def __init__(self, dim, n_heads, mlp_ratio=4, dropout=0.1): method _get_bias_dropout_scale (line 257) | def _get_bias_dropout_scale(self): method forward (line 263) | def forward(self, x, rotary_cos_sin, **kwargs): class DDiTBlock (line 305) | class DDiTBlock(nn.Module): method __init__ (line 306) | def __init__(self, dim, n_heads, adaLN, method _get_bias_dropout_scale (line 332) | def _get_bias_dropout_scale(self): method forward (line 339) | def forward(self, x, rotary_cos_sin, c=None): class EmbeddingLayer (line 382) | class EmbeddingLayer(nn.Module): method __init__ (line 383) | def __init__(self, dim, vocab_dim): method forward (line 388) | def forward(self, x, weights=None): class DDiTFinalLayer (line 406) | class DDiTFinalLayer(nn.Module): method __init__ (line 407) | def __init__(self, hidden_size, out_channels, cond_dim, method forward (line 422) | def forward(self, x, c): class DIT (line 431) | class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin): method __init__ (line 432) | def __init__(self, config, vocab_size: int): method _get_bias_dropout_scale (line 471) | def _get_bias_dropout_scale(self): method forward (line 477) | def forward(self, x, sigma, class_cond=None, weights=None): FILE: models/ema.py class ExponentialMovingAverage (line 4) | class ExponentialMovingAverage: method __init__ (line 9) | def __init__(self, parameters, decay, use_num_updates=True): method move_shadow_params_to_device (line 26) | def move_shadow_params_to_device(self, device): method update (line 29) | def update(self, parameters): method copy_to (line 51) | def copy_to(self, parameters): method store (line 64) | def store(self, parameters): method restore (line 74) | def restore(self, parameters): method state_dict (line 89) | def state_dict(self): method load_state_dict (line 94) | def load_state_dict(self, state_dict): FILE: models/unet.py function transformer_timestep_embedding (line 15) | def transformer_timestep_embedding(timesteps, embedding_dim, max_positio... function variance_scaling (line 33) | def variance_scaling(scale, mode, distribution, function default_init (line 67) | def default_init(scale=1.): class NiN (line 73) | class NiN(nn.Module): method __init__ (line 74) | def __init__(self, in_ch, out_ch, init_scale=0.1): method forward (line 79) | def forward(self, x, # ["batch", "in_ch", "H", "W"] class AttnBlock (line 88) | class AttnBlock(nn.Module): method __init__ (line 90) | def __init__(self, channels, skip_rescale=True): method forward (line 100) | def forward(self, x, # ["batch", "channels", "H", "W"] class ResBlock (line 122) | class ResBlock(nn.Module): method __init__ (line 123) | def __init__(self, in_ch, out_ch, temb_dim=None, dropout=0.1, skip_res... method forward (line 157) | def forward(self, x, # ["batch", "in_ch", "H", "W"] class Downsample (line 184) | class Downsample(nn.Module): method __init__ (line 185) | def __init__(self, channels): method forward (line 190) | def forward(self, x, # ["batch", "ch", "inH", "inW"] class Upsample (line 199) | class Upsample(nn.Module): method __init__ (line 200) | def __init__(self, channels): method forward (line 204) | def forward(self, x, # ["batch", "ch", "inH", "inW"] class UNet (line 214) | class UNet(nn.Module): method __init__ (line 215) | def __init__(self, config, vocab_size=None): method _center_data (line 344) | def _center_data(self, x): method _time_embedding (line 348) | def _time_embedding(self, timesteps): method _do_input_conv (line 360) | def _do_input_conv(self, h): method _do_downsampling (line 365) | def _do_downsampling(self, h, hs, temb): method _do_middle (line 385) | def _do_middle(self, h, temb): method _do_upsampling (line 398) | def _do_upsampling(self, h, hs, temb): method _do_output (line 418) | def _do_output(self, h): method _logistic_output_res (line 426) | def _logistic_output_res(self, method _log_minus_exp (line 435) | def _log_minus_exp(self, a, b, eps=1e-6): method _truncated_logistic_output (line 443) | def _truncated_logistic_output(self, net_out): method forward (line 477) | def forward(self, FILE: models/unit_test_attention.py function attention_inner_heads_flash (line 9) | def attention_inner_heads_flash(qkv, num_heads): class TestAttentionInnerHeadsFlash (line 62) | class TestAttentionInnerHeadsFlash(unittest.TestCase): method setUp (line 63) | def setUp(self): method attention_inner_heads_old (line 74) | def attention_inner_heads_old(self, qkv, num_heads): method test_attention_inner_heads_flash (line 90) | def test_attention_inner_heads_flash(self): FILE: trainer_base.py class Loss (line 18) | class Loss: class LogLinear (line 25) | class LogLinear(torch.nn.Module): method __init__ (line 26) | def __init__(self, eps): method forward (line 30) | def forward(self, t): method get_t_for_alpha (line 36) | def get_t_for_alpha(self, alpha_t): class Cosine (line 40) | class Cosine(torch.nn.Module): method __init__ (line 41) | def __init__(self, eps): method forward (line 46) | def forward(self, t): method get_t_for_alpha (line 52) | def get_t_for_alpha(self, alpha_t): function sample_categorical (line 62) | def sample_categorical(categorical_probs): function _unsqueeze (line 69) | def _unsqueeze(x, reference): class TrainerBase (line 75) | class TrainerBase(L.LightningModule): method __init__ (line 76) | def __init__( method _validate_configuration (line 147) | def _validate_configuration(self): method to (line 161) | def to(self, *args, **kwargs): method q_xt (line 166) | def q_xt(self, x, alpha_t): method _get_parameters (line 169) | def _get_parameters(self): method _eval_mode (line 173) | def _eval_mode(self): method _train_mode (line 180) | def _train_mode(self): method on_load_checkpoint (line 186) | def on_load_checkpoint(self, checkpoint): method on_save_checkpoint (line 197) | def on_save_checkpoint(self, checkpoint): method on_train_start (line 236) | def on_train_start(self): method optimizer_step (line 273) | def optimizer_step(self, *args, **kwargs): method _process_sigma (line 278) | def _process_sigma(self, sigma): method _process_model_output (line 281) | def _process_model_output(self, model_output, xt, sigma): method forward (line 284) | def forward(self, xt, sigma, labels=None, weights=None, method on_train_epoch_start (line 297) | def on_train_epoch_start(self): method training_step (line 302) | def training_step(self, batch, batch_idx): method on_train_epoch_end (line 319) | def on_train_epoch_end(self): method on_validation_epoch_start (line 324) | def on_validation_epoch_start(self): method validation_step (line 330) | def validation_step(self, batch, batch_idx): method on_validation_epoch_end (line 339) | def on_validation_epoch_end(self): method configure_optimizers (line 381) | def configure_optimizers(self): method generate_samples (line 398) | def generate_samples(self, num_samples, num_steps, eps): method restore_model_and_sample (line 401) | def restore_model_and_sample(self, num_steps, eps=1e-5): method _process_model_input (line 412) | def _process_model_input(self, x0, valid_tokens): method nll (line 415) | def nll(self, input_tokens, labels, output_tokens, method _loss (line 419) | def _loss(self, x0, labels, valid_tokens, class Diffusion (line 442) | class Diffusion(TrainerBase): method _validate_configuration (line 443) | def _validate_configuration(self): method _process_model_input (line 452) | def _process_model_input(self, x0, valid_tokens): method _process_sigma (line 455) | def _process_sigma(self, sigma): method _sample_t (line 465) | def _sample_t(self, n, accum_step): method _sigma_from_alphat (line 484) | def _sigma_from_alphat(self, alpha_t): method _reconstruction_loss (line 487) | def _reconstruction_loss(self, x0): method nll_per_token (line 496) | def nll_per_token(self, model_output, xt, x0, alpha_t, method nll (line 500) | def nll(self, x0, labels, output_tokens, method _get_score (line 540) | def _get_score(self, **kwargs): method _denoiser_update (line 544) | def _denoiser_update(self, x, t): method _analytic_update (line 547) | def _analytic_update(self, x, t, dt): method _posterior_from_x0 (line 550) | def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t): method _forward_process (line 555) | def _forward_process(self, q_x0, alpha_s): method _get_ancestral_posterior (line 560) | def _get_ancestral_posterior(self, xt, sigma, labels, method _get_posterior_from_xt (line 584) | def _get_posterior_from_xt(self, xt, sigma, labels, alpha_s, method _get_guided_posterior_from_xt (line 602) | def _get_guided_posterior_from_xt(self, xt, sigma, labels, method _ancestral_update (line 633) | def _ancestral_update(self, x, t, labels, dt, p_x0=None, method _psi_update (line 650) | def _psi_update(self, x, t, labels, dt, kappa, p_x0=None, method _get_sampling_time_profile (line 674) | def _get_sampling_time_profile(self, eps, num_steps): method _mode_to_psi_kappas (line 694) | def _mode_to_psi_kappas(self, mode, timesteps): method _get_kappas (line 727) | def _get_kappas(self, timesteps): method generate_samples (line 743) | def generate_samples(self, num_samples, labels=None, method _semi_ar_sampler (line 800) | def _semi_ar_sampler( method restore_model_and_semi_ar_sample (line 840) | def restore_model_and_semi_ar_sample( class AbsorbingState (line 856) | class AbsorbingState(Diffusion): method __init__ (line 857) | def __init__(self, config, tokenizer): method _validate_configuration (line 875) | def _validate_configuration(self): method q_xt (line 886) | def q_xt(self, x, alpha_t): method prior_sample (line 901) | def prior_sample(self, *batch_dims): method _posterior_from_x0 (line 905) | def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t): method _forward_process (line 919) | def _forward_process(self, x0, alpha_s): method _staggered_score (line 924) | def _staggered_score(self, score, dsigma): method _analytic_update (line 931) | def _analytic_update(self, x, t, dt): method _denoiser_update (line 942) | def _denoiser_update(self, x, t): method _transp_transition (line 953) | def _transp_transition(self, i, sigma): class UniformState (line 963) | class UniformState(Diffusion): method _validate_configuration (line 964) | def _validate_configuration(self): method _forward_process (line 971) | def _forward_process(self, x0, alpha_s): method q_xt (line 975) | def q_xt(self, x, alpha_t): method prior_sample (line 993) | def prior_sample(self, *batch_dims): FILE: utils.py function count_parameters (line 24) | def count_parameters(model): function fsspec_exists (line 29) | def fsspec_exists(filename): function fsspec_listdir (line 35) | def fsspec_listdir(dirname): function fsspec_mkdirs (line 41) | def fsspec_mkdirs(dirname, exist_ok=True): function print_nans (line 47) | def print_nans(tensor, name): class LRHalveScheduler (line 52) | class LRHalveScheduler: method __init__ (line 53) | def __init__(self, warmup_steps, n_halve_steps): method __call__ (line 57) | def __call__(self, current_step): class CosineDecayWarmupLRScheduler (line 64) | class CosineDecayWarmupLRScheduler( method __init__ (line 74) | def __init__(self, *args, **kwargs): method step (line 79) | def step(self, epoch=None): class LoggingContext (line 97) | class LoggingContext: method __init__ (line 99) | def __init__(self, logger, level=None, handler=None, close=True): method __enter__ (line 105) | def __enter__(self): method __exit__ (line 112) | def __exit__(self, et, ev, tb): class GradientInspectionCallback (line 121) | class GradientInspectionCallback(lightning.Callback): method __init__ (line 122) | def __init__(self, num_grads_log): method on_before_optimizer_step (line 125) | def on_before_optimizer_step(self, trainer, pl_module, optimizer): function get_logger (line 158) | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: function top_k_top_p_filtering (line 177) | def top_k_top_p_filtering( function _discrete_prob_map (line 244) | def _discrete_prob_map(gamma_t, N=10): function _discrete_prob_grad (line 253) | def _discrete_prob_grad(gamma_t, N=10): function _cache_prob_usdm_in_partition (line 263) | def _cache_prob_usdm_in_partition( function test_cache_prob_usdm_in_partition (line 303) | def test_cache_prob_usdm_in_partition( function compute_duo_series_coefficients (line 340) | def compute_duo_series_coefficients(num_coefficients, function compute_duo_gamma_to_alpha_dalpha_series (line 373) | def compute_duo_gamma_to_alpha_dalpha_series( function duo_t_to_alpha_dalpha_sigm_corrected (line 411) | def duo_t_to_alpha_dalpha_sigm_corrected( function duo_to_alpha_dalpha_sigmoid (line 432) | def duo_to_alpha_dalpha_sigmoid(t: torch.Tensor, a: float, function duo_to_alpha_dalpha_poly (line 440) | def duo_to_alpha_dalpha_poly(t: torch.Tensor, function compute_duo_operator_approx (line 453) | def compute_duo_operator_approx(num_coefficients, vocab_size, function _sample_k_int (line 511) | def _sample_k_int(bs: int, l: int, k: int, max_value: int, function _sample_topk_gaussian (line 529) | def _sample_topk_gaussian(N: int, function _sample_topk_and_extra (line 559) | def _sample_topk_and_extra(N: int, alpha: torch.Tensor, function _log_mean_exp_trunc_normal (line 575) | def _log_mean_exp_trunc_normal(c: torch.Tensor, function sample_tempered_softmax_topk (line 591) | def sample_tempered_softmax_topk(