SYMBOL INDEX (409 symbols across 25 files) FILE: demo.py function parse_labels_input (line 204) | def parse_labels_input(labels_input: str) -> Union[List[str], Dict[str, ... function parse_examples_input (line 226) | def parse_examples_input(examples_input: str) -> Optional[List[Dict[str,... function format_output (line 257) | def format_output( function format_as_json (line 277) | def format_as_json(results: Union[List[Dict], Dict], hierarchical: bool ... function format_hierarchical_dict (line 297) | def format_hierarchical_dict(d: Dict, indent: int = 0) -> str: function classification (line 313) | def classification( function update_output_visibility (line 821) | def update_output_visibility(hierarchical: bool, fmt: str): function classify_wrapper (line 842) | def classify_wrapper(text, labels, threshold, multi_label, prompt, examp... FILE: gliclass/config.py class GLiClassModelConfig (line 19) | class GLiClassModelConfig(PretrainedConfig): method __init__ (line 23) | def __init__( FILE: gliclass/data_processing.py class AugmentationConfig (line 11) | class AugmentationConfig: class DataAugmenter (line 26) | class DataAugmenter: method __init__ (line 27) | def __init__(self, config, examples, labels, label2description=None): method remove_labels (line 34) | def remove_labels(self, true_labels, all_labels): method add_random_labels (line 42) | def add_random_labels(self, all_labels): method add_random_text (line 51) | def add_random_text(self, text, all_labels): method add_random_synonyms (line 66) | def add_random_synonyms(self, all_labels): method add_random_descriptions (line 86) | def add_random_descriptions(self, item): method add_random_examples (line 111) | def add_random_examples(self, item): method augment (line 145) | def augment(self, item): class GLiClassDataset (line 184) | class GLiClassDataset(Dataset): method __init__ (line 185) | def __init__( method get_diversity (line 222) | def get_diversity(self): method collect_dataset_labels (line 225) | def collect_dataset_labels(self): method prepare_labels (line 231) | def prepare_labels(self, example, label2idx, problem_type): method prepare_prompt (line 243) | def prepare_prompt(self, item, label_token_first=True): method format_examples (line 256) | def format_examples(self, item): method tokenize (line 270) | def tokenize(self, texts): method tokenize_labels (line 274) | def tokenize_labels(self, labels): method tokenize_and_prepare_labels_for_uniencoder (line 278) | def tokenize_and_prepare_labels_for_uniencoder(self, example): method tokenize_and_prepare_labels_for_encoder_decoder (line 295) | def tokenize_and_prepare_labels_for_encoder_decoder(self, example): method tokenize_and_prepare_labels_for_biencoder (line 312) | def tokenize_and_prepare_labels_for_biencoder(self, example): method __len__ (line 346) | def __len__(self): method __getitem__ (line 349) | def __getitem__(self, idx): function pad_2d_tensor (line 365) | def pad_2d_tensor(key_data): class DataCollatorWithPadding (line 395) | class DataCollatorWithPadding: method __init__ (line 396) | def __init__(self, device="cuda:0", config=None): method _resolve_max_num_classes (line 400) | def _resolve_max_num_classes(self, batch): method __call__ (line 415) | def __call__(self, batch): FILE: gliclass/layers.py class LstmSeq2SeqEncoder (line 23) | class LstmSeq2SeqEncoder(nn.Module): method __init__ (line 24) | def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0,... method forward (line 35) | def forward(self, x, mask, hidden=None): class FeaturesProjector (line 49) | class FeaturesProjector(nn.Module): method __init__ (line 50) | def __init__(self, config: GLiClassModelConfig): method forward (line 58) | def forward(self, features): class BiEncoderProjector (line 66) | class BiEncoderProjector(nn.Module): method __init__ (line 67) | def __init__(self, config: GLiClassModelConfig): method forward (line 74) | def forward(self, features): class DropoutContext (line 82) | class DropoutContext: method __init__ (line 83) | def __init__(self): function get_mask (line 91) | def get_mask(input, local_context): class XDropout (line 110) | class XDropout(torch.autograd.Function): method forward (line 114) | def forward(ctx, input, local_ctx): method backward (line 124) | def backward(ctx, grad_output): method symbolic (line 132) | def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: floa... class StableDropout (line 150) | class StableDropout(nn.Module): method __init__ (line 158) | def __init__(self, drop_prob): method forward (line 164) | def forward(self, x): method clear_context (line 175) | def clear_context(self): method init_context (line 179) | def init_context(self, reuse_mask=True, scale=1): method get_context (line 187) | def get_context(self): class SelfAttentionBlock (line 199) | class SelfAttentionBlock(nn.Module): method __init__ (line 200) | def __init__(self, d_model, num_heads, dropout=0.1): method forward (line 206) | def forward(self, x, mask=None): class CrossAttentionBlock (line 211) | class CrossAttentionBlock(nn.Module): method __init__ (line 212) | def __init__(self, d_model, num_heads, dropout=0.1): method forward (line 218) | def forward(self, query, key, value, mask=None): class Fuser (line 223) | class Fuser(nn.Module): method __init__ (line 224) | def __init__(self, d_model, num_heads, num_layers, dropout=0.1): method forward (line 237) | def forward(self, query, key, query_mask=None, key_mask=None): class LayerwiseAttention (line 254) | class LayerwiseAttention(nn.Module): method __init__ (line 255) | def __init__(self, num_layers, hidden_size, output_size=None): method forward (line 271) | def forward(self, encoder_outputs): FILE: gliclass/loss_functions.py function sequence_contrastive_loss (line 5) | def sequence_contrastive_loss(embeddings, mask): function focal_loss_with_logits (line 31) | def focal_loss_with_logits( FILE: gliclass/model.py class GLiClassOutput (line 60) | class GLiClassOutput(SequenceClassifierOutput): class GLiClassPreTrainedModel (line 65) | class GLiClassPreTrainedModel(PreTrainedModel): method _initialize_weights (line 72) | def _initialize_weights(self, module, is_remote_code: bool = False): method _init_weights (line 88) | def _init_weights(self, module): class GLiClassBaseModel (line 117) | class GLiClassBaseModel(nn.Module): # ): method __init__ (line 118) | def __init__(self, config: GLiClassModelConfig, device="cpu", **kwargs): method _extract_class_features (line 162) | def _extract_class_features(self, token_embeds, input_ids, attention_m... method _extract_class_features_first (line 224) | def _extract_class_features_first( method _extract_class_features_averaged (line 255) | def _extract_class_features_averaged( method get_loss (line 296) | def get_loss(self, logits, labels, classes_embedding=None, classes_emb... class GLiClassUniEncoder (line 347) | class GLiClassUniEncoder(GLiClassBaseModel): method __init__ (line 348) | def __init__(self, config: GLiClassModelConfig, from_pretrained=False): method _create_segment_ids (line 419) | def _create_segment_ids(self, input_ids): method process_encoder_output (line 444) | def process_encoder_output(self, input_ids, attention_mask, encoder_la... method forward (line 469) | def forward( class GLiClassEncoderDecoder (line 556) | class GLiClassEncoderDecoder(GLiClassBaseModel): method __init__ (line 557) | def __init__(self, config: GLiClassModelConfig, from_pretrained=False): method _make_bidirectional_4d_mask (line 573) | def _make_bidirectional_4d_mask(attention_mask_2d, dtype): method forward (line 593) | def forward( class GLiClassEncoderDecoderCLS (line 674) | class GLiClassEncoderDecoderCLS(GLiClassBaseModel): method __init__ (line 681) | def __init__(self, config: GLiClassModelConfig, from_pretrained=False): method forward (line 696) | def forward( class GLiClassBiEncoder (line 767) | class GLiClassBiEncoder(GLiClassBaseModel): method __init__ (line 768) | def __init__(self, config: GLiClassModelConfig, from_pretrained=False): method pool_outputs (line 790) | def pool_outputs(self, encoder_outputs): method encode_text (line 798) | def encode_text(self, input_ids, attention_mask): method encode_classes (line 803) | def encode_classes(self, class_input_ids, class_attention_mask, labels... method forward (line 835) | def forward( class GLiClassBiEncoderFused (line 871) | class GLiClassBiEncoderFused(GLiClassBiEncoder): method __init__ (line 872) | def __init__(self, config: GLiClassModelConfig, from_pretrained=False): method encode_text (line 875) | def encode_text(self, input_ids, attention_mask, class_embeddings, lab... method forward (line 895) | def forward( class GLiClassModel (line 937) | class GLiClassModel(GLiClassPreTrainedModel): method __init__ (line 938) | def __init__(self, config, from_pretrained=False): method get_input_embeddings (line 952) | def get_input_embeddings(self): method set_input_embeddings (line 960) | def set_input_embeddings(self, value): method tie_weights (line 971) | def tie_weights(self, recompute_mapping=True, missing_keys=None): method resize_token_embeddings (line 1025) | def resize_token_embeddings(self, new_num_tokens: int | None = None, p... method forward (line 1039) | def forward(self, *args, **kwargs): FILE: gliclass/ops.py function attn_padded (line 7) | def attn_padded( FILE: gliclass/pipeline.py function flatten_hierarchical_labels (line 12) | def flatten_hierarchical_labels( function build_hierarchical_output (line 69) | def build_hierarchical_output( function format_examples_prompt (line 135) | def format_examples_prompt( class BaseZeroShotClassificationPipeline (line 174) | class BaseZeroShotClassificationPipeline(ABC): method __init__ (line 175) | def __init__( method _normalize_classification_type (line 216) | def _normalize_classification_type(self, classification_type: str | No... method _normalize_texts (line 227) | def _normalize_texts(self, texts: str | List[str]) -> List[str]: method _normalize_thresholds (line 232) | def _normalize_thresholds(self, threshold: float | List[float], num_te... method _normalize_classification_types (line 239) | def _normalize_classification_types( method _process_labels (line 252) | def _process_labels( method _format_examples_for_input (line 281) | def _format_examples_for_input(self, examples: List[Dict[str, Any]] | ... method _examples_are_per_text (line 290) | def _examples_are_per_text(self, examples) -> bool: method _get_text_examples (line 298) | def _get_text_examples(self, examples, index: int): method _format_prompt (line 306) | def _format_prompt(self, prompt: str | List[str] | None = None, index:... method _resolve_max_num_classes (line 321) | def _resolve_max_num_classes(self, batch_labels, same_labels: bool): method prepare_inputs (line 329) | def prepare_inputs(self, texts, labels, same_labels=False, examples=No... method _get_batch_examples (line 332) | def _get_batch_examples(self, examples, start_idx, batch_size): method _get_batch_prompt (line 340) | def _get_batch_prompt(self, prompt, start_idx, batch_size): method get_embeddings (line 349) | def get_embeddings(self, texts, labels, batch_size=8, examples=None, p... method __call__ (line 397) | def __call__( class UniEncoderZeroShotClassificationPipeline (line 522) | class UniEncoderZeroShotClassificationPipeline(BaseZeroShotClassificatio... method __init__ (line 523) | def __init__( method prepare_input (line 538) | def prepare_input(self, text, labels, examples=None, prompt=None): method prepare_inputs (line 565) | def prepare_inputs(self, texts, labels, same_labels=False, examples=No... class EncoderDecoderZeroShotClassificationPipeline (line 586) | class EncoderDecoderZeroShotClassificationPipeline(BaseZeroShotClassific... method __init__ (line 587) | def __init__( method prepare_labels_prompt (line 602) | def prepare_labels_prompt(self, labels, prompt=None): method prepare_inputs (line 617) | def prepare_inputs(self, texts, labels, same_labels=False, examples=No... class BiEncoderZeroShotClassificationPipeline (line 650) | class BiEncoderZeroShotClassificationPipeline(BaseZeroShotClassification... method __init__ (line 651) | def __init__( method prepare_input (line 667) | def prepare_input(self, text, labels, examples=None, prompt=None): method prepare_inputs (line 687) | def prepare_inputs(self, texts, labels, same_labels=False, examples=No... class ZeroShotClassificationPipeline (line 746) | class ZeroShotClassificationPipeline: method __init__ (line 825) | def __init__( method flatten_labels (line 869) | def flatten_labels(self, labels: List[str] | Dict[str, Any]) -> List[s... method get_embeddings (line 882) | def get_embeddings(self, *args, **kwargs): method __call__ (line 886) | def __call__( class ZeroShotClassificationWithChunkingPipeline (line 933) | class ZeroShotClassificationWithChunkingPipeline(BaseZeroShotClassificat... method __init__ (line 936) | def __init__( method chunk_text (line 960) | def chunk_text(self, text, chunk_size=None, overlap=None): method prepare_input (line 984) | def prepare_input(self, text, labels, examples=None, prompt=None): method prepare_inputs (line 1011) | def prepare_inputs(self, texts, labels, same_labels=False, examples=No... method aggregate_chunk_scores (line 1030) | def aggregate_chunk_scores(self, chunk_scores: List[Dict[str, float]],... method process_single_text (line 1041) | def process_single_text(self, text, labels, threshold=0.5, examples=No... method __call__ (line 1095) | def __call__( function parse_hierarchical_prediction (line 1221) | def parse_hierarchical_prediction(prediction: str, separator: str = ".")... function group_predictions_by_hierarchy (line 1230) | def group_predictions_by_hierarchy( function get_best_per_category (line 1250) | def get_best_per_category(predictions: List[Dict[str, Any]], separator: ... FILE: gliclass/poolings.py class GlobalMaxPooling1D (line 5) | class GlobalMaxPooling1D(nn.Module): method forward (line 8) | def forward(self, x: torch.Tensor): class FirstTokenPooling1D (line 12) | class FirstTokenPooling1D(nn.Module): method forward (line 15) | def forward(self, x: torch.Tensor): class LastTokenPooling1D (line 19) | class LastTokenPooling1D(nn.Module): method forward (line 22) | def forward(self, x: torch.Tensor): class GlobalAvgPooling1D (line 26) | class GlobalAvgPooling1D(nn.Module): method forward (line 29) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None... class GlobalSumPooling1D (line 38) | class GlobalSumPooling1D(nn.Module): method forward (line 41) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None... class GlobalRMSPooling1D (line 47) | class GlobalRMSPooling1D(nn.Module): method forward (line 50) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None... class GlobalAbsMaxPooling1D (line 59) | class GlobalAbsMaxPooling1D(nn.Module): method forward (line 62) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None... class GlobalAbsAvgPooling1D (line 69) | class GlobalAbsAvgPooling1D(nn.Module): method forward (line 72) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None... class PassPooling1D (line 81) | class PassPooling1D(nn.Module): method forward (line 84) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None... FILE: gliclass/scorers.py class ScorerWeightedDot (line 7) | class ScorerWeightedDot(nn.Module): method __init__ (line 8) | def __init__(self, hidden_size, dropout=0.1, **kwargs): method forward (line 21) | def forward(self, text_rep, label_rep, **kwargs): class ScorerDot (line 42) | class ScorerDot(nn.Module): method __init__ (line 43) | def __init__(self, *args, **kwargs): method forward (line 47) | def forward(self, text_rep, label_rep, **kwargs): class MLPScorer (line 53) | class MLPScorer(nn.Module): method __init__ (line 54) | def __init__(self, hidden_size, mlp_hidden_size=256, **kwargs): method forward (line 69) | def forward(self, text_rep, label_rep, **kwargs): class HopfieldScorer (line 81) | class HopfieldScorer(nn.Module): method __init__ (line 82) | def __init__(self, hidden_size, mlp_hidden_size=256, beta=4, num_itera... method forward (line 101) | def forward(self, text_rep, label_rep, **kwargs): class CrossAttnScorer (line 128) | class CrossAttnScorer(nn.Module): method __init__ (line 129) | def __init__(self, hidden_size, num_heads=16, attn_dropout=0.1, scorer... method forward (line 154) | def forward(self, text_rep, label_rep, text_mask=None, **kwargs): FILE: gliclass/serve/__main__.py function main (line 21) | def main(): FILE: gliclass/serve/client.py class GLiClassClient (line 6) | class GLiClassClient: method __init__ (line 9) | def __init__(self, url: str = "http://localhost:8000/gliclass"): method __call__ (line 17) | def __call__( method classify (line 55) | def classify( method health_check (line 80) | def health_check(self) -> bool: FILE: gliclass/serve/config.py class GLiClassServeConfig (line 10) | class GLiClassServeConfig: method __post_init__ (line 63) | def __post_init__(self): method to_env_vars (line 68) | def to_env_vars(self) -> dict: method from_yaml (line 76) | def from_yaml(cls, config_path: str | Path) -> "GLiClassServeConfig": method to_yaml (line 90) | def to_yaml(self, config_path: str | Path) -> None: method update (line 101) | def update(self, **kwargs) -> "GLiClassServeConfig": FILE: gliclass/serve/memory.py function _power_of_two_seq_lens (line 21) | def _power_of_two_seq_lens(max_seq_len: int, min_seq_len: int = 64) -> L... class GLiClassMemoryEstimator (line 32) | class GLiClassMemoryEstimator: method __init__ (line 35) | def __init__( method measure_cuda_context (line 51) | def measure_cuda_context(self) -> None: method measure_model_memory (line 61) | def measure_model_memory(self) -> None: method available_memory (line 73) | def available_memory(self) -> int: method calibrate (line 80) | def calibrate( method _measure_peak (line 106) | def _measure_peak( method _lookup_seq_len (line 124) | def _lookup_seq_len(self, seq_len: int) -> int: method per_sample_at (line 133) | def per_sample_at(self, seq_len: int) -> int: method batch_size_fn (line 138) | def batch_size_fn( FILE: gliclass/serve/server.py class GLiClassServer (line 20) | class GLiClassServer: method __init__ (line 23) | def __init__(self, config: GLiClassServeConfig): method _precompile (line 92) | def _precompile(self) -> None: method _calibrate_memory (line 118) | def _calibrate_memory(self) -> None: method batch_size_fn (line 129) | def batch_size_fn(self, seq_len: int | None = None) -> int: method observed_seq_len (line 149) | def observed_seq_len( method _filter_labels (line 176) | def _filter_labels(self, labels: list[str]) -> list[str]: method _run_batch_internal (line 183) | def _run_batch_internal( method predict (line 220) | def predict( function _build_deployment (line 249) | def _build_deployment(config: GLiClassServeConfig): function serve_gliclass (line 359) | def serve_gliclass( function shutdown (line 396) | def shutdown() -> None: class GLiClassFactory (line 400) | class GLiClassFactory: method __init__ (line 420) | def __init__( method handle (line 441) | def handle(self): method predict (line 445) | def predict( method predict_async (line 472) | async def predict_async( method shutdown (line 501) | def shutdown(self) -> None: method __enter__ (line 516) | def __enter__(self): method __exit__ (line 519) | def __exit__(self, exc_type, exc_val, exc_tb): method __del__ (line 523) | def __del__(self): FILE: gliclass/training.py class EWC (line 32) | class EWC: method __init__ (line 35) | def __init__( method _compute_fisher (line 82) | def _compute_fisher(self, dataset: Dataset) -> Dict[str, torch.Tensor]: method _normalize_fisher (line 187) | def _normalize_fisher(self): method ewc_loss (line 199) | def ewc_loss(self, batch_size: int | None = None) -> torch.Tensor: method get_importance_scores (line 226) | def get_importance_scores(self) -> Dict[str, float]: method update_lambda (line 237) | def update_lambda(self, new_lambda: float): method consolidate (line 245) | def consolidate(self, dataset: Dataset, alpha: float = 0.5): class TrainingArguments (line 273) | class TrainingArguments(transformers.TrainingArguments): class Trainer (line 294) | class Trainer(transformers.Trainer): method __init__ (line 297) | def __init__(self, ewc: EWC | None = None, prev_dataset=None, *args, *... method _maybe_initialize_ewc (line 316) | def _maybe_initialize_ewc(self): method compute_loss (line 352) | def compute_loss(self, model, inputs, return_outputs=False, **kwargs): method train (line 381) | def train(self, *args, **kwargs): method training_step (line 387) | def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor: method prediction_step (line 437) | def prediction_step( method create_optimizer (line 498) | def create_optimizer(self): class RLTrainerConfig (line 577) | class RLTrainerConfig(TrainingArguments): class RLTrainer (line 607) | class RLTrainer(Trainer): method __init__ (line 608) | def __init__( method _init_metrics (line 625) | def _init_metrics(self): method compute_rewards (line 634) | def compute_rewards( method get_reference_scores (line 646) | def get_reference_scores(self, input_texts, labels_text): method compute_loss (line 671) | def compute_loss( method _inner_training_loop (line 757) | def _inner_training_loop(self, *args, **kwargs): method log_metrics (line 845) | def log_metrics(self): method _save_checkpoint (line 857) | def _save_checkpoint(self, model, step=None): FILE: gliclass/utils.py function is_module_available (line 4) | def is_module_available(module_name): class MissedPackageException (line 21) | class MissedPackageException(Exception): function retrieval_augmented_text (line 27) | def retrieval_augmented_text(text: str, examples: list) -> str: function default_f1_reward (line 64) | def default_f1_reward( FILE: test_gliclass.py class TestModel (line 14) | class TestModel: method __init__ (line 16) | def __init__(self, model, token): method load_model (line 28) | def load_model(self): method prepare_dataset (line 34) | def prepare_dataset(self, dataset, classes=None, text_column='text', l... method prepare_nomapping (line 55) | def prepare_nomapping(self, dataset, classes=None, text_column='text',... method get_gliclass_predictions (line 79) | def get_gliclass_predictions(self, test_texts, classes, batch_size=8): method evaluate (line 84) | def evaluate(self, predicts, true_labels): method process (line 90) | def process(self): FILE: tests/test_data_processing.py class TestPad2DTensor (line 9) | class TestPad2DTensor: method sample_tensors (line 13) | def sample_tensors(self): method test_pads_to_maximum_dimensions (line 21) | def test_pads_to_maximum_dimensions(self, sample_tensors): method test_preserves_original_values (line 28) | def test_preserves_original_values(self, sample_tensors): method test_pads_with_zeros (line 48) | def test_pads_with_zeros(self, sample_tensors): method test_single_tensor (line 58) | def test_single_tensor(self): method test_uniform_size_tensors (line 67) | def test_uniform_size_tensors(self): method test_empty_tensor_handling (line 81) | def test_empty_tensor_handling(self): method test_preserves_dtype (line 94) | def test_preserves_dtype(self): method test_varying_row_counts (line 105) | def test_varying_row_counts(self): method test_varying_column_counts (line 118) | def test_varying_column_counts(self): method test_batch_consistency (line 131) | def test_batch_consistency(self): FILE: tests/test_loss_functions.py class TestSequenceContrastiveLoss (line 9) | class TestSequenceContrastiveLoss: method sample_embeddings (line 13) | def sample_embeddings(self): method sample_mask (line 21) | def sample_mask(self): method test_returns_scalar_loss (line 25) | def test_returns_scalar_loss(self, sample_embeddings, sample_mask): method test_loss_is_positive (line 32) | def test_loss_is_positive(self, sample_embeddings, sample_mask): method test_identical_sequences_low_loss (line 38) | def test_identical_sequences_low_loss(self): method test_handles_masked_positions (line 48) | def test_handles_masked_positions(self): method test_gradient_flows_through_loss (line 58) | def test_gradient_flows_through_loss(self, sample_embeddings, sample_m... class TestFocalLossWithLogits (line 68) | class TestFocalLossWithLogits: method sample_logits (line 72) | def sample_logits(self): method sample_targets (line 77) | def sample_targets(self): method test_returns_tensor_with_reduction_none (line 81) | def test_returns_tensor_with_reduction_none(self, sample_logits, sampl... method test_returns_scalar_with_reduction_mean (line 88) | def test_returns_scalar_with_reduction_mean(self, sample_logits, sampl... method test_loss_is_positive (line 94) | def test_loss_is_positive(self, sample_logits, sample_targets): method test_perfect_predictions_low_loss (line 100) | def test_perfect_predictions_low_loss(self): method test_wrong_predictions_high_loss (line 110) | def test_wrong_predictions_high_loss(self): method test_alpha_parameter_effect (line 120) | def test_alpha_parameter_effect(self, sample_logits, sample_targets): method test_gamma_parameter_effect (line 128) | def test_gamma_parameter_effect(self, sample_logits, sample_targets): method test_reduction_sum (line 137) | def test_reduction_sum(self, sample_logits, sample_targets): method test_reduction_none (line 143) | def test_reduction_none(self, sample_logits, sample_targets): method test_handles_extreme_logits (line 149) | def test_handles_extreme_logits(self): method test_gradient_flows_through_loss (line 159) | def test_gradient_flows_through_loss(self, sample_logits, sample_targe... method test_all_zeros_targets (line 168) | def test_all_zeros_targets(self): method test_all_ones_targets (line 178) | def test_all_ones_targets(self): FILE: tests/test_poolings.py class TestGlobalMaxPooling1D (line 19) | class TestGlobalMaxPooling1D: method pooling_layer (line 23) | def pooling_layer(self): method sample_input (line 28) | def sample_input(self): method test_returns_max_across_sequence (line 32) | def test_returns_max_across_sequence(self, pooling_layer, sample_input): method test_output_shape (line 39) | def test_output_shape(self, pooling_layer, sample_input): class TestGlobalAvgPooling1D (line 46) | class TestGlobalAvgPooling1D: method pooling_layer (line 50) | def pooling_layer(self): method test_returns_average_across_sequence (line 54) | def test_returns_average_across_sequence(self, pooling_layer): method test_handles_attention_mask (line 63) | def test_handles_attention_mask(self, pooling_layer): method test_output_shape (line 73) | def test_output_shape(self, pooling_layer): class TestGlobalSumPooling1D (line 82) | class TestGlobalSumPooling1D: method pooling_layer (line 86) | def pooling_layer(self): method test_returns_sum_across_sequence (line 90) | def test_returns_sum_across_sequence(self, pooling_layer): method test_handles_attention_mask (line 99) | def test_handles_attention_mask(self, pooling_layer): class TestFirstTokenPooling1D (line 110) | class TestFirstTokenPooling1D: method pooling_layer (line 114) | def pooling_layer(self): method test_returns_first_token (line 118) | def test_returns_first_token(self, pooling_layer): method test_works_with_batch (line 127) | def test_works_with_batch(self, pooling_layer): method test_output_shape (line 136) | def test_output_shape(self, pooling_layer): class TestLastTokenPooling1D (line 145) | class TestLastTokenPooling1D: method pooling_layer (line 149) | def pooling_layer(self): method test_returns_last_token (line 153) | def test_returns_last_token(self, pooling_layer): method test_works_with_batch (line 162) | def test_works_with_batch(self, pooling_layer): method test_output_shape (line 171) | def test_output_shape(self, pooling_layer): class TestGlobalRMSPooling1D (line 180) | class TestGlobalRMSPooling1D: method pooling_layer (line 184) | def pooling_layer(self): method test_returns_rms_across_sequence (line 188) | def test_returns_rms_across_sequence(self, pooling_layer): method test_handles_attention_mask (line 197) | def test_handles_attention_mask(self, pooling_layer): method test_output_shape (line 207) | def test_output_shape(self, pooling_layer): class TestGlobalAbsMaxPooling1D (line 216) | class TestGlobalAbsMaxPooling1D: method pooling_layer (line 220) | def pooling_layer(self): method test_returns_abs_max_across_sequence (line 224) | def test_returns_abs_max_across_sequence(self, pooling_layer): method test_handles_attention_mask (line 233) | def test_handles_attention_mask(self, pooling_layer): method test_output_shape (line 243) | def test_output_shape(self, pooling_layer): class TestGlobalAbsAvgPooling1D (line 252) | class TestGlobalAbsAvgPooling1D: method pooling_layer (line 256) | def pooling_layer(self): method test_returns_abs_avg_across_sequence (line 260) | def test_returns_abs_avg_across_sequence(self, pooling_layer): method test_handles_attention_mask (line 269) | def test_handles_attention_mask(self, pooling_layer): method test_output_shape (line 279) | def test_output_shape(self, pooling_layer): class TestPassPooling1D (line 288) | class TestPassPooling1D: method pooling_layer (line 292) | def pooling_layer(self): method test_returns_input_unchanged (line 296) | def test_returns_input_unchanged(self, pooling_layer): method test_ignores_attention_mask (line 304) | def test_ignores_attention_mask(self, pooling_layer): method test_maintains_shape (line 313) | def test_maintains_shape(self, pooling_layer): FILE: tests/test_scorers.py class TestScorerWeightedDot (line 15) | class TestScorerWeightedDot: method scorer (line 17) | def scorer(self): method test_forward_pass (line 20) | def test_forward_pass(self, scorer): method test_gradient_flow (line 29) | def test_gradient_flow(self, scorer): class TestScorerDot (line 41) | class TestScorerDot: method scorer (line 43) | def scorer(self): method test_forward_pass (line 46) | def test_forward_pass(self, scorer): method test_gradient_flow (line 55) | def test_gradient_flow(self, scorer): class TestMLPScorer (line 67) | class TestMLPScorer: method scorer (line 69) | def scorer(self): method test_forward_pass (line 72) | def test_forward_pass(self, scorer): method test_different_batch_sizes (line 81) | def test_different_batch_sizes(self, scorer): method test_gradient_flow (line 90) | def test_gradient_flow(self, scorer): class TestHopfieldScorer (line 102) | class TestHopfieldScorer: method scorer (line 104) | def scorer(self): method test_forward_pass (line 107) | def test_forward_pass(self, scorer): method test_multiple_iterations (line 116) | def test_multiple_iterations(self): method test_gradient_flow (line 125) | def test_gradient_flow(self, scorer): class TestCrossAttnScorer (line 137) | class TestCrossAttnScorer: method scorer (line 139) | def scorer(self): method test_forward_pass_with_text_mask (line 142) | def test_forward_pass_with_text_mask(self, scorer): method test_forward_pass_without_text_mask (line 153) | def test_forward_pass_without_text_mask(self, scorer): method test_different_seq_lengths (line 162) | def test_different_seq_lengths(self, scorer): method test_gradient_flow (line 171) | def test_gradient_flow(self, scorer): method test_eval_mode (line 182) | def test_eval_mode(self, scorer): FILE: tests/test_utils.py class TestIsModuleAvailable (line 9) | class TestIsModuleAvailable: method test_detects_installed_module (line 12) | def test_detects_installed_module(self): method test_detects_missing_module (line 17) | def test_detects_missing_module(self): method test_handles_submodules (line 21) | def test_handles_submodules(self): class TestRetrievalAugmentedText (line 26) | class TestRetrievalAugmentedText: method test_with_structured_examples (line 29) | def test_with_structured_examples(self): method test_empty_examples_returns_original_text (line 42) | def test_empty_examples_returns_original_text(self): method test_includes_true_label_markers (line 51) | def test_includes_true_label_markers(self): class TestDefaultF1Reward (line 62) | class TestDefaultF1Reward: method sample_inputs (line 66) | def sample_inputs(self): method test_returns_tensor (line 77) | def test_returns_tensor(self, sample_inputs): method test_output_shape (line 83) | def test_output_shape(self, sample_inputs): method test_perfect_predictions (line 89) | def test_perfect_predictions(self): method test_zero_f1_for_wrong_predictions (line 100) | def test_zero_f1_for_wrong_predictions(self): method test_handles_valid_mask (line 111) | def test_handles_valid_mask(self): FILE: train.py class CustomTrainer (line 20) | class CustomTrainer(Trainer): method __init__ (line 23) | def __init__(self, *args, use_weighted_sampling=False, **kwargs): method _get_train_sampler (line 27) | def _get_train_sampler(self, train_dataset) -> torch.utils.data.Sampler: function compute_metrics (line 38) | def compute_metrics(p, problem_type='multi_label_classification'): function load_dataset (line 78) | def load_dataset(data_path: str) -> list: function main (line 92) | def main(args): FILE: train_rl.py function accuracy_reward (line 20) | def accuracy_reward(probs, actions, targets, valid_mask): function recall_reward (line 27) | def recall_reward( function compute_metrics (line 43) | def compute_metrics(p): function main (line 71) | def main(args):