SYMBOL INDEX (434 symbols across 36 files) FILE: caduceus/configuration_caduceus.py class CaduceusConfig (line 10) | class CaduceusConfig(PretrainedConfig): method __init__ (line 14) | def __init__( FILE: caduceus/modeling_caduceus.py function create_block (line 33) | def create_block( class BiMambaWrapper (line 87) | class BiMambaWrapper(nn.Module): method __init__ (line 90) | def __init__( method forward (line 122) | def forward(self, hidden_states, inference_params=None): class CaduceusEmbeddings (line 143) | class CaduceusEmbeddings(nn.Module): method __init__ (line 144) | def __init__( method forward (line 159) | def forward(self, input_ids): class CaduceusMixerModel (line 166) | class CaduceusMixerModel(nn.Module): method __init__ (line 167) | def __init__( method forward (line 216) | def forward(self, input_ids, inputs_embeds=None, output_hidden_states=... function cross_entropy (line 279) | def cross_entropy(logits, y, ignore_index=-100): function weighted_cross_entropy (line 286) | def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100): class CaduceusPreTrainedModel (line 297) | class CaduceusPreTrainedModel(PreTrainedModel): method _init_weights (line 304) | def _init_weights( class Caduceus (line 344) | class Caduceus(CaduceusPreTrainedModel): method __init__ (line 346) | def __init__(self, config: CaduceusConfig, device=None, dtype=None, **... method forward (line 363) | def forward( class CaduceusForMaskedLM (line 392) | class CaduceusForMaskedLM(CaduceusPreTrainedModel): method __init__ (line 395) | def __init__(self, config: CaduceusConfig, device=None, dtype=None, **... method get_input_embeddings (line 417) | def get_input_embeddings(self): method set_input_embeddings (line 420) | def set_input_embeddings(self, value): method get_output_embeddings (line 425) | def get_output_embeddings(self): method set_output_embeddings (line 428) | def set_output_embeddings(self, new_embeddings): method tie_weights (line 434) | def tie_weights(self): method get_decoder (line 441) | def get_decoder(self): method set_decoder (line 445) | def set_decoder(self, decoder): method forward (line 449) | def forward( class CaduceusForSequenceClassification (line 495) | class CaduceusForSequenceClassification(CaduceusPreTrainedModel): method __init__ (line 496) | def __init__( method init_scorer (line 521) | def init_scorer(self, initializer_range=0.02): method get_input_embeddings (line 526) | def get_input_embeddings(self): method set_input_embeddings (line 529) | def set_input_embeddings(self, value): method pool_hidden_states (line 534) | def pool_hidden_states(self, hidden_states, sequence_length_dim=1): method forward (line 545) | def forward( FILE: caduceus/modeling_rcps.py class RCPSEmbedding (line 21) | class RCPSEmbedding(nn.Module): method __init__ (line 23) | def __init__(self, vocab_size: int, d_model: int, complement_map: dict... method weight (line 38) | def weight(self): method set_weight (line 42) | def set_weight(self, value): method rc (line 46) | def rc(self, x): method forward (line 54) | def forward(self, input_ids): class RCPSWrapper (line 70) | class RCPSWrapper(nn.Module): method __init__ (line 76) | def __init__(self, submodule: nn.Module): method rc (line 81) | def rc(x): method forward (line 85) | def forward(self, x, **kwargs): class RCPSAddNormWrapper (line 102) | class RCPSAddNormWrapper(RCPSWrapper): method __init__ (line 104) | def __init__(self, submodule: nn.Module): method forward (line 107) | def forward(self, x, residual=None, prenorm=False): class RCPSMambaBlock (line 133) | class RCPSMambaBlock(nn.Module): method __init__ (line 134) | def __init__( method forward (line 160) | def forward( method allocate_inference_cache (line 201) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,... class RCPSLMHead (line 209) | class RCPSLMHead(nn.Module): method __init__ (line 211) | def __init__(self, true_dim: int, vocab_size: int, complement_map: dic... method weight (line 225) | def weight(self): method set_weight (line 229) | def set_weight(self, value): method forward (line 233) | def forward(self, x): FILE: caduceus/tests/test_rcps.py function test_rcps_embedding (line 31) | def test_rcps_embedding(batch_size, seq_len, d_model, dtype): function test_rcps_wrapper (line 80) | def test_rcps_wrapper(batch_size, seq_len, d_model, dtype): function test_rcps_add_norm_wrapper (line 116) | def test_rcps_add_norm_wrapper(batch_size, seq_len, d_model, prenorm, dt... function test_rcps_mamba_block_wrapper (line 155) | def test_rcps_mamba_block_wrapper(batch_size, seq_len, d_model, bidirect... function test_rcps_lm_head (line 209) | def test_rcps_lm_head(batch_size, seq_len, d_model, dtype): function test_rcps_backbone (line 271) | def test_rcps_backbone(batch_size, seq_len, n_layer, d_model, dtype, fus... function test_rcps_mamba_lm (line 348) | def test_rcps_mamba_lm(batch_size, seq_len, n_layer, d_model, dtype, bid... function test_collapse_invariance (line 429) | def test_collapse_invariance(batch_size, seq_len, n_layer, d_model, dtyp... FILE: caduceus/tokenization_caduceus.py class CaduceusTokenizer (line 10) | class CaduceusTokenizer(PreTrainedTokenizer): method __init__ (line 13) | def __init__(self, method vocab_size (line 83) | def vocab_size(self) -> int: method complement_map (line 87) | def complement_map(self) -> Dict[int, int]: method _tokenize (line 90) | def _tokenize(self, text: str, **kwargs) -> List[str]: method _convert_token_to_id (line 93) | def _convert_token_to_id(self, token: str) -> int: method _convert_id_to_token (line 96) | def _convert_id_to_token(self, index: int) -> str: method convert_tokens_to_string (line 99) | def convert_tokens_to_string(self, tokens): method get_special_tokens_mask (line 102) | def get_special_tokens_mask( method build_inputs_with_special_tokens (line 120) | def build_inputs_with_special_tokens( method get_vocab (line 130) | def get_vocab(self) -> Dict[str, int]: method save_vocabulary (line 134) | def save_vocabulary(self, save_directory: str, filename_prefix: Option... FILE: src/callbacks/params.py class ParamsLog (line 10) | class ParamsLog(pl.Callback): method __init__ (line 12) | def __init__( method on_fit_start (line 28) | def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningMod... FILE: src/callbacks/timer.py class Timer (line 18) | class Timer(Callback): method __init__ (line 21) | def __init__( method on_train_start (line 36) | def on_train_start(self, trainer: Trainer, pl_module: LightningModule)... method on_train_epoch_start (line 39) | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningM... method on_train_batch_start (line 44) | def on_train_batch_start( method on_train_batch_end (line 65) | def on_train_batch_end( method on_train_epoch_end (line 86) | def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningMod... method on_validation_epoch_start (line 92) | def on_validation_epoch_start(self, trainer: Trainer, pl_module: Light... method on_validation_epoch_end (line 96) | def on_validation_epoch_end(self, trainer: Trainer, pl_module: Lightni... method _should_log (line 103) | def _should_log(trainer) -> bool: FILE: src/callbacks/validation.py class ValEveryNGlobalSteps (line 13) | class ValEveryNGlobalSteps(Callback): method __init__ (line 15) | def __init__(self, every_n): method on_train_batch_end (line 19) | def on_train_batch_end(self, trainer, *_: Any): FILE: src/dataloaders/base.py class DefaultCollateMixin (line 20) | class DefaultCollateMixin: method _collate_callback (line 30) | def _collate_callback(cls, x, *args, **kwargs): method _return_callback (line 39) | def _return_callback(cls, return_value, *args, **kwargs): method _collate (line 50) | def _collate(cls, batch, *args, **kwargs): method _collate_fn (line 71) | def _collate_fn(cls, batch, *args, **kwargs): method _dataloader (line 92) | def _dataloader(self, dataset, **loader_args): class SequenceDataset (line 106) | class SequenceDataset(DefaultCollateMixin): method init_defaults (line 115) | def init_defaults(self): method __init_subclass__ (line 119) | def __init_subclass__(cls, **kwargs): method __init__ (line 123) | def __init__(self, _name_, data_dir=None, **dataset_cfg): method init (line 138) | def init(self): method setup (line 142) | def setup(self): method split_train_val (line 146) | def split_train_val(self, val_split): method train_dataloader (line 159) | def train_dataloader(self, **kwargs): method _train_dataloader (line 163) | def _train_dataloader(self, dataset, **kwargs): method val_dataloader (line 169) | def val_dataloader(self, **kwargs): method test_dataloader (line 173) | def test_dataloader(self, **kwargs): method _eval_dataloader (line 177) | def _eval_dataloader(self, dataset, **kwargs): method __str__ (line 183) | def __str__(self): FILE: src/dataloaders/datasets/genomic_bench_dataset.py class GenomicBenchmarkDataset (line 15) | class GenomicBenchmarkDataset(torch.utils.data.Dataset): method __init__ (line 21) | def __init__( method __len__ (line 80) | def __len__(self): method __getitem__ (line 83) | def __getitem__(self, idx): FILE: src/dataloaders/datasets/hg38_char_tokenizer.py class CharacterTokenizer (line 15) | class CharacterTokenizer(PreTrainedTokenizer): method __init__ (line 16) | def __init__(self, characters: Sequence[str], model_max_length: int, p... method vocab_size (line 77) | def vocab_size(self) -> int: method _tokenize (line 80) | def _tokenize(self, text: str) -> List[str]: method _convert_token_to_id (line 83) | def _convert_token_to_id(self, token: str) -> int: method _convert_id_to_token (line 86) | def _convert_id_to_token(self, index: int) -> str: method convert_tokens_to_string (line 89) | def convert_tokens_to_string(self, tokens): method build_inputs_with_special_tokens (line 92) | def build_inputs_with_special_tokens( method get_special_tokens_mask (line 102) | def get_special_tokens_mask( method get_vocab (line 120) | def get_vocab(self) -> Dict[str, int]: method create_token_type_ids_from_sequences (line 123) | def create_token_type_ids_from_sequences( method get_config (line 134) | def get_config(self) -> Dict: method from_config (line 141) | def from_config(cls, config: Dict) -> "CharacterTokenizer": method save_pretrained (line 147) | def save_pretrained(self, save_directory: Union[str, os.PathLike], **k... method from_pretrained (line 154) | def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kw... FILE: src/dataloaders/datasets/hg38_dataset.py class FastaInterval (line 18) | class FastaInterval: method __init__ (line 20) | def __init__( method _compute_interval (line 41) | def _compute_interval(start, end, max_length, i_shift): method __call__ (line 50) | def __call__( class HG38Dataset (line 92) | class HG38Dataset(torch.utils.data.Dataset): method __init__ (line 95) | def __init__( method replace_value (line 153) | def replace_value(x, old_value, new_value): method __len__ (line 157) | def __len__(self): method __getitem__ (line 160) | def __getitem__(self, idx): FILE: src/dataloaders/datasets/nucleotide_transformer_dataset.py class NucleotideTransformerDataset (line 12) | class NucleotideTransformerDataset(torch.utils.data.Dataset): method __init__ (line 19) | def __init__( method __len__ (line 59) | def __len__(self): method __getitem__ (line 62) | def __getitem__(self, idx): FILE: src/dataloaders/fault_tolerant_sampler.py class RandomFaultTolerantSampler (line 9) | class RandomFaultTolerantSampler(RandomSampler): method __init__ (line 11) | def __init__(self, *args, generator=None, **kwargs): method state_dict (line 26) | def state_dict(self): method load_state_dict (line 29) | def load_state_dict(self, state_dict): method __iter__ (line 43) | def __iter__(self) -> Iterator[int]: class FaultTolerantDistributedSampler (line 64) | class FaultTolerantDistributedSampler(DistributedSampler): method __init__ (line 66) | def __init__(self, *args, **kwargs): method state_dict (line 72) | def state_dict(self): method load_state_dict (line 75) | def load_state_dict(self, state_dict): method __iter__ (line 86) | def __iter__(self): FILE: src/dataloaders/genomics.py class HG38 (line 29) | class HG38(SequenceDataset): method __init__ (line 45) | def __init__(self, bed_file, fasta_file, tokenizer_name=None, dataset_... method setup (line 97) | def setup(self, stage=None): method init_datasets (line 119) | def init_datasets(self): method train_dataloader (line 152) | def train_dataloader(self, **kwargs: Any) -> DataLoader: method val_dataloader (line 177) | def val_dataloader(self, **kwargs: Any) -> Union[DataLoader, List[Data... method test_dataloader (line 182) | def test_dataloader(self, **kwargs: Any) -> Union[DataLoader, List[Dat... method _data_loader (line 189) | def _data_loader(dataset: Dataset, batch_size: int, shuffle: bool = Fa... method load_state_dict (line 198) | def load_state_dict(self, checkpoint): class GenomicBenchmark (line 208) | class GenomicBenchmark(HG38): method __init__ (line 212) | def __init__( method setup (line 262) | def setup(self, stage=None): class NucleotideTransformer (line 308) | class NucleotideTransformer(HG38): method __init__ (line 312) | def __init__(self, dataset_name, train_val_split_seed, method setup (line 357) | def setup(self, stage=None): FILE: src/dataloaders/utils/mlm.py function mlm_getitem (line 4) | def mlm_getitem(seq, mlm_probability=0.15, contains_eos=False, tokenizer... FILE: src/dataloaders/utils/rc.py function coin_flip (line 12) | def coin_flip(p=0.5): function string_reverse_complement (line 17) | def string_reverse_complement(seq): FILE: src/models/baseline/genomics_benchmark_cnn.py class GenomicsBenchmarkCNN (line 10) | class GenomicsBenchmarkCNN(nn.Module): method __init__ (line 11) | def __init__(self, number_of_classes, vocab_size, input_len, embedding... method count_flatten_size (line 42) | def count_flatten_size(self, input_len): method forward (line 49) | def forward(self, x, state=None): # Adding `state` to be consistent w... FILE: src/models/nn/activation.py function Activation (line 9) | def Activation(activation=None, size=None, dim=-1): class GLU (line 45) | class GLU(nn.Module): method __init__ (line 46) | def __init__(self, dim=-1, activation='sigmoid'): method forward (line 52) | def forward(self, x): class ModReLU (line 57) | class ModReLU(nn.Module): method __init__ (line 60) | def __init__(self, features): method reset_parameters (line 67) | def reset_parameters(self): method forward (line 70) | def forward(self, inputs): class SquaredReLU (line 79) | class SquaredReLU(nn.Module): method forward (line 80) | def forward(self, x): function laplace (line 85) | def laplace(x, mu=0.707107, sigma=0.282095): class Laplace (line 90) | class Laplace(nn.Module): method __init__ (line 91) | def __init__(self, mu=0.707107, sigma=0.282095): method forward (line 96) | def forward(self, x): FILE: src/models/nn/adaptive_softmax.py class OptionalParameterList (line 23) | class OptionalParameterList(nn.ParameterList): method extra_repr (line 24) | def extra_repr(self): class ProjectedAdaptiveLogSoftmax (line 37) | class ProjectedAdaptiveLogSoftmax(nn.Module): method __init__ (line 38) | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, method _compute_logit (line 128) | def _compute_logit(self, hidden, weight, bias, proj): method get_out_proj (line 142) | def get_out_proj(self, i): method forward (line 153) | def forward(self, hidden, target, keep_order=False, key_padding_mask=N... method compute_logits (line 237) | def compute_logits(self, hidden): class AdaptiveEmbedding (line 300) | class AdaptiveEmbedding(nn.Module): method __init__ (line 305) | def __init__(self, n_token, d_embed, d_proj, cutoffs : List[int], div_... method forward (line 342) | def forward(self, inp): function _init_weight (line 395) | def _init_weight(weight, d : int, init_scale : Optional[float], default=... FILE: src/models/nn/utils.py function wrap_kwargs (line 8) | def wrap_kwargs(f): function discard_kwargs (line 84) | def discard_kwargs(f): function PassthroughSequential (line 92) | def PassthroughSequential(*modules): FILE: src/models/sequence/dna_embedding.py class DNAEmbeddingModel (line 27) | class DNAEmbeddingModel(nn.Module, GenerationMixin): method __init__ (line 34) | def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_siz... method forward (line 82) | def forward(self, input_ids, position_ids=None, inference_params=None,... method d_output (line 90) | def d_output(self): class DNAEmbeddingModelMamba (line 99) | class DNAEmbeddingModelMamba(DNAEmbeddingModel): method __init__ (line 102) | def __init__( method forward (line 149) | def forward(self, input_ids, position_ids=None, inference_params=None,... class DNAEmbeddingModelCaduceus (line 156) | class DNAEmbeddingModelCaduceus(DNAEmbeddingModel): method __init__ (line 159) | def __init__( method forward (line 179) | def forward(self, input_ids, position_ids=None, inference_params=None,... function load_backbone (line 198) | def load_backbone(model, state_dict, freeze_backbone=False, ignore_head=... FILE: src/models/sequence/hyena.py class FFTConvFuncv2 (line 25) | class FFTConvFuncv2(torch.autograd.Function): method forward (line 27) | def forward(ctx, u, k): method backward (line 40) | def backward(ctx, dout): function fftconv_ref (line 55) | def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): function mul_sum (line 79) | def mul_sum(q, y): class Sin (line 83) | class Sin(nn.Module): method __init__ (line 84) | def __init__(self, dim, w=10, train_freq=True): method forward (line 92) | def forward(self, x): class PositionalEmbedding (line 96) | class PositionalEmbedding(OptimModule): method __init__ (line 97) | def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float = 1e-... method forward (line 117) | def forward(self, L): class ExponentialModulation (line 121) | class ExponentialModulation(OptimModule): method __init__ (line 122) | def __init__( method forward (line 139) | def forward(self, t, x): class HyenaFilter (line 145) | class HyenaFilter(OptimModule): method __init__ (line 146) | def __init__( method filter (line 214) | def filter(self, L, *args, **kwargs): method forward (line 225) | def forward(self, x, L, k=None, bias=None, *args, **kwargs): class HyenaOperator (line 255) | class HyenaOperator(nn.Module): method __init__ (line 256) | def __init__( method setup_projections (line 330) | def setup_projections(self, fused_bias_fc, inner_factor): method setup_filters (line 343) | def setup_filters(self, filter_cls, filter_args): method recurrence (line 369) | def recurrence(self, u, state): method forward (line 373) | def forward(self, u, *args, **kwargs): method d_output (line 432) | def d_output(self): FILE: src/models/sequence/long_conv_lm.py class CheckpointedModule (line 33) | class CheckpointedModule(torch.nn.Module): method __init__ (line 34) | def __init__(self, layer): method forward (line 38) | def forward(self, x): function create_mixer_cls (line 42) | def create_mixer_cls( function create_mlp_cls (line 93) | def create_mlp_cls( function create_block (line 130) | def create_block( function _init_weights (line 195) | def _init_weights( class LMBackbone (line 240) | class LMBackbone(nn.Module): method __init__ (line 241) | def __init__( method tie_weights (line 344) | def tie_weights(self): method forward (line 348) | def forward(self, input_ids, position_ids=None, inference_params=None): class ConvLMHeadModel (line 391) | class ConvLMHeadModel(nn.Module, GenerationMixin): method __init__ (line 392) | def __init__( method tie_weights (line 473) | def tie_weights(self): method forward (line 478) | def forward( FILE: src/ops/fftconv.py function _mul_sum (line 11) | def _mul_sum(y, q): function fftconv_ref (line 15) | def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): function fftconv_h3_ref (line 38) | def fftconv_h3_ref(k, ssm_kernel, D, q, v, head_dim=1, ssm_kernel_rev=No... class FFTConvFunc (line 58) | class FFTConvFunc(torch.autograd.Function): method forward (line 61) | def forward(ctx, u, k, D, dropout_mask=None, gelu=True, force_fp16_out... method backward (line 88) | def backward(ctx, dout): function fftconv_func (line 105) | def fftconv_func(u, k, D, dropout_mask=None, gelu=True, force_fp16_outpu... FILE: src/tasks/decoders.py class Decoder (line 16) | class Decoder(nn.Module): method forward (line 21) | def forward(self, x, **kwargs): method step (line 33) | def step(self, x): class SequenceDecoder (line 40) | class SequenceDecoder(Decoder): method __init__ (line 41) | def __init__( method forward (line 70) | def forward(self, x, state=None, lengths=None, l_output=None): method step (line 156) | def step(self, x, state=None): function _instantiate (line 197) | def _instantiate(decoder, model=None, dataset=None): function instantiate (line 217) | def instantiate(decoder, model=None, dataset=None): FILE: src/tasks/encoders.py class Encoder (line 7) | class Encoder(nn.Module): method forward (line 16) | def forward(self, x, **kwargs): function _instantiate (line 64) | def _instantiate(encoder, dataset=None, model=None): function instantiate (line 84) | def instantiate(encoder, dataset=None, model=None): FILE: src/tasks/metrics.py class CorrectAggregatedMetric (line 13) | class CorrectAggregatedMetric(Metric): method __init__ (line 16) | def __init__(self, class_idx: int, dist_sync_on_step=False): method _update (line 25) | def _update(self, numerator, denominator, preds, y) -> tuple: method update (line 28) | def update(self, logits: torch.Tensor, y: torch.Tensor): method compute (line 36) | def compute(self): method reset (line 41) | def reset(self): class AccuracyPerClass (line 45) | class AccuracyPerClass(CorrectAggregatedMetric): method _update (line 48) | def _update(self, numerator, denominator, preds, y) -> tuple: class PrecisionPerClass (line 59) | class PrecisionPerClass(CorrectAggregatedMetric): method _update (line 62) | def _update(self, numerator, denominator, preds, y) -> tuple: class RecallPerClass (line 71) | class RecallPerClass(CorrectAggregatedMetric): method _update (line 74) | def _update(self, numerator, denominator, preds, y) -> tuple: function mcc (line 83) | def mcc(logits, y): function last_k_ppl (line 90) | def last_k_ppl(logits, y, seq_len=1024, k=None): function _student_t_map (line 122) | def _student_t_map(mu, sigma, nu): function student_t_loss (line 127) | def student_t_loss(outs, y): function gaussian_ll_loss (line 144) | def gaussian_ll_loss(outs, y): function binary_cross_entropy (line 155) | def binary_cross_entropy(logits, y): function binary_accuracy (line 161) | def binary_accuracy(logits, y): function padded_cross_entropy (line 164) | def padded_cross_entropy(logits, y, pad_mask, pad_value=-1): function cross_entropy (line 181) | def cross_entropy(logits, y, ignore_index=-100): function soft_cross_entropy (line 187) | def soft_cross_entropy(logits, y, label_smoothing=0.0): function accuracy (line 193) | def accuracy(logits, y): function accuracy_ignore_index (line 203) | def accuracy_ignore_index(logits, y, ignore_index=-100): function accuracy_at_k (line 212) | def accuracy_at_k(logits, y, k=1): function f1_binary (line 221) | def f1_binary(logits, y): function f1_macro (line 228) | def f1_macro(logits, y): function f1_micro (line 235) | def f1_micro(logits, y): function roc_auc_macro (line 242) | def roc_auc_macro(logits, y): function roc_auc_micro (line 252) | def roc_auc_micro(logits, y): function mse (line 260) | def mse(outs, y, len_batch=None): function forecast_rmse (line 278) | def forecast_rmse(outs, y, len_batch=None): function mae (line 282) | def mae(outs, y, len_batch=None): function loss (line 301) | def loss(x, y, loss_fn): function bpb (line 306) | def bpb(x, y, loss_fn): function ppl (line 311) | def ppl(x, y, loss_fn): FILE: src/tasks/tasks.py class BaseTask (line 16) | class BaseTask: method __init__ (line 27) | def __init__(self, dataset=None, model=None, loss=None, loss_val=None,... method _init_torchmetrics (line 55) | def _init_torchmetrics(self): method _reset_torchmetrics (line 83) | def _reset_torchmetrics(self, prefix=None): method get_torchmetrics (line 96) | def get_torchmetrics(self, prefix): method torchmetrics (line 105) | def torchmetrics(self, x, y, prefix, loss=None): method get_torchmetrics (line 124) | def get_torchmetrics(self, prefix): method metrics (line 127) | def metrics(self, x, y, **kwargs): method forward (line 143) | def forward(self, batch, encoder, model, decoder, _state): class Scalar (line 161) | class Scalar(nn.Module): method __init__ (line 162) | def __init__(self, c=1): method forward (line 166) | def forward(self, x): class LMTask (line 170) | class LMTask(BaseTask): method forward (line 171) | def forward(self, batch, encoder, model, decoder, _state): class MultiClass (line 199) | class MultiClass(BaseTask): method __init__ (line 201) | def __init__(self, *args, **kwargs): method metrics (line 211) | def metrics(self, x, y, **kwargs): method _reset_torchmetrics (line 236) | def _reset_torchmetrics(self, prefix=None): class HG38Task (line 244) | class HG38Task(LMTask): method __init__ (line 246) | def __init__(self, dataset=None, model=None, loss=None, loss_val=None,... method metrics (line 303) | def metrics(self, x, y, **kwargs): class AdaptiveLMTask (line 335) | class AdaptiveLMTask(BaseTask): method __init__ (line 336) | def __init__( FILE: src/tasks/torchmetrics.py class Perplexity (line 24) | class Perplexity(Metric): method __init__ (line 46) | def __init__(self, **kwargs: Dict[str, Any]): method update (line 54) | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor]... method compute (line 68) | def compute(self) -> Tensor: class NumTokens (line 75) | class NumTokens(Metric): method __init__ (line 88) | def __init__(self, **kwargs: Dict[str, Any]): method update (line 97) | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor]... method compute (line 100) | def compute(self) -> Tensor: method reset (line 103) | def reset(self): method _forward_reduce_state_update (line 109) | def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: FILE: src/utils/config.py function is_list (line 13) | def is_list(x): function is_dict (line 17) | def is_dict(x): function to_dict (line 21) | def to_dict(x, recursive=True): function to_list (line 37) | def to_list(x, recursive=False): function extract_attrs_from_obj (line 56) | def extract_attrs_from_obj(obj, *attrs): function auto_assign_attrs (line 63) | def auto_assign_attrs(cls, **kwargs): function instantiate (line 68) | def instantiate(registry, config, *args, partial=False, wrap=None, **kwa... function get_class (line 112) | def get_class(registry, _name_): function omegaconf_filter_keys (line 116) | def omegaconf_filter_keys(d, fn=None): FILE: src/utils/optim/schedulers.py class CosineWarmup (line 11) | class CosineWarmup(torch.optim.lr_scheduler.CosineAnnealingLR): method __init__ (line 13) | def __init__(self, optimizer, T_max, eta_min=0, warmup_step=0, **kwargs): method get_lr (line 19) | def get_lr(self): function InvSqrt (line 40) | def InvSqrt(optimizer, warmup_step): function Constant (line 54) | def Constant(optimizer, warmup_step): class TimmCosineLRScheduler (line 65) | class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler.... method __init__ (line 70) | def __init__(self, *args, **kwargs): method step (line 75) | def step(self, epoch=None): FILE: src/utils/optim_groups.py function add_optimizer_hooks (line 14) | def add_optimizer_hooks( function group_parameters_for_optimizer (line 41) | def group_parameters_for_optimizer( FILE: src/utils/train.py class LoggingContext (line 20) | class LoggingContext: method __init__ (line 21) | def __init__(self, logger, level=None, handler=None, close=True): method __enter__ (line 27) | def __enter__(self): method __exit__ (line 34) | def __exit__(self, et, ev, tb): function get_logger (line 44) | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: function process_config (line 58) | def process_config(config: DictConfig) -> DictConfig: # TODO because of... function print_config (line 101) | def print_config( function log_optimizer (line 143) | def log_optimizer(logger, optimizer, keys): class OptimModule (line 154) | class OptimModule(nn.Module): method register (line 157) | def register(self, name, tensor, lr=None, wd=0.0): FILE: train.py class DummyExperiment (line 43) | class DummyExperiment: method nop (line 46) | def nop(self, *args, **kw): method __getattr__ (line 49) | def __getattr__(self, _): method __getitem__ (line 52) | def __getitem__(self, idx) -> "DummyExperiment": method __setitem__ (line 56) | def __setitem__(self, *args, **kwargs) -> None: function rank_zero_experiment (line 60) | def rank_zero_experiment(fn: Callable) -> Callable: class CustomWandbLogger (line 74) | class CustomWandbLogger(WandbLogger): method __init__ (line 76) | def __init__(self, *args, **kwargs): method experiment (line 83) | def experiment(self): class SequenceLightningModule (line 126) | class SequenceLightningModule(pl.LightningModule): method __init__ (line 127) | def __init__(self, config): method setup (line 159) | def setup(self, stage=None): method load_state_dict (line 240) | def load_state_dict(self, state_dict, strict=False): method _check_config (line 255) | def _check_config(self): method _initialize_state (line 268) | def _initialize_state(self): method _reset_state (line 273) | def _reset_state(self, batch, device=None): method _detach_state (line 278) | def _detach_state(self, state): method _process_state (line 292) | def _process_state(self, batch, batch_idx, training=True): method forward (line 326) | def forward(self, batch): method step (line 329) | def step(self, x_t): method _shared_step (line 336) | def _shared_step(self, batch, batch_idx, prefix="train"): method on_train_epoch_start (line 379) | def on_train_epoch_start(self): method training_epoch_end (line 383) | def training_epoch_end(self, outputs): method on_validation_epoch_start (line 387) | def on_validation_epoch_start(self): method validation_epoch_end (line 392) | def validation_epoch_end(self, outputs): method on_test_epoch_start (line 396) | def on_test_epoch_start(self): method test_epoch_end (line 401) | def test_epoch_end(self, outputs): method training_step (line 405) | def training_step(self, batch, batch_idx, dataloader_idx=0): method validation_step (line 438) | def validation_step(self, batch, batch_idx, dataloader_idx=0): method test_step (line 455) | def test_step(self, batch, batch_idx, dataloader_idx=0): method configure_optimizers (line 460) | def configure_optimizers(self): method train_dataloader (line 543) | def train_dataloader(self): method _eval_dataloaders_names (line 546) | def _eval_dataloaders_names(self, loaders, prefix): method _eval_dataloaders (line 557) | def _eval_dataloaders(self): method val_dataloader (line 584) | def val_dataloader(self): method test_dataloader (line 589) | def test_dataloader(self): function create_trainer (line 596) | def create_trainer(config, **kwargs): function fsspec_exists (line 649) | def fsspec_exists(filename): function train (line 654) | def train(config): function main (line 701) | def main(config: OmegaConf): FILE: vep_embeddings.py class DNAEmbeddingModel (line 30) | class DNAEmbeddingModel(nn.Module): method __init__ (line 36) | def __init__( method forward (line 56) | def forward(self, input_ids): class EnformerTokenizer (line 63) | class EnformerTokenizer: method encode (line 70) | def encode( method batch_encode_plus (line 83) | def batch_encode_plus( function setup_distributed (line 92) | def setup_distributed(): function cleanup_distributed (line 97) | def cleanup_distributed(): function fsspec_exists (line 102) | def fsspec_exists(filename): function fsspec_listdir (line 108) | def fsspec_listdir(dirname): function recast_chromosome_tissue_dist2TSS (line 115) | def recast_chromosome_tissue_dist2TSS(examples): function tokenize_variants (line 124) | def tokenize_variants(examples, tokenizer, max_length: int): function find_variant_idx (line 172) | def find_variant_idx(examples): function prepare_dataset (line 198) | def prepare_dataset(args, tokenizer): function get_backbone_model (line 260) | def get_backbone_model(args, device): function concat_storage_dict_values (line 270) | def concat_storage_dict_values(storage_dict): function dump_embeddings (line 275) | def dump_embeddings(args, dataset, model, device): function combine_embeddings (line 407) | def combine_embeddings(embeds_path): function main (line 433) | def main(args):