SYMBOL INDEX (204 symbols across 16 files) FILE: dist.py function initialized (line 16) | def initialized(): function initialize (line 20) | def initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, t... function get_rank (line 52) | def get_rank(): function get_local_rank (line 56) | def get_local_rank(): function get_world_size (line 60) | def get_world_size(): function get_device (line 64) | def get_device(): function set_gpu_id (line 68) | def set_gpu_id(gpu_id: int): function is_master (line 78) | def is_master(): function is_local_master (line 82) | def is_local_master(): function new_group (line 86) | def new_group(ranks: List[int]): function barrier (line 92) | def barrier(): function allreduce (line 97) | def allreduce(t: torch.Tensor, async_op=False): function allgather (line 109) | def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], to... function allgather_diff_shape (line 122) | def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.... function broadcast (line 149) | def broadcast(t: torch.Tensor, src_rank) -> None: function dist_fmt_vals (line 159) | def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[t... function master_only (line 171) | def master_only(func): function local_master_only (line 184) | def local_master_only(func): function for_visualize (line 197) | def for_visualize(func): function finalize (line 209) | def finalize(): FILE: models/__init__.py function build_vae_var (line 9) | def build_vae_var( FILE: models/basic_vae.py function nonlinearity (line 14) | def nonlinearity(x): function Normalize (line 18) | def Normalize(in_channels, num_groups=32): class Upsample2x (line 22) | class Upsample2x(nn.Module): method __init__ (line 23) | def __init__(self, in_channels): method forward (line 27) | def forward(self, x): class Downsample2x (line 31) | class Downsample2x(nn.Module): method __init__ (line 32) | def __init__(self, in_channels): method forward (line 36) | def forward(self, x): class ResnetBlock (line 40) | class ResnetBlock(nn.Module): method __init__ (line 41) | def __init__(self, *, in_channels, out_channels=None, dropout): # conv... method forward (line 57) | def forward(self, x): class AttnBlock (line 63) | class AttnBlock(nn.Module): method __init__ (line 64) | def __init__(self, in_channels): method forward (line 73) | def forward(self, x): function make_attn (line 95) | def make_attn(in_channels, using_sa=True): class Encoder (line 99) | class Encoder(nn.Module): method __init__ (line 100) | def __init__( method forward (line 144) | def forward(self, x): class Decoder (line 163) | class Decoder(nn.Module): method __init__ (line 164) | def __init__( method forward (line 210) | def forward(self, z): FILE: models/basic_var.py function slow_attn (line 27) | def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p... class FFN (line 33) | class FFN(nn.Module): method __init__ (line 34) | def __init__(self, in_features, hidden_features=None, out_features=Non... method forward (line 44) | def forward(self, x): method extra_repr (line 54) | def extra_repr(self) -> str: class SelfAttention (line 58) | class SelfAttention(nn.Module): method __init__ (line 59) | def __init__( method kv_caching (line 87) | def kv_caching(self, enable: bool): self.caching, self.cached_k, self.... method forward (line 90) | def forward(self, x, attn_bias): method extra_repr (line 124) | def extra_repr(self) -> str: class AdaLNSelfAttn (line 128) | class AdaLNSelfAttn(nn.Module): method __init__ (line 129) | def __init__( method forward (line 152) | def forward(self, x, cond_BD, attn_bias): # C: embed_dim, D: cond_dim method extra_repr (line 161) | def extra_repr(self) -> str: class AdaLNBeforeHead (line 165) | class AdaLNBeforeHead(nn.Module): method __init__ (line 166) | def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim method forward (line 172) | def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor): FILE: models/helpers.py function sample_with_top_k_top_p_ (line 6) | def sample_with_top_k_top_p_(logits_BlV: torch.Tensor, top_k: int = 0, t... function gumbel_softmax_with_rng (line 22) | def gumbel_softmax_with_rng(logits: torch.Tensor, tau: float = 1, hard: ... function drop_path (line 39) | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by... class DropPath (line 49) | class DropPath(nn.Module): # taken from timm method __init__ (line 50) | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): method forward (line 55) | def forward(self, x): method extra_repr (line 58) | def extra_repr(self): FILE: models/quant.py class VectorQuantizer2 (line 15) | class VectorQuantizer2(nn.Module): method __init__ (line 17) | def __init__( method eini (line 44) | def eini(self, eini): method extra_repr (line 48) | def extra_repr(self) -> str: method forward (line 52) | def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[tor... method embed_to_fhat (line 107) | def embed_to_fhat(self, ms_h_BChw: List[torch.Tensor], all_to_max_scal... method f_to_idxBl_or_fhat (line 135) | def f_to_idxBl_or_fhat(self, f_BChw: torch.Tensor, to_fhat: bool, v_pa... method idxBl_to_var_input (line 169) | def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torc... method get_next_autoregressive_input (line 187) | def get_next_autoregressive_input(self, si: int, SN: int, f_hat: torch... class Phi (line 199) | class Phi(nn.Conv2d): method __init__ (line 200) | def __init__(self, embed_dim, quant_resi): method forward (line 205) | def forward(self, h_BChw): class PhiShared (line 209) | class PhiShared(nn.Module): method __init__ (line 210) | def __init__(self, qresi: Phi): method __getitem__ (line 214) | def __getitem__(self, _) -> Phi: class PhiPartiallyShared (line 218) | class PhiPartiallyShared(nn.Module): method __init__ (line 219) | def __init__(self, qresi_ls: nn.ModuleList): method __getitem__ (line 225) | def __getitem__(self, at_from_0_to_1: float) -> Phi: method extra_repr (line 228) | def extra_repr(self) -> str: class PhiNonShared (line 232) | class PhiNonShared(nn.ModuleList): method __init__ (line 233) | def __init__(self, qresi: List): method __getitem__ (line 239) | def __getitem__(self, at_from_0_to_1: float) -> Phi: method extra_repr (line 242) | def extra_repr(self) -> str: FILE: models/var.py class SharedAdaLin (line 15) | class SharedAdaLin(nn.Linear): method forward (line 16) | def forward(self, cond_BD): class VAR (line 21) | class VAR(nn.Module): method __init__ (line 22) | def __init__( method get_logits (line 118) | def get_logits(self, h_or_h_and_residual: Union[torch.Tensor, Tuple[to... method autoregressive_infer_cfg (line 127) | def autoregressive_infer_cfg( method forward (line 192) | def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.... method init_weights (line 236) | def init_weights(self, init_adaln=0.5, init_adaln_gamma=1e-5, init_hea... method extra_repr (line 288) | def extra_repr(self): class VARHF (line 292) | class VARHF(VAR, PyTorchModelHubMixin): method __init__ (line 295) | def __init__( FILE: models/vqvae.py class VQVAE (line 16) | class VQVAE(nn.Module): method __init__ (line 17) | def __init__( method forward (line 56) | def forward(self, inp, ret_usages=False): # -> rec_B3HW, idx_N, loss method fhat_to_img (line 62) | def fhat_to_img(self, f_hat: torch.Tensor): method img_to_idxBl (line 65) | def img_to_idxBl(self, inp_img_no_grad: torch.Tensor, v_patch_nums: Op... method idxBl_to_img (line 69) | def idxBl_to_img(self, ms_idx_Bl: List[torch.Tensor], same_shape: bool... method embed_to_img (line 78) | def embed_to_img(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale... method img_to_reconstructed_img (line 84) | def img_to_reconstructed_img(self, x, v_patch_nums: Optional[Sequence[... method load_state_dict (line 92) | def load_state_dict(self, state_dict: Dict[str, Any], strict=True, ass... FILE: train.py function build_everything (line 19) | def build_everything(args: arg_util.Args): function main_training (line 171) | def main_training(): function train_one_ep (line 253) | def train_one_ep(ep: int, is_first_ep: bool, start_it: int, args: arg_ut... class NullDDP (line 320) | class NullDDP(torch.nn.Module): method __init__ (line 321) | def __init__(self, module, *args, **kwargs): method forward (line 326) | def forward(self, *args, **kwargs): FILE: trainer.py class VARTrainer (line 20) | class VARTrainer(object): method __init__ (line 21) | def __init__( method eval_ep (line 55) | def eval_ep(self, ld_val: DataLoader): method train_step (line 86) | def train_step( method get_config (line 162) | def get_config(self): method state_dict (line 169) | def state_dict(self): method load_state_dict (line 179) | def load_state_dict(self, state, strict=True, skip_vae=False): FILE: utils/amp_sc.py class NullCtx (line 7) | class NullCtx: method __enter__ (line 8) | def __enter__(self): method __exit__ (line 11) | def __exit__(self, exc_type, exc_val, exc_tb): class AmpOptimizer (line 15) | class AmpOptimizer: method __init__ (line 16) | def __init__( method backward_clip_step (line 39) | def backward_clip_step( method state_dict (line 77) | def state_dict(self): method load_state_dict (line 85) | def load_state_dict(self, state, strict=True): FILE: utils/arg_util.py class Args (line 25) | class Args(Tap): method seed_everything (line 113) | def seed_everything(self, benchmark: bool): method get_different_generator_for_each_rank (line 129) | def get_different_generator_for_each_rank(self) -> Optional[torch.Gene... method compile_model (line 139) | def compile_model(self, m, fast): method state_dict (line 148) | def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]: method load_state_dict (line 156) | def load_state_dict(self, d: Union[OrderedDict, dict, str]): method set_tf32 (line 167) | def set_tf32(tf32: bool): method dump_log (line 177) | def dump_log(self): method __str__ (line 198) | def __str__(self): function init_dist_and_get_args (line 207) | def init_dist_and_get_args(): FILE: utils/data.py function normalize_01_into_pm1 (line 8) | def normalize_01_into_pm1(x): # normalize x from [0, 1] to [-1, 1] by (... function build_dataset (line 12) | def build_dataset( function pil_loader (line 41) | def pil_loader(path): function print_aug (line 47) | def print_aug(transform, label): FILE: utils/data_sampler.py class EvalDistributedSampler (line 6) | class EvalDistributedSampler(Sampler): method __init__ (line 7) | def __init__(self, dataset, num_replicas, rank): method __iter__ (line 13) | def __iter__(self): method __len__ (line 16) | def __len__(self) -> int: class InfiniteBatchSampler (line 20) | class InfiniteBatchSampler(Sampler): method __init__ (line 21) | def __init__(self, dataset_len, batch_size, seed_for_all_rank=0, fill_... method gener_indices (line 33) | def gener_indices(self): method __iter__ (line 51) | def __iter__(self): method __len__ (line 63) | def __len__(self): class DistInfiniteBatchSampler (line 67) | class DistInfiniteBatchSampler(InfiniteBatchSampler): method __init__ (line 68) | def __init__(self, world_size, rank, dataset_len, glb_batch_size, same... method gener_indices (line 84) | def gener_indices(self): FILE: utils/lr_control.py function lr_wd_annealing (line 10) | def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_... function filter_params (line 68) | def filter_params(model, nowd_keys=()) -> Tuple[ FILE: utils/misc.py function echo (line 20) | def echo(info): function os_system_get_stdout (line 22) | def os_system_get_stdout(cmd): function os_system_get_stdout_stderr (line 24) | def os_system_get_stdout_stderr(cmd): function time_str (line 36) | def time_str(fmt='[%m-%d %H:%M:%S]'): function init_distributed_mode (line 40) | def init_distributed_mode(local_out_path, only_sync_master=False, timeou... function _change_builtin_print (line 54) | def _change_builtin_print(is_master): class SyncPrint (line 78) | class SyncPrint(object): method __init__ (line 79) | def __init__(self, local_output_dir, sync_stdout=True): method write (line 90) | def write(self, message): method flush (line 94) | def flush(self): method close (line 98) | def close(self): method __del__ (line 111) | def __del__(self): class DistLogger (line 115) | class DistLogger(object): method __init__ (line 116) | def __init__(self, lg, verbose): method do_nothing (line 120) | def do_nothing(*args, **kwargs): method __getattr__ (line 123) | def __getattr__(self, attr: str): class TensorboardLogger (line 127) | class TensorboardLogger(object): method __init__ (line 128) | def __init__(self, log_dir, filename_suffix): method set_step (line 135) | def set_step(self, step=None): method update (line 141) | def update(self, head='scalar', step=None, **kwargs): method log_tensor_as_distri (line 155) | def log_tensor_as_distri(self, tag, tensor1d, step=None): method log_image (line 167) | def log_image(self, tag, img_chw, step=None): method flush (line 176) | def flush(self): method close (line 179) | def close(self): class SmoothedValue (line 183) | class SmoothedValue(object): method __init__ (line 188) | def __init__(self, window_size=30, fmt=None): method update (line 196) | def update(self, value, n=1): method synchronize_between_processes (line 201) | def synchronize_between_processes(self): method median (line 213) | def median(self): method avg (line 217) | def avg(self): method global_avg (line 221) | def global_avg(self): method max (line 225) | def max(self): method value (line 229) | def value(self): method time_preds (line 232) | def time_preds(self, counts) -> Tuple[float, str, str]: method __str__ (line 236) | def __str__(self): class MetricLogger (line 245) | class MetricLogger(object): method __init__ (line 246) | def __init__(self, delimiter=' '): method update (line 252) | def update(self, **kwargs): method __getattr__ (line 261) | def __getattr__(self, attr): method __str__ (line 269) | def __str__(self): method synchronize_between_processes (line 278) | def synchronize_between_processes(self): method add_meter (line 282) | def add_meter(self, name, meter): method log_every (line 285) | def log_every(self, start_it, max_iters, itrt, print_freq, header=None): function glob_with_latest_modified_first (line 340) | def glob_with_latest_modified_first(pattern, recursive=False): function auto_resume (line 344) | def auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[... function create_npz_from_sample_folder (line 360) | def create_npz_from_sample_folder(sample_folder: str):