SYMBOL INDEX (288 symbols across 12 files) FILE: data.py function bin_mnist_transform (line 37) | def bin_mnist_transform(x): function bin_mnist_cts_transform (line 41) | def bin_mnist_cts_transform(x): function rgb_image_transform (line 45) | def rgb_image_transform(x, num_bins=256): class MyLambda (line 49) | class MyLambda(torchvision.transforms.Lambda): method __init__ (line 50) | def __init__(self, lambd, arg1): method __call__ (line 54) | def __call__(self, x): class CIFAR10 (line 58) | class CIFAR10(torchvision.datasets.CIFAR10): method __getitem__ (line 59) | def __getitem__(self, idx): class MNIST (line 63) | class MNIST(torchvision.datasets.MNIST): method __getitem__ (line 64) | def __getitem__(self, idx): function make_datasets (line 68) | def make_datasets(cfg: DictConfig) -> tuple[Dataset, Dataset, Dataset]: function prepare_text8 (line 127) | def prepare_text8(data_dir: pathlib.Path): class Text8Dataset (line 185) | class Text8Dataset(Dataset): method __init__ (line 186) | def __init__(self, data_dir: Union[str, pathlib.Path], split: str, dow... method __getitem__ (line 204) | def __getitem__(self, index) -> torch.Tensor: method __len__ (line 208) | def __len__(self): function char_ids_to_str (line 212) | def char_ids_to_str(char_ids: Union[list[int], np.array, torch.Tensor]) ... function batch_to_str (line 217) | def batch_to_str(text_batch: Union[list[list], np.array, torch.Tensor]) ... function batch_to_images (line 222) | def batch_to_images(image_batch: torch.Tensor, ncols: int = None) -> plt... FILE: model.py class BayesianFlow (line 42) | class BayesianFlow(nn.Module, ABC): method __init__ (line 43) | def __init__(self): method get_prior_input_params (line 47) | def get_prior_input_params(self, data_shape: tuple, device: torch.devi... method params_to_net_inputs (line 54) | def params_to_net_inputs(self, params: tuple[Tensor, ...]) -> Tensor: method get_alpha (line 59) | def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> float: method get_sender_dist (line 66) | def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shap... method update_input_params (line 73) | def update_input_params(self, input_params: tuple[Tensor, ...], y: Ten... method forward (line 79) | def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, ...]: class Loss (line 87) | class Loss(nn.Module, ABC): method __init__ (line 88) | def __init__(self): method cts_time_loss (line 92) | def cts_time_loss(self, data: Tensor, output_params: Tensor, input_par... method discrete_time_loss (line 98) | def discrete_time_loss( method reconstruction_loss (line 107) | def reconstruction_loss(self, data: Tensor, output_params: Tensor, inp... class CtsBayesianFlow (line 116) | class CtsBayesianFlow(BayesianFlow): method __init__ (line 117) | def __init__( method forward (line 125) | def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, None]: method params_to_net_inputs (line 137) | def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor: method get_prior_input_params (line 140) | def get_prior_input_params(self, data_shape: tuple, device: torch.devi... method get_alpha (line 143) | def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[floa... method get_sender_dist (line 147) | def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shap... method update_input_params (line 151) | def update_input_params(self, input_params: tuple[Tensor, float], y: T... class CtsBayesianFlowLoss (line 158) | class CtsBayesianFlowLoss(Loss): method __init__ (line 159) | def __init__( method cts_time_loss (line 178) | def cts_time_loss(self, data: Tensor, output_params: Tensor, input_par... method discrete_time_loss (line 191) | def discrete_time_loss( method reconstruction_loss (line 225) | def reconstruction_loss(self, data: Tensor, output_params: Tensor, inp... class DiscreteBayesianFlow (line 250) | class DiscreteBayesianFlow(BayesianFlow): method __init__ (line 251) | def __init__( method t_to_sqrt_beta (line 267) | def t_to_sqrt_beta(self, t): method count_dist (line 270) | def count_dist(self, x, beta=None): method count_sample (line 278) | def count_sample(self, x, beta): method get_prior_input_params (line 282) | def get_prior_input_params(self, data_shape: tuple, device: torch.devi... method params_to_net_inputs (line 286) | def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor: method get_alpha (line 293) | def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[floa... method get_sender_dist (line 296) | def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shap... method update_input_params (line 302) | def update_input_params(self, input_params: tuple[Tensor], y: Tensor, ... method forward (line 308) | def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor]: class DiscreteBayesianFlowLoss (line 325) | class DiscreteBayesianFlowLoss(Loss): method __init__ (line 326) | def __init__( method cts_time_loss (line 336) | def cts_time_loss(self, data: Tensor, output_params: Tensor, input_par... method discrete_time_loss (line 348) | def discrete_time_loss( method reconstruction_loss (line 369) | def reconstruction_loss(self, data: Tensor, output_params: Tensor, inp... class BFN (line 376) | class BFN(nn.Module): method __init__ (line 377) | def __init__(self, net: nn.Module, bayesian_flow: BayesianFlow, loss: ... method sample_t (line 385) | def sample_t(data: Tensor, n_steps: Optional[int]) -> Tensor: method forward (line 393) | def forward( method compute_reconstruction_loss (line 420) | def compute_reconstruction_loss(self, data: Tensor) -> Tensor: method sample (line 428) | def sample(self, data_shape: tuple, n_steps: int) -> Tensor: FILE: networks/adapters.py class TextInputAdapter (line 25) | class TextInputAdapter(nn.Module): method __init__ (line 30) | def __init__( method forward (line 46) | def forward(self, probs: torch.Tensor, t: torch.Tensor) -> Tensor: class FourierImageInputAdapter (line 61) | class FourierImageInputAdapter(nn.Module): method __init__ (line 66) | def __init__( method forward (line 122) | def forward(self, img: Tensor, t: Tensor) -> Tensor: class OutputAdapter (line 150) | class OutputAdapter(nn.Module): method __init__ (line 151) | def __init__(self, input_height: int, output_channels: int, output_hei... method forward (line 159) | def forward(self, inp: torch.Tensor) -> torch.Tensor: FILE: networks/transformer.py function gelu (line 37) | def gelu(x): class LayerNorm (line 41) | class LayerNorm(nn.Module): method __init__ (line 44) | def __init__(self, ndim, bias): method forward (line 49) | def forward(self, input): class SelfAttention (line 53) | class SelfAttention(nn.Module): method __init__ (line 54) | def __init__(self, n_head, n_embd, dropout, bias, is_causal): method forward (line 72) | def forward(self, x): class MLP (line 92) | class MLP(nn.Module): method __init__ (line 93) | def __init__(self, n_embd, dropout, bias): method forward (line 99) | def forward(self, x): class Block (line 107) | class Block(nn.Module): method __init__ (line 108) | def __init__(self, n_head, n_embd, dropout, bias, is_causal): method forward (line 115) | def forward(self, x): class GPT (line 121) | class GPT(nn.Module): method __init__ (line 122) | def __init__( method _init_weights (line 169) | def _init_weights(self, module): method forward (line 177) | def forward(self, data: torch.Tensor, t: torch.Tensor) -> torch.Tensor: method get_optim_groups (line 188) | def get_optim_groups(self, weight_decay: float): FILE: networks/unet_improved.py function convert_module_to_f16 (line 48) | def convert_module_to_f16(module): function convert_module_to_f32 (line 57) | def convert_module_to_f32(module): function make_master_params (line 66) | def make_master_params(model_params): function model_grads_to_master_grads (line 77) | def model_grads_to_master_grads(model_params, master_params): function master_params_to_model_params (line 85) | def master_params_to_model_params(model_params, master_params): function unflatten_master_params (line 97) | def unflatten_master_params(model_params, master_params): function zero_grad (line 104) | def zero_grad(model_params): class SiLU (line 113) | class SiLU(nn.Module): method forward (line 114) | def forward(self, x): class GroupNorm32 (line 118) | class GroupNorm32(nn.GroupNorm): method forward (line 119) | def forward(self, x): function conv_nd (line 123) | def conv_nd(dims, *args, **kwargs): function linear (line 136) | def linear(*args, **kwargs): function avg_pool_nd (line 143) | def avg_pool_nd(dims, *args, **kwargs): function update_ema (line 156) | def update_ema(target_params, source_params, rate=0.99): function zero_module (line 169) | def zero_module(module): function scale_module (line 178) | def scale_module(module, scale): function mean_flat (line 187) | def mean_flat(tensor): function normalization (line 194) | def normalization(channels): function timestep_embedding (line 204) | def timestep_embedding(timesteps, dim, max_period=10000): function checkpoint (line 225) | def checkpoint(func, inputs, params, flag): class CheckpointFunction (line 243) | class CheckpointFunction(th.autograd.Function): method forward (line 245) | def forward(ctx, run_function, length, *args): method backward (line 254) | def backward(ctx, *output_grads): class TimestepBlock (line 274) | class TimestepBlock(nn.Module): method forward (line 280) | def forward(self, x, emb): class TimestepEmbedSequential (line 286) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): method forward (line 292) | def forward(self, x, emb): class Upsample (line 301) | class Upsample(nn.Module): method __init__ (line 311) | def __init__(self, channels, use_conv, dims=2): method forward (line 319) | def forward(self, x): class Downsample (line 330) | class Downsample(nn.Module): method __init__ (line 340) | def __init__(self, channels, use_conv, dims=2): method forward (line 351) | def forward(self, x): class ResBlock (line 356) | class ResBlock(TimestepBlock): method __init__ (line 371) | def __init__( method forward (line 417) | def forward(self, x, emb): method _forward (line 427) | def _forward(self, x, emb): class AttentionBlock (line 443) | class AttentionBlock(nn.Module): method __init__ (line 451) | def __init__(self, channels, num_heads=1, use_checkpoint=False): method forward (line 462) | def forward(self, x): method _forward (line 465) | def _forward(self, x): class QKVAttention (line 476) | class QKVAttention(nn.Module): method forward (line 481) | def forward(self, qkv): method count_flops (line 496) | def count_flops(model, _x, y): class UNetModel (line 519) | class UNetModel(nn.Module): method __init__ (line 542) | def __init__( method convert_to_fp16 (line 682) | def convert_to_fp16(self): method convert_to_fp32 (line 690) | def convert_to_fp32(self): method inner_dtype (line 699) | def inner_dtype(self): method forward (line 705) | def forward( method get_feature_vectors (line 751) | def get_feature_vectors(self, x, timesteps, y=None): FILE: networks/unet_vdm.py function zero_init (line 39) | def zero_init(module: nn.Module) -> nn.Module: class UNetVDM (line 46) | class UNetVDM(nn.Module): method __init__ (line 47) | def __init__( method forward (line 121) | def forward( method maybe_concat_fourier (line 152) | def maybe_concat_fourier(self, z): class ResnetBlock (line 158) | class ResnetBlock(nn.Module): method __init__ (line 159) | def __init__( method forward (line 187) | def forward(self, x, condition): function get_timestep_embedding (line 201) | def get_timestep_embedding( class FourierFeatures (line 223) | class FourierFeatures(nn.Module): method __init__ (line 224) | def __init__(self, first=5.0, last=6.0, step=1.0): method num_features (line 229) | def num_features(self): method forward (line 232) | def forward(self, x): function attention_inner_heads (line 248) | def attention_inner_heads(qkv, num_heads): class Attention (line 285) | class Attention(nn.Module): method __init__ (line 288) | def __init__(self, n_heads): method forward (line 292) | def forward(self, qkv): class AttentionBlock (line 301) | class AttentionBlock(nn.Module): method __init__ (line 304) | def __init__(self, n_heads, n_channels, norm_groups): method forward (line 314) | def forward(self, x): class UpDownBlock (line 318) | class UpDownBlock(nn.Module): method __init__ (line 319) | def __init__(self, resnet_block, attention_block=None): method forward (line 324) | def forward(self, x, cond): FILE: probability.py class CtsDistribution (line 36) | class CtsDistribution: method log_prob (line 38) | def log_prob(self, x): method sample (line 42) | def sample(self): class DiscreteDistribution (line 46) | class DiscreteDistribution: method probs (line 49) | def probs(self): method log_probs (line 53) | def log_probs(self): method mean (line 57) | def mean(self): method mode (line 61) | def mode(self): method log_prob (line 65) | def log_prob(self, x): method sample (line 69) | def sample(self): class DiscretizedDistribution (line 73) | class DiscretizedDistribution(DiscreteDistribution): method __init__ (line 74) | def __init__(self, num_bins, device): method class_centres (line 81) | def class_centres(self): method class_boundaries (line 85) | def class_boundaries(self): method mean (line 89) | def mean(self): method mode (line 93) | def mode(self): class DiscretizedCtsDistribution (line 98) | class DiscretizedCtsDistribution(DiscretizedDistribution): method __init__ (line 99) | def __init__(self, cts_dist, num_bins, device, batch_dims, clip=True, ... method probs (line 108) | def probs(self): method prob (line 127) | def prob(self, x): method log_prob (line 145) | def log_prob(self, x): method sample (line 153) | def sample(self, sample_shape=torch.Size([])): class GMM (line 165) | class GMM(MixtureSameFamily): method __init__ (line 166) | def __init__(self, mix_wt_logits, means, std_devs): class DiscretizedGMM (line 172) | class DiscretizedGMM(DiscretizedCtsDistribution): method __init__ (line 173) | def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max... class DiscretizedNormal (line 191) | class DiscretizedNormal(DiscretizedCtsDistribution): method __init__ (line 192) | def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max... class Bernoulli (line 210) | class Bernoulli(DiscreteDistribution): method __init__ (line 211) | def __init__(self, logits): method probs (line 215) | def probs(self): method mode (line 220) | def mode(self): method log_prob (line 223) | def log_prob(self, x): method sample (line 226) | def sample(self, sample_shape=torch.Size([])): class DiscretizedBernoulli (line 230) | class DiscretizedBernoulli(DiscretizedDistribution): method __init__ (line 231) | def __init__(self, logits): method probs (line 236) | def probs(self): method mode (line 241) | def mode(self): method log_prob (line 244) | def log_prob(self, x): method sample (line 247) | def sample(self, sample_shape=torch.Size([])): class DeltaDistribution (line 251) | class DeltaDistribution(CtsDistribution): method __init__ (line 252) | def __init__(self, mean, clip_range=1.0): method mode (line 258) | def mode(self): method mean (line 262) | def mean(self): method sample (line 265) | def sample(self, sample_shape=torch.Size([])): class Categorical (line 269) | class Categorical(DiscreteDistribution): method __init__ (line 270) | def __init__(self, logits): method probs (line 275) | def probs(self): method mode (line 279) | def mode(self): method log_prob (line 282) | def log_prob(self, x): method sample (line 285) | def sample(self, sample_shape=torch.Size([])): class DiscretizedCategorical (line 289) | class DiscretizedCategorical(DiscretizedDistribution): method __init__ (line 290) | def __init__(self, logits=None, probs=None): method probs (line 300) | def probs(self): method mode (line 304) | def mode(self): method log_prob (line 307) | def log_prob(self, x): method sample (line 310) | def sample(self, sample_shape=torch.Size([])): class CtsDistributionFactory (line 314) | class CtsDistributionFactory: method get_dist (line 316) | def get_dist(self, params: torch.Tensor, input_params=None, t=None) ->... class GMMFactory (line 321) | class GMMFactory(CtsDistributionFactory): method __init__ (line 322) | def __init__(self, min_std_dev=1e-3, max_std_dev=10, log_dev=True): method get_dist (line 327) | def get_dist(self, params, input_params=None, t=None): class NormalFactory (line 335) | class NormalFactory(CtsDistributionFactory): method __init__ (line 336) | def __init__(self, min_std_dev=1e-3, max_std_dev=10): method get_dist (line 340) | def get_dist(self, params, input_params=None, t=None): class DeltaFactory (line 346) | class DeltaFactory(CtsDistributionFactory): method __init__ (line 347) | def __init__(self, clip_range=1.0): method get_dist (line 350) | def get_dist(self, params, input_params=None, t=None): class DiscreteDistributionFactory (line 354) | class DiscreteDistributionFactory: method get_dist (line 356) | def get_dist(self, params: torch.Tensor, input_params=None, t=None) ->... class BernoulliFactory (line 361) | class BernoulliFactory(DiscreteDistributionFactory): method get_dist (line 362) | def get_dist(self, params, input_params=None, t=None): class CategoricalFactory (line 366) | class CategoricalFactory(DiscreteDistributionFactory): method get_dist (line 367) | def get_dist(self, params, input_params=None, t=None): class DiscretizedBernoulliFactory (line 371) | class DiscretizedBernoulliFactory(DiscreteDistributionFactory): method get_dist (line 372) | def get_dist(self, params, input_params=None, t=None): class DiscretizedCategoricalFactory (line 376) | class DiscretizedCategoricalFactory(DiscreteDistributionFactory): method get_dist (line 377) | def get_dist(self, params, input_params=None, t=None): class DiscretizedGMMFactory (line 381) | class DiscretizedGMMFactory(DiscreteDistributionFactory): method __init__ (line 382) | def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=... method get_dist (line 390) | def get_dist(self, params, input_params=None, t=None): class DiscretizedNormalFactory (line 402) | class DiscretizedNormalFactory(DiscreteDistributionFactory): method __init__ (line 403) | def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=... method get_dist (line 411) | def get_dist(self, params, input_params=None, t=None): function noise_pred_params_to_data_pred_params (line 423) | def noise_pred_params_to_data_pred_params(noise_pred_params: torch.Tenso... class PredDistToDataDistFactory (line 459) | class PredDistToDataDistFactory(DiscreteDistributionFactory): method __init__ (line 460) | def __init__(self, data_dist_factory, min_variance, min_t=1e-6): method get_dist (line 466) | def get_dist(self, params, input_params, t): FILE: sample.py function main (line 24) | def main(cfg: DictConfig) -> torch.Tensor: FILE: test.py function setup (line 32) | def setup(cfg: DictConfig) -> Tuple[nn.Module, DataLoader]: function test (line 47) | def test(model: BFN, dataloader: DataLoader, n_steps: int, n_repeats: in... function main (line 73) | def main(cfg: DictConfig) -> tuple[float, float, float, float]: FILE: train.py function setup (line 56) | def setup(cfg) -> Tuple[nn.Module, dict, optim.Optimizer]: function validate (line 70) | def validate( function train (line 116) | def train( function main (line 178) | def main(cfg): FILE: utils_model.py function sandwich (line 28) | def sandwich(x: Tensor): function safe_log (line 32) | def safe_log(data: Tensor): function safe_exp (line 36) | def safe_exp(data: Tensor): function idx_to_float (line 40) | def idx_to_float(idx: np.ndarray, num_bins: int): function float_to_idx (line 45) | def float_to_idx(flt: np.ndarray, num_bins: int): function quantize (line 50) | def quantize(flt, num_bins: int): function pe_encode (line 54) | def pe_encode(sequence_length: int, embedding_size: int) -> Tensor: function pe_encode_float (line 69) | def pe_encode_float(x: Tensor, max_freq: float, embedding_size: int) -> ... FILE: utils_train.py function stringify_unsupported (line 29) | def stringify_unsupported(x): function seed_everything (line 49) | def seed_everything(seed: Optional[int]): function worker_init_function (line 58) | def worker_init_function(worker_id: int) -> None: function init_checkpointing (line 65) | def init_checkpointing(checkpoint_dir: Union[str, Path, None], run_id: s... function checkpoint_training_state (line 77) | def checkpoint_training_state(checkpoint_dir, accelerator, ema_model, st... function log (line 89) | def log(key_handler, value, step, cond=True): function log_cfg (line 95) | def log_cfg(cfg, run: "neptune.Run"): function update_ema (line 104) | def update_ema(ema_model, model, ema_decay): function ddict (line 110) | def ddict(): function make_infinite (line 115) | def make_infinite(dataloader: DataLoader) -> Generator[dict, None, None]: function make_progress_bar (line 121) | def make_progress_bar(is_main: bool, text="[red]loss: {task.fields[loss]... function make_dataloaders (line 132) | def make_dataloaders(cfg: DictConfig): function make_from_cfg (line 149) | def make_from_cfg(module, cfg, **parameters): function make_bfn (line 153) | def make_bfn(cfg: DictConfig): function make_config (line 205) | def make_config(cfg_file: str):