SYMBOL INDEX (62 symbols across 4 files) FILE: d3pm_runner.py class DummyX0Model (line 39) | class DummyX0Model(nn.Module): method __init__ (line 41) | def __init__(self, n_channel: int, N: int = 16) -> None: method forward (line 72) | def forward(self, x, t, cond) -> torch.Tensor: class D3PM (line 134) | class D3PM(nn.Module): method __init__ (line 135) | def __init__( method _at (line 193) | def _at(self, a, t, x): method q_posterior_logits (line 200) | def q_posterior_logits(self, x_0, x_t, t): method vb (line 237) | def vb(self, dist1, dist2): method q_sample (line 249) | def q_sample(self, x_0, t, noise): method model_predict (line 256) | def model_predict(self, x_0, t, cond): method forward (line 266) | def forward(self, x: torch.Tensor, cond: torch.Tensor = None) -> torch... method p_sample (line 299) | def p_sample(self, x, t, cond, noise): method sample (line 314) | def sample(self, x, cond=None): method sample_with_image_sequence (line 323) | def sample_with_image_sequence(self, x, cond=None, stride=10): FILE: dit.py function modulate (line 10) | def modulate(x, shift, scale): class TimestepEmbedder (line 14) | class TimestepEmbedder(nn.Module): method __init__ (line 15) | def __init__(self, hidden_size, frequency_embedding_size=256): method timestep_embedding (line 25) | def timestep_embedding(t, dim, max_period=10000): method forward (line 38) | def forward(self, t): class LabelEmbedder (line 46) | class LabelEmbedder(nn.Module): method __init__ (line 47) | def __init__(self, num_classes, hidden_size, dropout_prob): method token_drop (line 56) | def token_drop(self, labels, force_drop_ids=None): method forward (line 66) | def forward(self, labels, train, force_drop_ids=None): class Attention (line 74) | class Attention(nn.Module): method __init__ (line 75) | def __init__(self, dim, n_heads): method reshape_for_broadcast (line 91) | def reshape_for_broadcast(freqs_cis, x): method apply_rotary_emb (line 99) | def apply_rotary_emb(xq, xk, freqs_cis): method forward (line 107) | def forward(self, x, freqs_cis): class FeedForward (line 134) | class FeedForward(nn.Module): method __init__ (line 135) | def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=No... method _forward_silu_gating (line 146) | def _forward_silu_gating(self, x1, x3): method forward (line 149) | def forward(self, x): class TransformerBlock (line 153) | class TransformerBlock(nn.Module): method __init__ (line 154) | def __init__( method forward (line 182) | def forward(self, x, freqs_cis, adaln_input=None): class FinalLayer (line 201) | class FinalLayer(nn.Module): method __init__ (line 202) | def __init__(self, hidden_size, patch_size, out_channels): method forward (line 216) | def forward(self, x, c): class DDiT_Llama (line 223) | class DDiT_Llama(nn.Module): method __init__ (line 225) | def __init__( method forward (line 262) | def forward(self, x, t, cond=None): class DiT_Llama (line 281) | class DiT_Llama(nn.Module): method __init__ (line 282) | def __init__( method unpatchify (line 338) | def unpatchify(self, x): method patchify (line 347) | def patchify(self, x): method forward (line 360) | def forward(self, x, t, y): method forward_with_cfg (line 394) | def forward_with_cfg(self, x, t, y, cfg_scale): method precompute_freqs_cis (line 405) | def precompute_freqs_cis(dim, end, theta=10000.0): function DiT_Llama_600M_patch2 (line 413) | def DiT_Llama_600M_patch2(**kwargs): function DiT_Llama_3B_patch2 (line 417) | def DiT_Llama_3B_patch2(**kwargs): FILE: lm.py class WikiTextDataset (line 14) | class WikiTextDataset(Dataset): method __init__ (line 15) | def __init__(self, tokenizer=None, type_path="train", max_length=512, ... method __len__ (line 25) | def __len__(self): method __getitem__ (line 30) | def __getitem__(self, idx): FILE: lm_deepspeed.py class WikiTextDataset (line 23) | class WikiTextDataset(Dataset): method __init__ (line 24) | def __init__( method __len__ (line 38) | def __len__(self): method __getitem__ (line 41) | def __getitem__(self, idx): function _z3_params_to_fetch (line 65) | def _z3_params_to_fetch(param_list): function save_zero_three_model (line 73) | def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0): function set_seed (line 100) | def set_seed(seed=42): function main (line 120) | def main(