SYMBOL INDEX (110 symbols across 10 files) FILE: diffusion/__init__.py function create_diffusion (line 10) | def create_diffusion( FILE: diffusion/diffusion_utils.py function normal_kl (line 10) | def normal_kl(mean1, logvar1, mean2, logvar2): function approx_standard_normal_cdf (line 39) | def approx_standard_normal_cdf(x): function continuous_gaussian_log_likelihood (line 47) | def continuous_gaussian_log_likelihood(x, *, means, log_scales): function discretized_gaussian_log_likelihood (line 62) | def discretized_gaussian_log_likelihood(x, *, means, log_scales): FILE: diffusion/gaussian_diffusion.py function mean_flat (line 16) | def mean_flat(tensor): class ModelMeanType (line 23) | class ModelMeanType(enum.Enum): class ModelVarType (line 33) | class ModelVarType(enum.Enum): class LossType (line 46) | class LossType(enum.Enum): method is_vb (line 54) | def is_vb(self): function _warmup_beta (line 58) | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_f... function get_beta_schedule (line 65) | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffus... function get_named_beta_schedule (line 98) | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): function betas_for_alpha_bar (line 125) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9... class GaussianDiffusion (line 144) | class GaussianDiffusion: method __init__ (line 153) | def __init__( method q_mean_variance (line 203) | def q_mean_variance(self, x_start, t): method q_sample (line 215) | def q_sample(self, x_start, t, noise=None): method q_posterior_mean_variance (line 232) | def q_posterior_mean_variance(self, x_start, x_t, t): method p_mean_variance (line 254) | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn... method _predict_xstart_from_eps (line 334) | def _predict_xstart_from_eps(self, x_t, t, eps): method _predict_eps_from_xstart (line 341) | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): method condition_mean (line 346) | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): method condition_score (line 358) | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): method p_sample (line 376) | def p_sample( method p_sample_loop (line 419) | def p_sample_loop( method p_sample_loop_progressive (line 464) | def p_sample_loop_progressive( method ddim_sample (line 513) | def ddim_sample( method ddim_reverse_sample (line 562) | def ddim_reverse_sample( method ddim_sample_loop (line 600) | def ddim_sample_loop( method ddim_sample_loop_progressive (line 633) | def ddim_sample_loop_progressive( method _vb_terms_bpd (line 682) | def _vb_terms_bpd( method training_losses (line 715) | def training_losses(self, model, x_start, t, model_kwargs=None, noise=... method _prior_bpd (line 789) | def _prior_bpd(self, x_start): method calc_bpd_loop (line 805) | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwar... function _extract_into_tensor (line 861) | def _extract_into_tensor(arr, timesteps, broadcast_shape): FILE: diffusion/respace.py function space_timesteps (line 12) | def space_timesteps(num_timesteps, section_counts): class SpacedDiffusion (line 65) | class SpacedDiffusion(GaussianDiffusion): method __init__ (line 73) | def __init__(self, use_timesteps, **kwargs): method p_mean_variance (line 89) | def p_mean_variance( method training_losses (line 94) | def training_losses( method condition_mean (line 99) | def condition_mean(self, cond_fn, *args, **kwargs): method condition_score (line 102) | def condition_score(self, cond_fn, *args, **kwargs): method _wrap_model (line 105) | def _wrap_model(self, model): method _scale_timesteps (line 112) | def _scale_timesteps(self, t): class _WrappedModel (line 117) | class _WrappedModel: method __init__ (line 118) | def __init__(self, model, timestep_map, original_num_steps): method __call__ (line 124) | def __call__(self, x, ts, **kwargs): FILE: diffusion/timestep_sampler.py function create_named_schedule_sampler (line 13) | def create_named_schedule_sampler(name, diffusion): class ScheduleSampler (line 27) | class ScheduleSampler(ABC): method weights (line 38) | def weights(self): method sample (line 44) | def sample(self, batch_size, device): class UniformSampler (line 62) | class UniformSampler(ScheduleSampler): method __init__ (line 63) | def __init__(self, diffusion): method weights (line 67) | def weights(self): class LossAwareSampler (line 71) | class LossAwareSampler(ScheduleSampler): method update_with_local_losses (line 72) | def update_with_local_losses(self, local_ts, local_losses): method update_with_all_losses (line 106) | def update_with_all_losses(self, ts, losses): class LossSecondMomentResampler (line 120) | class LossSecondMomentResampler(LossAwareSampler): method __init__ (line 121) | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): method weights (line 130) | def weights(self): method update_with_all_losses (line 139) | def update_with_all_losses(self, ts, losses): method _warmed_up (line 149) | def _warmed_up(self): FILE: download.py function find_model (line 18) | def find_model(model_name): function download_model (line 32) | def download_model(model_name): FILE: models.py function modulate (line 19) | def modulate(x, shift, scale): class TimestepEmbedder (line 27) | class TimestepEmbedder(nn.Module): method __init__ (line 31) | def __init__(self, hidden_size, frequency_embedding_size=256): method timestep_embedding (line 41) | def timestep_embedding(t, dim, max_period=10000): method forward (line 61) | def forward(self, t): class LabelEmbedder (line 67) | class LabelEmbedder(nn.Module): method __init__ (line 71) | def __init__(self, num_classes, hidden_size, dropout_prob): method token_drop (line 78) | def token_drop(self, labels, force_drop_ids=None): method forward (line 89) | def forward(self, labels, train, force_drop_ids=None): class DiTBlock (line 101) | class DiTBlock(nn.Module): method __init__ (line 105) | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwar... method forward (line 118) | def forward(self, x, c): class FinalLayer (line 125) | class FinalLayer(nn.Module): method __init__ (line 129) | def __init__(self, hidden_size, patch_size, out_channels): method forward (line 138) | def forward(self, x, c): class DiT (line 145) | class DiT(nn.Module): method __init__ (line 149) | def __init__( method initialize_weights (line 182) | def initialize_weights(self): method unpatchify (line 218) | def unpatchify(self, x): method forward (line 233) | def forward(self, x, t, y): method forward_with_cfg (line 250) | def forward_with_cfg(self, x, t, y, cfg_scale): function get_2d_sincos_pos_embed (line 274) | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra... function get_2d_sincos_pos_embed_from_grid (line 292) | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): function get_1d_sincos_pos_embed_from_grid (line 303) | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): function DiT_XL_2 (line 328) | def DiT_XL_2(**kwargs): function DiT_XL_4 (line 331) | def DiT_XL_4(**kwargs): function DiT_XL_8 (line 334) | def DiT_XL_8(**kwargs): function DiT_L_2 (line 337) | def DiT_L_2(**kwargs): function DiT_L_4 (line 340) | def DiT_L_4(**kwargs): function DiT_L_8 (line 343) | def DiT_L_8(**kwargs): function DiT_B_2 (line 346) | def DiT_B_2(**kwargs): function DiT_B_4 (line 349) | def DiT_B_4(**kwargs): function DiT_B_8 (line 352) | def DiT_B_8(**kwargs): function DiT_S_2 (line 355) | def DiT_S_2(**kwargs): function DiT_S_4 (line 358) | def DiT_S_4(**kwargs): function DiT_S_8 (line 361) | def DiT_S_8(**kwargs): FILE: sample.py function main (line 21) | def main(args): FILE: sample_ddp.py function create_npz_from_sample_folder (line 28) | def create_npz_from_sample_folder(sample_dir, num=50_000): function main (line 45) | def main(args): FILE: train.py function update_ema (line 40) | def update_ema(ema_model, model, decay=0.9999): function requires_grad (line 52) | def requires_grad(model, flag=True): function cleanup (line 60) | def cleanup(): function create_logger (line 67) | def create_logger(logging_dir): function center_crop_arr (line 85) | def center_crop_arr(pil_image, image_size): function main (line 110) | def main(args):