SYMBOL INDEX (1124 symbols across 124 files) FILE: docs/_assets/js/google-analytics.js function gtag (line 8) | function gtag() { FILE: expts/dataset_benchmark.py function benchmark (line 27) | def benchmark(fn, *args, message="", log2wandb=False, **kwargs): function benchmark_dataloader (line 37) | def benchmark_dataloader(dataloader, name, n_epochs=5, log2wandb=False): function main (line 81) | def main( FILE: expts/debug_yaml.py function get_anchors_and_aliases (line 41) | def get_anchors_and_aliases(filepath): FILE: expts/main_run_get_fingerprints.py function main (line 27) | def main() -> None: FILE: expts/main_run_multitask.py function main (line 6) | def main(cfg: DictConfig) -> None: FILE: expts/main_run_predict.py function main (line 31) | def main(cfg: DictConfig) -> None: FILE: expts/main_run_test.py function main (line 25) | def main(cfg: DictConfig) -> None: FILE: expts/run_validation_test.py function main (line 37) | def main(cfg: DictConfig) -> None: FILE: graphium/cli/data.py function download (line 19) | def download(name: str, output: str, progress: bool = True): function list (line 34) | def list(): function prepare_data (line 40) | def prepare_data(overrides: List[str]) -> None: FILE: graphium/cli/finetune_utils.py function benchmark_tdc_admet_cli (line 28) | def benchmark_tdc_admet_cli( function get_fingerprints_from_model (line 80) | def get_fingerprints_from_model( function get_tdc_task_specific (line 139) | def get_tdc_task_specific(task: str, output: Literal["name", "mode", "la... FILE: graphium/cli/fingerprints.py function get_fingerprints_from_model (line 5) | def get_fingerprints_from_model(): ... FILE: graphium/cli/parameters.py function infer_parameter_count (line 26) | def infer_parameter_count(overrides: List[str] = []) -> int: FILE: graphium/cli/train_finetune_test.py function cli (line 48) | def cli(cfg: DictConfig) -> None: function get_replication_factor (line 55) | def get_replication_factor(cfg): function get_gradient_accumulation_factor (line 72) | def get_gradient_accumulation_factor(cfg): function get_training_batch_size (line 90) | def get_training_batch_size(cfg): function get_training_device_iterations (line 108) | def get_training_device_iterations(cfg): function run_training_finetuning_testing (line 125) | def run_training_finetuning_testing(cfg: DictConfig) -> None: FILE: graphium/config/_load.py function load_config (line 18) | def load_config(name: str): FILE: graphium/config/_loader.py function get_accelerator (line 49) | def get_accelerator( function _get_ipu_opts (line 82) | def _get_ipu_opts(config: Union[omegaconf.DictConfig, Dict[str, Any]]) -... function load_datamodule (line 98) | def load_datamodule( function load_metrics (line 167) | def load_metrics(config: Union[omegaconf.DictConfig, Dict[str, Any]]) ->... function load_architecture (line 192) | def load_architecture( function load_predictor (line 297) | def load_predictor( function load_mup (line 366) | def load_mup(mup_base_path: str, predictor: PredictorModule) -> Predicto... function load_trainer (line 388) | def load_trainer( function save_params_to_wandb (line 469) | def save_params_to_wandb( function load_accelerator (line 515) | def load_accelerator(config: Union[omegaconf.DictConfig, Dict[str, Any]]... function load_config_override (line 532) | def load_config_override( function load_yaml_config (line 546) | def load_yaml_config( function merge_dicts (line 579) | def merge_dicts( function get_checkpoint_path (line 626) | def get_checkpoint_path(config: Union[omegaconf.DictConfig, Dict[str, An... FILE: graphium/config/config_convert.py function recursive_config_reformating (line 17) | def recursive_config_reformating(configs): FILE: graphium/data/collate.py function graphium_collate_fn (line 30) | def graphium_collate_fn( function collage_pyg_graph (line 127) | def collage_pyg_graph(pyg_graphs: Iterable[Union[Data, Dict]], batch_siz... function pad_to_expected_label_size (line 184) | def pad_to_expected_label_size(labels: torch.Tensor, label_size: List[in... function collate_pyg_graph_labels (line 205) | def collate_pyg_graph_labels(pyg_labels: List[Data]): function get_expected_label_size (line 228) | def get_expected_label_size(label_data: Data, task: str, label_size: Lis... function collate_labels (line 243) | def collate_labels( function pad_nodepairs (line 305) | def pad_nodepairs(pe: torch.Tensor, num_nodes: int, max_num_nodes_per_gr... FILE: graphium/data/datamodule.py class BaseDataModule (line 104) | class BaseDataModule(lightning.LightningDataModule): method __init__ (line 105) | def __init__( method prepare_data (line 157) | def prepare_data(self): method setup (line 160) | def setup(self): method train_dataloader (line 163) | def train_dataloader(self, **kwargs): method val_dataloader (line 174) | def val_dataloader(self, **kwargs): method test_dataloader (line 185) | def test_dataloader(self, **kwargs): method predict_dataloader (line 196) | def predict_dataloader(self, **kwargs): method get_collate_fn (line 207) | def get_collate_fn(self, collate_fn): method is_prepared (line 218) | def is_prepared(self): method is_setup (line 222) | def is_setup(self): method num_node_feats (line 226) | def num_node_feats(self): method num_edge_feats (line 230) | def num_edge_feats(self): method predict_ds (line 234) | def predict_ds(self): method get_num_workers (line 242) | def get_num_workers(self): method predict_ds (line 254) | def predict_ds(self, value): method get_fake_graph (line 258) | def get_fake_graph(self): method _read_csv (line 264) | def _read_csv( method _get_data_file_type (line 294) | def _get_data_file_type(path): method _get_table_columns (line 319) | def _get_table_columns(path: str) -> List[str]: method _read_parquet (line 348) | def _read_parquet(path, **kwargs): method _read_sdf (line 398) | def _read_sdf(path: str, mol_col_name: str = "_rdkit_molecule_obj", **... method _glob (line 445) | def _glob(path: str) -> List[str]: method _read_table (line 457) | def _read_table(self, path: str, **kwargs) -> pd.DataFrame: method get_dataloader_kwargs (line 487) | def get_dataloader_kwargs(self, stage: RunningStage, shuffle: bool, **... method get_dataloader (line 524) | def get_dataloader(self, dataset: Dataset, shuffle: bool, stage: Runni... method _dataloader (line 539) | def _dataloader(self, dataset: Dataset, **kwargs) -> DataLoader: method get_max_num_nodes_datamodule (line 556) | def get_max_num_nodes_datamodule(self, stages: Optional[List[str]] = N... method get_max_num_edges_datamodule (line 610) | def get_max_num_edges_datamodule(self, stages: Optional[List[str]] = N... class DatasetProcessingParams (line 667) | class DatasetProcessingParams: method __init__ (line 668) | def __init__( class IPUDataModuleModifier (line 731) | class IPUDataModuleModifier: method __init__ (line 732) | def __init__( method _dataloader (line 763) | def _dataloader(self, dataset: Dataset, **kwargs) -> "poptorch.DataLoa... class MultitaskFromSmilesDataModule (line 788) | class MultitaskFromSmilesDataModule(BaseDataModule, IPUDataModuleModifier): method __init__ (line 789) | def __init__( method _parse_caching_args (line 924) | def _parse_caching_args(self, processed_graph_data_path, dataloading_f... method _get_task_key (line 945) | def _get_task_key(self, task_level: str, task: str): method get_task_levels (line 951) | def get_task_levels(self): method prepare_data (line 961) | def prepare_data(self, save_smiles_and_ids: bool = False): method setup (line 1163) | def setup( method _make_multitask_dataset (line 1219) | def _make_multitask_dataset( method _ready_to_load_all_from_file (line 1276) | def _ready_to_load_all_from_file(self) -> bool: method _path_to_load_from_file (line 1286) | def _path_to_load_from_file(self, stage: Literal["train", "val", "test... method _data_ready_at_path (line 1294) | def _data_ready_at_path(self, path: str) -> bool: method _save_data_to_files (line 1302) | def _save_data_to_files(self, save_smiles_and_ids: bool = False) -> None: method get_folder_size (line 1329) | def get_folder_size(self, path): method calculate_statistics (line 1333) | def calculate_statistics(self, dataset: Datasets.MultitaskDataset, tra... method get_label_statistics (line 1360) | def get_label_statistics( method normalize_label (line 1391) | def normalize_label(self, dataset: Datasets.MultitaskDataset, stage) -... method save_featurized_data (line 1411) | def save_featurized_data(self, dataset: Datasets.MultitaskDataset, pro... method process_func (line 1425) | def process_func(self, param): method get_dataloader_kwargs (line 1435) | def get_dataloader_kwargs(self, stage: RunningStage, shuffle: bool, **... method get_dataloader (line 1471) | def get_dataloader( method get_collate_fn (line 1504) | def get_collate_fn(self, collate_fn): method _featurize_molecules (line 1517) | def _featurize_molecules(self, smiles: Iterable[str]) -> Tuple[List, L... method _filter_none_molecules (line 1571) | def _filter_none_molecules( method _parse_label_cols (line 1620) | def _parse_label_cols( method is_prepared (line 1670) | def is_prepared(self): method is_setup (line 1676) | def is_setup(self): method num_node_feats (line 1682) | def num_node_feats(self): method in_dims (line 1689) | def in_dims(self): method num_edge_feats (line 1711) | def num_edge_feats(self): method get_fake_graph (line 1720) | def get_fake_graph(self): method _save_to_cache (line 1737) | def _save_to_cache(self): method _load_from_cache (line 1740) | def _load_from_cache(self): method _extract_smiles_labels (line 1743) | def _extract_smiles_labels( method _get_split_indices (line 1841) | def _get_split_indices( method _sub_sample_df (line 1928) | def _sub_sample_df( method get_data_hash (line 1954) | def get_data_hash(self): method get_data_cache_fullname (line 1983) | def get_data_cache_fullname(self, compress: bool = False) -> str: method load_data_from_cache (line 2000) | def load_data_from_cache(self, verbose: bool = True, compress: bool = ... method get_subsets_of_datasets (line 2056) | def get_subsets_of_datasets( method __len__ (line 2085) | def __len__(self) -> int: method to_dict (line 2100) | def to_dict(self) -> Dict[str, Any]: method __repr__ (line 2125) | def __repr__(self) -> str: class GraphOGBDataModule (line 2135) | class GraphOGBDataModule(MultitaskFromSmilesDataModule): method __init__ (line 2136) | def __init__( method to_dict (line 2234) | def to_dict(self) -> Dict[str, Any]: method _load_dataset (line 2248) | def _load_dataset( method _get_dataset_metadata (line 2340) | def _get_dataset_metadata(self, dataset_name: str) -> Dict[str, Any]: method _get_ogb_metadata (line 2347) | def _get_ogb_metadata(self): class ADMETBenchmarkDataModule (line 2367) | class ADMETBenchmarkDataModule(MultitaskFromSmilesDataModule): method __init__ (line 2395) | def __init__( method _get_task_specific_arguments (line 2478) | def _get_task_specific_arguments(self, name: str, seed: int, cache_dir... class FakeDataModule (line 2536) | class FakeDataModule(MultitaskFromSmilesDataModule): method __init__ (line 2544) | def __init__( method generate_data (line 2576) | def generate_data(self, label_cols: List[str], smiles_col: str): method prepare_data (line 2598) | def prepare_data(self): method setup (line 2700) | def setup(self, stage=None): method get_fake_graph (line 2730) | def get_fake_graph(self): FILE: graphium/data/dataset.py class SingleTaskDataset (line 32) | class SingleTaskDataset(Dataset): method __init__ (line 33) | def __init__( method __len__ (line 87) | def __len__(self): method __getitem__ (line 95) | def __getitem__(self, idx): method __getstate__ (line 128) | def __getstate__(self): method __setstate__ (line 140) | def __setstate__(self, state: dict): class MultitaskDataset (line 149) | class MultitaskDataset(Dataset): method __init__ (line 152) | def __init__( method transfer_from_disk_to_ram (line 231) | def transfer_from_disk_to_ram(self, parallel_with_batches: bool = False): method save_metadata (line 271) | def save_metadata(self, directory: str): method _load_metadata (line 290) | def _load_metadata(self): method __len__ (line 327) | def __len__(self): method num_nodes_list (line 334) | def num_nodes_list(self): method num_edges_list (line 346) | def num_edges_list(self): method num_graphs_total (line 358) | def num_graphs_total(self): method num_nodes_total (line 365) | def num_nodes_total(self): method max_num_nodes_per_graph (line 372) | def max_num_nodes_per_graph(self): method std_num_nodes_per_graph (line 379) | def std_num_nodes_per_graph(self): method min_num_nodes_per_graph (line 386) | def min_num_nodes_per_graph(self): method mean_num_nodes_per_graph (line 393) | def mean_num_nodes_per_graph(self): method num_edges_total (line 400) | def num_edges_total(self): method max_num_edges_per_graph (line 407) | def max_num_edges_per_graph(self): method min_num_edges_per_graph (line 414) | def min_num_edges_per_graph(self): method std_num_edges_per_graph (line 421) | def std_num_edges_per_graph(self): method mean_num_edges_per_graph (line 428) | def mean_num_edges_per_graph(self): method __getitem__ (line 434) | def __getitem__(self, idx): method load_graph_from_index (line 464) | def load_graph_from_index(self, data_idx): method merge (line 479) | def merge( method _get_all_lists_ids (line 531) | def _get_all_lists_ids(self, datasets: Dict[str, SingleTaskDataset]) -... method _get_inv_of_mol_ids (line 578) | def _get_inv_of_mol_ids(self, all_mol_ids): method _find_valid_label (line 582) | def _find_valid_label(self, task, ds): method set_label_size_dict (line 597) | def set_label_size_dict(self, datasets: Dict[str, SingleTaskDataset]): method set_label_dtype_dict (line 615) | def set_label_dtype_dict(self, datasets: Dict[str, SingleTaskDataset]): method __repr__ (line 630) | def __repr__(self) -> str: class FakeDataset (line 673) | class FakeDataset(MultitaskDataset): method __init__ (line 678) | def __init__( method _get_inv_of_mol_ids (line 715) | def _get_inv_of_mol_ids(self, all_mol_ids): method deepcopy_mol (line 722) | def deepcopy_mol(self, mol_ids, labels, smiles, features=None): method __len__ (line 747) | def __len__(self): method __getitem__ (line 753) | def __getitem__(self, idx): function get_num_nodes_per_graph (line 774) | def get_num_nodes_per_graph(graphs): function get_num_edges_per_graph (line 784) | def get_num_edges_per_graph(graphs): FILE: graphium/data/multilevel_utils.py function extract_labels (line 22) | def extract_labels(df: pd.DataFrame, task_level: str, label_cols: List[s... FILE: graphium/data/normalization.py class LabelNormalization (line 21) | class LabelNormalization: method __init__ (line 22) | def __init__( method calculate_statistics (line 59) | def calculate_statistics(self, array): method normalize (line 74) | def normalize(self, input): method denormalize (line 100) | def denormalize(self, input): FILE: graphium/data/sampler.py class DatasetSubSampler (line 24) | class DatasetSubSampler(data_utils.Sampler): method __init__ (line 25) | def __init__( method __iter__ (line 60) | def __iter__(self): method __len__ (line 69) | def __len__(self): method check_sampling_required (line 73) | def check_sampling_required(cls, sampler_task_dict): FILE: graphium/data/sdf2csv.py function extract_zip (line 24) | def extract_zip(fname): function extract_mols_from_sdf (line 32) | def extract_mols_from_sdf(fname): function mols2cxs (line 41) | def mols2cxs(mols): function write_csv (line 51) | def write_csv(cxs: any, homos: any, fname: str): function sdf2csv (line 66) | def sdf2csv(sdf_name: str = "pcqm4m-v2-train", outname: str = "pcqm4m-v2... FILE: graphium/data/smiles_transform.py function smiles_to_unique_mol_id (line 20) | def smiles_to_unique_mol_id(smiles: str) -> Optional[str]: function did_featurization_fail (line 40) | def did_featurization_fail(features: Any) -> bool: class BatchingSmilesTransform (line 47) | class BatchingSmilesTransform: method __init__ (line 52) | def __init__(self, transform: Callable): method __call__ (line 59) | def __call__(self, smiles_list: Iterable[str]) -> Any: method parse_batch_size (line 69) | def parse_batch_size(numel: int, desired_batch_size: int, n_jobs: int)... function smiles_to_unique_mol_ids (line 91) | def smiles_to_unique_mol_ids( FILE: graphium/data/utils.py function load_micro_zinc (line 36) | def load_micro_zinc() -> pd.DataFrame: function load_tiny_zinc (line 49) | def load_tiny_zinc() -> pd.DataFrame: function graphium_package_path (line 62) | def graphium_package_path(graphium_path: str) -> str: function list_graphium_datasets (line 76) | def list_graphium_datasets() -> set: function download_graphium_dataset (line 85) | def download_graphium_dataset( function get_keys (line 124) | def get_keys(pyg_data): function found_size_mismatch (line 131) | def found_size_mismatch(task: str, features: Union[Data, GraphDict], lab... FILE: graphium/features/commute.py function compute_commute_distances (line 22) | def compute_commute_distances( FILE: graphium/features/electrostatic.py function compute_electrostatic_interactions (line 22) | def compute_electrostatic_interactions( FILE: graphium/features/featurizer.py function to_dense_array (line 33) | def to_dense_array(array: np.ndarray, dtype: str = None) -> np.ndarray: function to_dense_tensor (line 53) | def to_dense_tensor(tensor: Tensor, dtype: str = None) -> Tensor: function _mask_nans_inf (line 70) | def _mask_nans_inf(mask_nan: Optional[str], array: np.ndarray, array_nam... function get_mol_atomic_features_onehot (line 103) | def get_mol_atomic_features_onehot(mol: dm.Mol, property_list: List[str]... function get_mol_conformer_features (line 194) | def get_mol_conformer_features( function get_mol_atomic_features_float (line 247) | def get_mol_atomic_features_float( function get_simple_mol_conformer (line 444) | def get_simple_mol_conformer(mol: dm.Mol) -> Union[Chem.rdchem.Conformer... function get_estimated_bond_length (line 481) | def get_estimated_bond_length(bond: Chem.rdchem.Bond, mol: dm.Mol) -> fl... function get_mol_edge_features (line 544) | def get_mol_edge_features( function mol_to_adj_and_features (line 634) | def mol_to_adj_and_features( function mol_to_adjacency_matrix (line 791) | def mol_to_adjacency_matrix( class GraphDict (line 856) | class GraphDict(dict): method __init__ (line 857) | def __init__( method keys (line 903) | def keys(self): method values (line 907) | def values(self): method make_pyg_graph (line 910) | def make_pyg_graph(self, **kwargs) -> Data: method adj (line 947) | def adj(self): method dtype (line 951) | def dtype(self): method mask_nan (line 955) | def mask_nan(self): method num_nodes (line 959) | def num_nodes(self) -> int: method num_edges (line 963) | def num_edges(self) -> int: function mol_to_graph_dict (line 970) | def mol_to_graph_dict( function mol_to_pyggraph (line 1138) | def mol_to_pyggraph( function mol_to_graph_signature (line 1252) | def mol_to_graph_signature(featurizer_args: Dict[str, Any] = None) -> Di... FILE: graphium/features/graphormer.py function compute_graphormer_distances (line 22) | def compute_graphormer_distances( FILE: graphium/features/nmp.py function float_or_none (line 29) | def float_or_none(string: str) -> Union[float, None]: FILE: graphium/features/positional_encoding.py function get_all_positional_encodings (line 29) | def get_all_positional_encodings( function graph_positional_encoder (line 76) | def graph_positional_encoder( FILE: graphium/features/properties.py function get_prop_or_none (line 23) | def get_prop_or_none( function get_props_from_mol (line 43) | def get_props_from_mol( FILE: graphium/features/rw.py function compute_rwse (line 25) | def compute_rwse( function get_Pks (line 122) | def get_Pks( FILE: graphium/features/spectral.py function compute_laplacian_pe (line 24) | def compute_laplacian_pe( function _get_positional_eigvecs (line 122) | def _get_positional_eigvecs( function normalize_matrix (line 158) | def normalize_matrix( FILE: graphium/features/transfer_pos_level.py function transfer_pos_level (line 23) | def transfer_pos_level( function node_to_edge (line 110) | def node_to_edge( function node_to_nodepair (line 150) | def node_to_nodepair(pe: np.ndarray, num_nodes: int) -> np.ndarray: function node_to_graph (line 174) | def node_to_graph(pe: np.ndarray, num_nodes: int) -> np.ndarray: function edge_to_node (line 190) | def edge_to_node(pe: np.ndarray, adj: Union[np.ndarray, spmatrix]) -> np... function edge_to_nodepair (line 206) | def edge_to_nodepair( function edge_to_graph (line 246) | def edge_to_graph(pe: np.ndarray) -> np.ndarray: function nodepair_to_node (line 260) | def nodepair_to_node(pe: np.ndarray, stats_list: List = [np.min, np.mean... function nodepair_to_edge (line 286) | def nodepair_to_edge( function nodepair_to_graph (line 325) | def nodepair_to_graph(pe: np.ndarray, num_nodes: int) -> np.ndarray: function graph_to_node (line 341) | def graph_to_node( FILE: graphium/finetuning/finetuning.py class GraphFinetuning (line 25) | class GraphFinetuning(BaseFinetuning): method __init__ (line 26) | def __init__( method freeze_before_training (line 55) | def freeze_before_training(self, pl_module: pl.LightningModule): method freeze_module (line 73) | def freeze_module(self, pl_module, module_name: str, module_map: Dict[... method finetune_function (line 96) | def finetune_function(self, pl_module: pl.LightningModule, epoch: int,... FILE: graphium/finetuning/finetuning_architecture.py class FullGraphFinetuningNetwork (line 26) | class FullGraphFinetuningNetwork(nn.Module, MupMixin): method __init__ (line 27) | def __init__( method forward (line 98) | def forward(self, g: Batch) -> Tensor: method make_mup_base_kwargs (line 139) | def make_mup_base_kwargs(self, divide_factor: float = 2.0) -> Dict[str... method set_max_num_nodes_edges_per_graph (line 173) | def set_max_num_nodes_edges_per_graph(self, max_nodes: Optional[int], ... class PretrainedModel (line 188) | class PretrainedModel(nn.Module, MupMixin): method __init__ (line 189) | def __init__( method forward (line 226) | def forward(self, g: Union[torch.Tensor, Batch]): method overwrite_with_pretrained (line 231) | def overwrite_with_pretrained( method make_mup_base_kwargs (line 281) | def make_mup_base_kwargs(self, divide_factor: float = 2.0) -> Dict[str... class FinetuningHead (line 299) | class FinetuningHead(nn.Module, MupMixin): method __init__ (line 300) | def __init__(self, finetuning_head_kwargs: Dict[str, Any]): method forward (line 320) | def forward(self, g: Union[Dict[str, Union[torch.Tensor, Batch]], torc... method make_mup_base_kwargs (line 332) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d... FILE: graphium/finetuning/fingerprinting.py class Fingerprinter (line 25) | class Fingerprinter: method __init__ (line 89) | def __init__( method setup (line 133) | def setup(self): method get_fingerprints_for_batch (line 140) | def get_fingerprints_for_batch(self, batch): method get_fingerprints_for_dataset (line 165) | def get_fingerprints_for_dataset(self, dataloader): method teardown (line 181) | def teardown(self): method __enter__ (line 190) | def __enter__(self): method __exit__ (line 194) | def __exit__(self, exc_type, exc_val, exc_tb): method _convert_output_type (line 200) | def _convert_output_type(self, feats: torch.Tensor): FILE: graphium/finetuning/utils.py function filter_cfg_based_on_admet_benchmark_name (line 24) | def filter_cfg_based_on_admet_benchmark_name(config: Dict[str, Any], nam... function modify_cfg_for_finetuning (line 59) | def modify_cfg_for_finetuning(cfg: Dict[str, Any]): function update_cfg_arch_for_module (line 179) | def update_cfg_arch_for_module( FILE: graphium/hyper_param_search/results.py function extract_main_metric_for_hparam_search (line 17) | def extract_main_metric_for_hparam_search(results: dict, cfg: dict): FILE: graphium/ipu/ipu_dataloader.py class IPUDataloaderOptions (line 37) | class IPUDataloaderOptions: method set_kwargs (line 56) | def set_kwargs(self): class CombinedBatchingCollator (line 96) | class CombinedBatchingCollator: method __init__ (line 106) | def __init__( method __call__ (line 132) | def __call__( function create_ipu_dataloader (line 195) | def create_ipu_dataloader( class Pad (line 296) | class Pad(BaseTransform): method __init__ (line 301) | def __init__( method validate (line 333) | def validate(self, data): method __call__ (line 357) | def __call__(self, batch: Batch) -> Batch: method forward (line 360) | def forward(self, batch: Batch) -> Batch: method _call (line 363) | def _call(self, batch: Batch) -> Batch: method __repr__ (line 427) | def __repr__(self) -> str: FILE: graphium/ipu/ipu_losses.py class BCEWithLogitsLossIPU (line 22) | class BCEWithLogitsLossIPU(BCEWithLogitsLoss): method forward (line 29) | def forward(self, input: Tensor, target: Tensor) -> Tensor: class BCELossIPU (line 67) | class BCELossIPU(BCELoss): method forward (line 74) | def forward(self, input: Tensor, target: Tensor) -> Tensor: class MSELossIPU (line 112) | class MSELossIPU(MSELoss): method forward (line 120) | def forward(self, input: Tensor, target: Tensor) -> Tensor: class L1LossIPU (line 140) | class L1LossIPU(L1Loss): method forward (line 148) | def forward(self, input: Tensor, target: Tensor) -> Tensor: class HybridCELossIPU (line 167) | class HybridCELossIPU(HybridCELoss): method __init__ (line 168) | def __init__( method forward (line 180) | def forward(self, input: Tensor, target: Tensor) -> Tensor: FILE: graphium/ipu/ipu_metrics.py function auroc_ipu (line 36) | def auroc_ipu( function average_precision_ipu (line 81) | def average_precision_ipu( function precision_ipu (line 128) | def precision_ipu( function recall_ipu (line 162) | def recall_ipu( function accuracy_ipu (line 195) | def accuracy_ipu( function get_confusion_matrix (line 321) | def get_confusion_matrix( class NaNTensor (line 451) | class NaNTensor(Tensor): method get_nans (line 461) | def get_nans(self) -> BoolTensor: method sum (line 474) | def sum(self, *args, **kwargs) -> Tensor: method mean (line 486) | def mean(self, *args, **kwargs) -> Tensor: method numel (line 494) | def numel(self) -> int: method min (line 500) | def min(self, *args, **kwargs) -> Tensor: method max (line 508) | def max(self, *args, **kwargs) -> Tensor: method argsort (line 516) | def argsort(self, dim=-1, descending=False) -> IntTensor: method size (line 527) | def size(self, dim) -> Tensor: method __lt__ (line 534) | def __lt__(self, other) -> Tensor: method __torch_function__ (line 546) | def __torch_function__(cls, func, types, args=(), kwargs=None): function pearson_ipu (line 569) | def pearson_ipu(preds, target): function spearman_ipu (line 585) | def spearman_ipu(preds, target): function _rank_data (line 605) | def _rank_data(data: Tensor) -> Tensor: function r2_score_ipu (line 626) | def r2_score_ipu(preds, target, *args, **kwargs) -> Tensor: function fbeta_score_ipu (line 659) | def fbeta_score_ipu( function f1_score_ipu (line 801) | def f1_score_ipu( function mean_squared_error_ipu (line 848) | def mean_squared_error_ipu(preds: Tensor, target: Tensor, squared: bool)... function mean_absolute_error_ipu (line 882) | def mean_absolute_error_ipu(preds: Tensor, target: Tensor) -> Tensor: FILE: graphium/ipu/ipu_simple_lightning.py class SimpleTorchModel (line 35) | class SimpleTorchModel(torch.nn.Module): method __init__ (line 36) | def __init__(self, in_dim, hidden_dim, kernel_size, num_classes): method make_mup_base_kwargs (line 60) | def make_mup_base_kwargs(self, divide_factor: float = 2.0): method forward (line 68) | def forward(self, x): class SimpleLightning (line 75) | class SimpleLightning(lightning.LightningModule): method __init__ (line 76) | def __init__(self, in_dim, hidden_dim, kernel_size, num_classes, on_ipu): method training_step (line 83) | def training_step(self, batch, _): method validation_step (line 89) | def validation_step(self, batch, _): method on_train_batch_end (line 99) | def on_train_batch_end(self, outputs, batch, batch_idx): method validation_epoch_end (line 102) | def validation_epoch_end(self, outputs): method configure_optimizers (line 109) | def configure_optimizers(self): FILE: graphium/ipu/ipu_utils.py function import_poptorch (line 23) | def import_poptorch(raise_error=True) -> Optional[ModuleType]: function is_running_on_ipu (line 47) | def is_running_on_ipu() -> bool: function load_ipu_options (line 57) | def load_ipu_options( function ipu_options_list_to_file (line 145) | def ipu_options_list_to_file(ipu_opts: Optional[List[str]]) -> tempfile.... FILE: graphium/ipu/ipu_wrapper.py class PyGArgsParser (line 36) | class PyGArgsParser(poptorch.ICustomArgParser): method sortedTensorKeys (line 45) | def sortedTensorKeys(struct: BaseData) -> Iterable[str]: method yieldTensors (line 57) | def yieldTensors(self, struct: BaseData): method reconstruct (line 64) | def reconstruct(self, original_structure: BaseData, tensor_iterator: I... class PredictorModuleIPU (line 92) | class PredictorModuleIPU(PredictorModule): method __init__ (line 97) | def __init__(self, *args, **kwargs): method compute_loss (line 103) | def compute_loss( method on_train_batch_end (line 115) | def on_train_batch_end(self, outputs, batch, batch_idx): method training_step (line 120) | def training_step(self, batch, batch_idx) -> Dict[str, Any]: method validation_step (line 130) | def validation_step(self, batch, batch_idx) -> Dict[str, Any]: method test_step (line 138) | def test_step(self, batch, batch_idx) -> Dict[str, Any]: method predict_step (line 147) | def predict_step(self, **inputs) -> Dict[str, Any]: method on_validation_batch_end (line 154) | def on_validation_batch_end( method evaluation_epoch_end (line 161) | def evaluation_epoch_end(self, outputs: Any): method on_test_batch_end (line 165) | def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int, ... method configure_optimizers (line 169) | def configure_optimizers(self, impl=None): method squeeze_input_dims (line 180) | def squeeze_input_dims(self, features, labels): method convert_from_fp16 (line 190) | def convert_from_fp16(self, data: Any) -> Any: method _convert_features_dtype (line 204) | def _convert_features_dtype(self, feats): method precision_to_dtype (line 222) | def precision_to_dtype(self, precision): method get_num_graphs (line 225) | def get_num_graphs(self, data: Batch): FILE: graphium/ipu/to_dense_batch.py function to_sparse_batch (line 21) | def to_sparse_batch(x: Tensor, mask_idx: Tensor): function to_sparse_batch_from_packed (line 28) | def to_sparse_batch_from_packed(x: Tensor, pack_from_node_idx: Tensor): function to_dense_batch (line 35) | def to_dense_batch( function to_packed_dense_batch (line 143) | def to_packed_dense_batch( FILE: graphium/nn/architectures/encoder_manager.py class EncoderManager (line 42) | class EncoderManager(nn.Module): method __init__ (line 43) | def __init__( method _initialize_positional_encoders (line 77) | def _initialize_positional_encoders(self, pe_encoders_kwargs: Dict[str... method forward (line 156) | def forward(self, g: Batch) -> Batch: method forward_positional_encoding (line 193) | def forward_positional_encoding(self, g: Batch) -> Dict[str, Tensor]: method forward_simple_pooling (line 231) | def forward_simple_pooling(self, h: Tensor, pooling: str, dim: int) ->... method make_mup_base_kwargs (line 253) | def make_mup_base_kwargs(self, divide_factor: float = 2.0) -> Dict[str... method input_keys (line 281) | def input_keys(self) -> Iterable[str]: method in_dims (line 294) | def in_dims(self) -> Iterable[int]: method out_dim (line 307) | def out_dim(self) -> int: FILE: graphium/nn/architectures/global_architectures.py class FeedForwardNN (line 49) | class FeedForwardNN(nn.Module, MupMixin): method __init__ (line 50) | def __init__( method _parse_layers (line 187) | def _parse_layers(self, layer_type, residual_type): method _check_bad_arguments (line 194) | def _check_bad_arguments(self): method _parse_class_from_dict (line 204) | def _parse_class_from_dict( method _create_residual_connection (line 221) | def _create_residual_connection(self, out_dims: List[int]) -> Tuple[Re... method _create_layers (line 248) | def _create_layers(self): method cache_readouts (line 295) | def cache_readouts(self) -> bool: method _enable_readout_cache (line 299) | def _enable_readout_cache(self): method _disable_readout_cache (line 307) | def _disable_readout_cache(self): method drop_layers (line 311) | def drop_layers(self, depth: int) -> None: method add_layers (line 322) | def add_layers(self, layers: int) -> None: method forward (line 333) | def forward(self, h: torch.Tensor) -> torch.Tensor: method get_init_kwargs (line 367) | def get_init_kwargs(self) -> Dict[str, Any]: method make_mup_base_kwargs (line 393) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d... method __repr__ (line 411) | def __repr__(self): class EnsembleFeedForwardNN (line 421) | class EnsembleFeedForwardNN(FeedForwardNN): method __init__ (line 422) | def __init__( method _create_layers (line 583) | def _create_layers(self): method _parse_num_ensemble (line 587) | def _parse_num_ensemble(self, num_ensemble: int, layer_kwargs) -> int: method _parse_reduction (line 616) | def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) ... method _parse_subset_in_dim (line 652) | def _parse_subset_in_dim( method _parse_layers (line 703) | def _parse_layers(self, layer_type, residual_type): method forward (line 710) | def forward(self, h: torch.Tensor) -> torch.Tensor: method get_init_kwargs (line 748) | def get_init_kwargs(self) -> Dict[str, Any]: method __repr__ (line 757) | def __repr__(self): class FeedForwardGraph (line 767) | class FeedForwardGraph(FeedForwardNN): method __init__ (line 768) | def __init__( method _check_bad_arguments (line 960) | def _check_bad_arguments(self): method get_nested_key (line 970) | def get_nested_key(self, d, target_key): method _create_layers (line 990) | def _create_layers(self): method _graph_layer_forward (line 1089) | def _graph_layer_forward( method _parse_virtual_node_class (line 1177) | def _parse_virtual_node_class(self) -> type: method _virtual_node_forward (line 1180) | def _virtual_node_forward( method forward (line 1228) | def forward(self, g: Batch) -> torch.Tensor: method get_init_kwargs (line 1290) | def get_init_kwargs(self) -> Dict[str, Any]: method make_mup_base_kwargs (line 1305) | def make_mup_base_kwargs( method __repr__ (line 1345) | def __repr__(self): class FullGraphMultiTaskNetwork (line 1355) | class FullGraphMultiTaskNetwork(nn.Module, MupMixin): method __init__ (line 1356) | def __init__( method _parse_feed_forward_gnn (line 1482) | def _parse_feed_forward_gnn( method _check_bad_arguments (line 1496) | def _check_bad_arguments(self): method _apply_ipu_options (line 1529) | def _apply_ipu_options(self, ipu_kwargs): method _apply_ipu_pipeline_split (line 1533) | def _apply_ipu_pipeline_split(self, gnn_layers_per_ipu): method _enable_readout_cache (line 1569) | def _enable_readout_cache(self, module_filter: Optional[Union[str, Lis... method _disable_readout_cache (line 1596) | def _disable_readout_cache(self): method create_module_map (line 1604) | def create_module_map(self, level: Union[Literal["layers"], Literal["m... method forward (line 1650) | def forward(self, g: Batch) -> Tensor: method make_mup_base_kwargs (line 1713) | def make_mup_base_kwargs(self, divide_factor: float = 2.0) -> Dict[str... method set_max_num_nodes_edges_per_graph (line 1781) | def set_max_num_nodes_edges_per_graph(self, max_nodes: Optional[int], ... method __repr__ (line 1806) | def __repr__(self) -> str: method in_dim (line 1828) | def in_dim(self) -> int: method out_dim (line 1838) | def out_dim(self) -> int: method out_dim_edges (line 1845) | def out_dim_edges(self) -> int: method in_dim_edges (line 1855) | def in_dim_edges(self) -> int: class GraphOutputNN (line 1862) | class GraphOutputNN(nn.Module, MupMixin): method __init__ (line 1863) | def __init__( method forward (line 1921) | def forward(self, g: Batch): method _parse_pooling_layer (line 1953) | def _parse_pooling_layer( method _pool_layer_forward (line 1986) | def _pool_layer_forward( method compute_nodepairs (line 2017) | def compute_nodepairs( method make_mup_base_kwargs (line 2059) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d... method drop_graph_output_nn_layers (line 2082) | def drop_graph_output_nn_layers(self, num_layers_to_drop: int) -> None: method extend_graph_output_nn_layers (line 2095) | def extend_graph_output_nn_layers(self, layers: nn.ModuleList): method set_max_num_nodes_edges_per_graph (line 2108) | def set_max_num_nodes_edges_per_graph(self, max_nodes: Optional[int], ... method concat_last_layers (line 2123) | def concat_last_layers(self) -> Optional[Iterable[int]]: method concat_last_layers (line 2136) | def concat_last_layers(self, value: Union[Type[None], int, Iterable[in... method out_dim (line 2155) | def out_dim(self) -> int: class TaskHeads (line 2162) | class TaskHeads(nn.Module, MupMixin): method __init__ (line 2163) | def __init__( method forward (line 2215) | def forward(self, g: Batch) -> Dict[str, torch.Tensor]: method make_mup_base_kwargs (line 2234) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d... method set_max_num_nodes_edges_per_graph (line 2269) | def set_max_num_nodes_edges_per_graph(self, max_nodes: Optional[int], ... method out_dim (line 2292) | def out_dim(self) -> Dict[str, int]: method __repr__ (line 2298) | def __repr__(self): method _check_bad_arguments (line 2307) | def _check_bad_arguments(self): FILE: graphium/nn/architectures/pyg_architectures.py class FeedForwardPyg (line 25) | class FeedForwardPyg(FeedForwardGraph): method _graph_layer_forward (line 26) | def _graph_layer_forward( method _parse_virtual_node_class (line 114) | def _parse_virtual_node_class(self) -> type: method _parse_pooling_layer (line 117) | def _parse_pooling_layer( FILE: graphium/nn/base_graph_layer.py class BaseGraphStructure (line 28) | class BaseGraphStructure: method __init__ (line 29) | def __init__( method _initialize_activation_dropout_norm (line 92) | def _initialize_activation_dropout_norm(self): method _parse_dropout (line 105) | def _parse_dropout(self, dropout): method _parse_droppath (line 112) | def _parse_droppath(self, droppath_rate): method _parse_norm (line 119) | def _parse_norm(self, normalization, dim=None): method apply_norm_activation_dropout (line 136) | def apply_norm_activation_dropout( method layer_supports_edges (line 191) | def layer_supports_edges(cls) -> bool: method layer_inputs_edges (line 205) | def layer_inputs_edges(self) -> bool: method layer_outputs_edges (line 221) | def layer_outputs_edges(self) -> bool: method out_dim_factor (line 237) | def out_dim_factor(self) -> int: method max_num_nodes_per_graph (line 258) | def max_num_nodes_per_graph(self) -> Optional[int]: method max_num_nodes_per_graph (line 265) | def max_num_nodes_per_graph(self, value: Optional[int]): method max_num_edges_per_graph (line 276) | def max_num_edges_per_graph(self) -> Optional[int]: method max_num_edges_per_graph (line 283) | def max_num_edges_per_graph(self, value: Optional[int]): method __repr__ (line 293) | def __repr__(self): class BaseGraphModule (line 302) | class BaseGraphModule(BaseGraphStructure, nn.Module): method __init__ (line 303) | def __init__( function check_intpus_allow_int (line 366) | def check_intpus_allow_int(obj, edge_index, size): FILE: graphium/nn/base_layers.py function get_activation (line 44) | def get_activation(activation: Union[type(None), str, Callable]) -> Opti... function get_activation_str (line 72) | def get_activation_str(activation: Union[type(None), str, Callable]) -> ... function get_norm (line 96) | def get_norm(normalization: Union[Type[None], str, Callable], dim: Optio... class MultiheadAttentionMup (line 124) | class MultiheadAttentionMup(nn.MultiheadAttention): method __init__ (line 132) | def __init__(self, biased_attention, **kwargs): method _reset_parameters (line 136) | def _reset_parameters(self): method forward (line 153) | def forward( class TransformerEncoderLayerMup (line 208) | class TransformerEncoderLayerMup(nn.TransformerEncoderLayer): method __init__ (line 215) | def __init__(self, biased_attention, *args, **kwargs) -> None: class MuReadoutGraphium (line 238) | class MuReadoutGraphium(MuReadout): method __init__ (line 250) | def __init__(self, in_features, *args, **kwargs): method absolute_width (line 255) | def absolute_width(self): method base_width (line 259) | def base_width(self): method base_width (line 263) | def base_width(self, val): method width_mult (line 271) | def width_mult(self): class FCLayer (line 275) | class FCLayer(nn.Module): method __init__ (line 276) | def __init__( method reset_parameters (line 380) | def reset_parameters(self, init_fn=None): method forward (line 391) | def forward(self, h: torch.Tensor) -> torch.Tensor: method in_channels (line 435) | def in_channels(self) -> int: method out_channels (line 442) | def out_channels(self) -> int: method __repr__ (line 448) | def __repr__(self): class MLP (line 452) | class MLP(nn.Module): method __init__ (line 453) | def __init__( method forward (line 585) | def forward(self, h: torch.Tensor) -> torch.Tensor: method in_features (line 609) | def in_features(self): method __getitem__ (line 612) | def __getitem__(self, idx: int) -> nn.Module: method __repr__ (line 615) | def __repr__(self): class GRU (line 622) | class GRU(nn.Module): method __init__ (line 623) | def __init__(self, in_dim: int, hidden_dim: int): method forward (line 639) | def forward(self, x, y): class DropPath (line 673) | class DropPath(nn.Module): method __init__ (line 674) | def __init__(self, drop_rate: float): method get_stochastic_drop_rate (line 690) | def get_stochastic_drop_rate( method forward (line 717) | def forward( FILE: graphium/nn/encoders/base_encoder.py class BaseEncoder (line 11) | class BaseEncoder(torch.nn.Module, MupMixin): method __init__ (line 12) | def __init__( method parse_input_keys_with_prefix (line 52) | def parse_input_keys_with_prefix(self, key_prefix): method forward (line 66) | def forward(self, graph: Batch, key_prefix=None) -> Dict[str, torch.Te... method parse_input_keys (line 76) | def parse_input_keys(self, input_keys: List[str]) -> List[str]: method parse_output_keys (line 85) | def parse_output_keys(self, output_keys: List[str]) -> List[str]: method make_mup_base_kwargs (line 93) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d... FILE: graphium/nn/encoders/bessel_pos_encoder.py class BesselSphericalPosEncoder (line 16) | class BesselSphericalPosEncoder(BaseEncoder): method __init__ (line 17) | def __init__( method forward (line 88) | def forward(self, batch: Batch, key_prefix: Optional[str] = None) -> D... method make_mup_base_kwargs (line 148) | def make_mup_base_kwargs( method parse_input_keys (line 176) | def parse_input_keys( method parse_output_keys (line 198) | def parse_output_keys( FILE: graphium/nn/encoders/gaussian_kernel_pos_encoder.py class GaussianKernelPosEncoder (line 9) | class GaussianKernelPosEncoder(BaseEncoder): method __init__ (line 10) | def __init__( method parse_input_keys (line 66) | def parse_input_keys( method parse_output_keys (line 92) | def parse_output_keys( method forward (line 108) | def forward(self, batch: Batch, key_prefix: Optional[str] = None) -> D... method make_mup_base_kwargs (line 141) | def make_mup_base_kwargs( FILE: graphium/nn/encoders/laplace_pos_encoder.py class LapPENodeEncoder (line 10) | class LapPENodeEncoder(BaseEncoder): method __init__ (line 11) | def __init__( method parse_input_keys (line 120) | def parse_input_keys( method parse_output_keys (line 145) | def parse_output_keys( method forward (line 168) | def forward( method make_mup_base_kwargs (line 228) | def make_mup_base_kwargs( FILE: graphium/nn/encoders/mlp_encoder.py class MLPEncoder (line 10) | class MLPEncoder(BaseEncoder): method __init__ (line 11) | def __init__( method parse_input_keys (line 72) | def parse_input_keys( method parse_output_keys (line 85) | def parse_output_keys( method forward (line 114) | def forward( method make_mup_base_kwargs (line 141) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d... class CatMLPEncoder (line 164) | class CatMLPEncoder(BaseEncoder): method __init__ (line 165) | def __init__( method parse_input_keys (line 230) | def parse_input_keys( method parse_output_keys (line 243) | def parse_output_keys( method forward (line 257) | def forward( method make_mup_base_kwargs (line 284) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d... FILE: graphium/nn/encoders/signnet_pos_encoder.py class SimpleGIN (line 18) | class SimpleGIN(nn.Module): method __init__ (line 19) | def __init__( method forward (line 71) | def forward(self, x, edge_index): class GINDeepSigns (line 77) | class GINDeepSigns(nn.Module): method __init__ (line 82) | def __init__( method forward (line 117) | def forward(self, x, edge_index, batch_index): class MaskedGINDeepSigns (line 126) | class MaskedGINDeepSigns(nn.Module): method __init__ (line 131) | def __init__( method batched_n_nodes (line 164) | def batched_n_nodes(self, batch_index): method forward (line 173) | def forward(self, x, edge_index, batch_index): class SignNetNodeEncoder (line 189) | class SignNetNodeEncoder(BaseEncoder): method __init__ (line 207) | def __init__( method parse_input_keys (line 283) | def parse_input_keys(self, input_keys): method parse_output_keys (line 295) | def parse_output_keys(self, output_keys): method forward (line 305) | def forward(self, batch: Batch, key_prefix: Optional[str] = None) -> D... method make_mup_base_kwargs (line 320) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d... FILE: graphium/nn/ensemble_layers.py class EnsembleLinear (line 26) | class EnsembleLinear(nn.Module): method __init__ (line 27) | def __init__( method reset_parameters (line 61) | def reset_parameters(self): method forward (line 73) | def forward(self, h: torch.Tensor) -> torch.Tensor: class EnsembleFCLayer (line 98) | class EnsembleFCLayer(FCLayer): method __init__ (line 99) | def __init__( method reset_parameters (line 189) | def reset_parameters(self, init_fn=None): method __repr__ (line 196) | def __repr__(self): class EnsembleMuReadoutGraphium (line 202) | class EnsembleMuReadoutGraphium(EnsembleLinear): method __init__ (line 208) | def __init__( method reset_parameters (line 230) | def reset_parameters(self) -> None: method width_mult (line 238) | def width_mult(self): method _rescale_parameters (line 246) | def _rescale_parameters(self): method forward (line 266) | def forward(self, x): method absolute_width (line 270) | def absolute_width(self): method base_width (line 274) | def base_width(self): method base_width (line 278) | def base_width(self, val): method width_mult (line 286) | def width_mult(self): class EnsembleMLP (line 290) | class EnsembleMLP(MLP): method __init__ (line 291) | def __init__( method _parse_reduction (line 392) | def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) ... method forward (line 428) | def forward(self, h: torch.Tensor) -> torch.Tensor: method __repr__ (line 452) | def __repr__(self): FILE: graphium/nn/pyg_layers/dimenet_pyg.py class ResidualLayer (line 30) | class ResidualLayer(torch.nn.Module): method __init__ (line 36) | def __init__(self, hidden_channels: int, activation: Union[Callable, s... method forward (line 41) | def forward(self, x: Tensor) -> Tensor: class OutputBlock (line 45) | class OutputBlock(torch.nn.Module): method __init__ (line 51) | def __init__( method forward (line 67) | def forward(self, x: Tensor, rbf: Tensor, i: Tensor, num_nodes: Option... class InteractionBlock (line 75) | class InteractionBlock(nn.Module): method __init__ (line 81) | def __init__( method forward (line 113) | def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor,... class DimeNetPyg (line 141) | class DimeNetPyg(BaseGraphModule): method __init__ (line 142) | def __init__( method forward (line 255) | def forward(self, batch: Union[Data, Batch]) -> Union[Data, Batch]: method layer_supports_edges (line 287) | def layer_supports_edges(cls) -> bool: method layer_inputs_edges (line 299) | def layer_inputs_edges(self) -> bool: method layer_outputs_edges (line 314) | def layer_outputs_edges(self) -> bool: method out_dim_factor (line 329) | def out_dim_factor(self) -> int: FILE: graphium/nn/pyg_layers/gated_gcn_pyg.py class GatedGCNPyg (line 28) | class GatedGCNPyg(MessagePassing, BaseGraphStructure): method __init__ (line 29) | def __init__( method forward (line 106) | def forward( method message (line 143) | def message(self, Dx_i: torch.Tensor, Ex_j: torch.Tensor, Ce: torch.Te... method aggregate (line 159) | def aggregate( method update (line 196) | def update(self, aggr_out: torch.Tensor, Ax: torch.Tensor): method layer_supports_edges (line 212) | def layer_supports_edges(cls) -> bool: method layer_inputs_edges (line 223) | def layer_inputs_edges(self) -> bool: method layer_outputs_edges (line 237) | def layer_outputs_edges(self) -> bool: method out_dim_factor (line 251) | def out_dim_factor(self) -> int: FILE: graphium/nn/pyg_layers/gcn_pyg.py class GCNConvPyg (line 23) | class GCNConvPyg(BaseGraphModule): method __init__ (line 24) | def __init__( method forward (line 47) | def forward( method layer_supports_edges (line 64) | def layer_supports_edges(cls) -> bool: method layer_inputs_edges (line 76) | def layer_inputs_edges(self) -> bool: method layer_outputs_edges (line 91) | def layer_outputs_edges(self) -> bool: method out_dim_factor (line 106) | def out_dim_factor(self) -> int: FILE: graphium/nn/pyg_layers/gin_pyg.py class GINConvPyg (line 24) | class GINConvPyg(BaseGraphModule): method __init__ (line 25) | def __init__( method forward (line 94) | def forward( method layer_supports_edges (line 111) | def layer_supports_edges(cls) -> bool: method layer_inputs_edges (line 123) | def layer_inputs_edges(self) -> bool: method layer_outputs_edges (line 138) | def layer_outputs_edges(self) -> bool: method out_dim_factor (line 153) | def out_dim_factor(self) -> int: class GINEConvPyg (line 173) | class GINEConvPyg(BaseGraphModule): method __init__ (line 174) | def __init__( method forward (line 244) | def forward( method layer_supports_edges (line 261) | def layer_supports_edges(cls) -> bool: method layer_inputs_edges (line 273) | def layer_inputs_edges(self) -> bool: method layer_outputs_edges (line 288) | def layer_outputs_edges(self) -> bool: method out_dim_factor (line 303) | def out_dim_factor(self) -> int: FILE: graphium/nn/pyg_layers/gps_pyg.py class GPSLayerPyg (line 53) | class GPSLayerPyg(BaseGraphModule): method __init__ (line 54) | def __init__( method residual_add (line 217) | def residual_add(self, feature: Tensor, input_feature: Tensor) -> Tensor: method scale_activations (line 230) | def scale_activations(self, feature: Tensor, scale_factor: Tensor) -> ... method forward (line 245) | def forward(self, batch: Batch) -> Batch: method _parse_mpnn_layer (line 299) | def _parse_mpnn_layer(self, mpnn_type, mpnn_kwargs: Dict[str, Any]) ->... method _parse_attn_layer (line 324) | def _parse_attn_layer( method _use_packing (line 365) | def _use_packing(self, batch: Batch) -> bool: method _to_dense_batch (line 372) | def _to_dense_batch( method _to_sparse_batch (line 406) | def _to_sparse_batch(self, batch: Batch, h_dense: Tensor, idx: Tensor)... method _self_attention_block (line 422) | def _self_attention_block(self, feat: Tensor, feat_in: Tensor, batch: ... method _sa_block (line 472) | def _sa_block( method layer_supports_edges (line 498) | def layer_supports_edges(cls) -> bool: method layer_inputs_edges (line 510) | def layer_inputs_edges(self) -> bool: method layer_outputs_edges (line 525) | def layer_outputs_edges(self) -> bool: method out_dim_factor (line 542) | def out_dim_factor(self) -> int: FILE: graphium/nn/pyg_layers/mpnn_pyg.py class MPNNPlusPyg (line 25) | class MPNNPlusPyg(BaseGraphModule): method __init__ (line 26) | def __init__( method gather_features (line 188) | def gather_features( method aggregate_features (line 235) | def aggregate_features( method forward (line 297) | def forward(self, batch: Batch) -> Batch: method layer_supports_edges (line 343) | def layer_supports_edges(cls) -> bool: method layer_inputs_edges (line 355) | def layer_inputs_edges(self) -> bool: method layer_outputs_edges (line 370) | def layer_outputs_edges(self) -> bool: method out_dim_factor (line 385) | def out_dim_factor(self) -> int: FILE: graphium/nn/pyg_layers/pna_pyg.py class PNAMessagePassingPyg (line 30) | class PNAMessagePassingPyg(MessagePassing, BaseGraphStructure): method __init__ (line 31) | def __init__( method forward (line 167) | def forward(self, batch: Union[Data, Batch]) -> Union[Data, Batch]: method message (line 182) | def message(self, x_i: Tensor, x_j: Tensor, edge_feat: OptTensor) -> T... method aggregate (line 202) | def aggregate( method layer_outputs_edges (line 263) | def layer_outputs_edges(self) -> bool: method out_dim_factor (line 278) | def out_dim_factor(self) -> int: method layer_inputs_edges (line 298) | def layer_inputs_edges(self) -> bool: method layer_supports_edges (line 313) | def layer_supports_edges(cls) -> bool: FILE: graphium/nn/pyg_layers/pooling_pyg.py function scatter_logsum_pool (line 30) | def scatter_logsum_pool(x: Tensor, batch: LongTensor, dim: int = 0, dim_... function scatter_std_pool (line 62) | def scatter_std_pool(x: Tensor, batch: LongTensor, dim: int = 0, dim_siz... class PoolingWrapperPyg (line 88) | class PoolingWrapperPyg(ModuleWrap): method __init__ (line 89) | def __init__(self, func, feat_type, *args, **kwargs) -> None: method forward (line 93) | def forward(self, g: Batch, feature: Tensor, *args, **kwargs): function parse_pooling_layer_pyg (line 112) | def parse_pooling_layer_pyg(in_dim: int, pooling: Union[str, List[str]],... class VirtualNodePyg (line 164) | class VirtualNodePyg(nn.Module): method __init__ (line 165) | def __init__( method forward (line 286) | def forward( FILE: graphium/nn/pyg_layers/utils.py class PreprocessPositions (line 27) | class PreprocessPositions(nn.Module): method __init__ (line 32) | def __init__( method forward (line 75) | def forward( class GaussianLayer (line 160) | class GaussianLayer(nn.Module): method __init__ (line 161) | def __init__(self, num_kernels=128, in_dim=3): method forward (line 175) | def forward(self, input: Tensor) -> Tensor: function triplets (line 190) | def triplets( FILE: graphium/nn/residual_connections.py class ResidualConnectionBase (line 27) | class ResidualConnectionBase(nn.Module): method __init__ (line 28) | def __init__(self, skip_steps: int = 1): method _bool_apply_skip_step (line 50) | def _bool_apply_skip_step(self, step_idx: int): method __repr__ (line 63) | def __repr__(self): method h_dim_increase_type (line 71) | def h_dim_increase_type(cls): method get_true_out_dims (line 90) | def get_true_out_dims(self, out_dims: List) -> List: method has_weights (line 127) | def has_weights(cls): class ResidualConnectionNone (line 138) | class ResidualConnectionNone(ResidualConnectionBase): method __init__ (line 144) | def __init__(self, skip_steps: int = 1): method __repr__ (line 147) | def __repr__(self): method h_dim_increase_type (line 154) | def h_dim_increase_type(cls): method has_weights (line 165) | def has_weights(cls): method forward (line 175) | def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): class ResidualConnectionSimple (line 191) | class ResidualConnectionSimple(ResidualConnectionBase): method __init__ (line 192) | def __init__(self, skip_steps: int = 1): method h_dim_increase_type (line 208) | def h_dim_increase_type(cls): method has_weights (line 219) | def has_weights(cls): method forward (line 229) | def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): class ResidualConnectionWeighted (line 265) | class ResidualConnectionWeighted(ResidualConnectionBase): method __init__ (line 266) | def __init__( method h_dim_increase_type (line 333) | def h_dim_increase_type(cls): method has_weights (line 343) | def has_weights(cls): method forward (line 353) | def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): method _bool_apply_skip_step (line 390) | def _bool_apply_skip_step(self, step_idx: int): class ResidualConnectionConcat (line 394) | class ResidualConnectionConcat(ResidualConnectionBase): method __init__ (line 395) | def __init__(self, skip_steps: int = 1): method h_dim_increase_type (line 412) | def h_dim_increase_type(cls): method has_weights (line 423) | def has_weights(cls): method forward (line 433) | def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): class ResidualConnectionDenseNet (line 471) | class ResidualConnectionDenseNet(ResidualConnectionBase): method __init__ (line 472) | def __init__(self, skip_steps: int = 1): method h_dim_increase_type (line 489) | def h_dim_increase_type(cls): method has_weights (line 500) | def has_weights(cls): method forward (line 510) | def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): class ResidualConnectionRandom (line 548) | class ResidualConnectionRandom(ResidualConnectionBase): method __init__ (line 549) | def __init__(self, skip_steps=1, out_dims: List[int] = None, num_layer... method h_dim_increase_type (line 580) | def h_dim_increase_type(cls): method has_weights (line 590) | def has_weights(cls): method forward (line 598) | def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int): FILE: graphium/nn/utils.py class MupMixin (line 20) | class MupMixin(abc.ABC): method make_mup_base_kwargs (line 22) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d... method scale_kwargs (line 38) | def scale_kwargs(self, scale_factor: Real, scale_in_dim: bool = False): FILE: graphium/trainer/losses.py class HybridCELoss (line 22) | class HybridCELoss(_WeightedLoss): method __init__ (line 23) | def __init__( method forward (line 68) | def forward(self, input: Tensor, target: Tensor, nan_targets: Tensor =... FILE: graphium/trainer/metrics.py class Thresholder (line 36) | class Thresholder: method __init__ (line 37) | def __init__( method compute (line 50) | def compute(self, preds: Tensor, target: Tensor): method __call__ (line 61) | def __call__(self, preds: Tensor, target: Tensor): method __repr__ (line 64) | def __repr__(self): method _get_operator (line 72) | def _get_operator(operator): method __getstate__ (line 97) | def __getstate__(self): method __setstate__ (line 115) | def __setstate__(self, state: dict): method __eq__ (line 120) | def __eq__(self, obj) -> bool: class MetricWrapper (line 131) | class MetricWrapper: method __init__ (line 137) | def __init__( method _parse_target_nan_mask (line 201) | def _parse_target_nan_mask(target_nan_mask): method _parse_multitask_handling (line 227) | def _parse_multitask_handling(multitask_handling, target_nan_mask): method _get_metric (line 256) | def _get_metric(metric): method compute (line 267) | def compute(self, preds: Tensor, target: Tensor) -> Tensor: method _filter_nans (line 343) | def _filter_nans(self, preds: Tensor, target: Tensor): method __call__ (line 359) | def __call__(self, preds: Tensor, target: Tensor) -> Tensor: method __repr__ (line 365) | def __repr__(self): method __eq__ (line 375) | def __eq__(self, obj) -> bool: method __getstate__ (line 388) | def __getstate__(self): method __setstate__ (line 405) | def __setstate__(self, state: dict): FILE: graphium/trainer/predictor.py class PredictorModule (line 42) | class PredictorModule(lightning.LightningModule): method __init__ (line 43) | def __init__( method forward (line 197) | def forward( method _convert_features_dtype (line 221) | def _convert_features_dtype(self, feats): method _get_task_key (line 231) | def _get_task_key(self, task_level: str, task: str): method configure_optimizers (line 237) | def configure_optimizers(self, impl=None): method compute_loss (line 254) | def compute_loss( method _general_step (line 327) | def _general_step(self, batch: Dict[str, Tensor], step_name: str, to_c... method flag_step (line 402) | def flag_step(self, batch: Dict[str, Tensor], step_name: str, to_cpu: ... method on_train_batch_start (line 471) | def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional... method on_train_batch_end (line 478) | def on_train_batch_end(self, outputs, batch: Any, batch_idx: int) -> N... method training_step (line 537) | def training_step(self, batch: Dict[str, Tensor], to_cpu: bool = True)... method get_gradient_norm (line 554) | def get_gradient_norm(self): method validation_step (line 564) | def validation_step(self, batch: Dict[str, Tensor], to_cpu: bool = Tru... method test_step (line 567) | def test_step(self, batch: Dict[str, Tensor], to_cpu: bool = True) -> ... method _general_epoch_end (line 570) | def _general_epoch_end(self, outputs: Dict[str, Any], step_name: str, ... method on_train_epoch_start (line 607) | def on_train_epoch_start(self) -> None: method on_train_epoch_end (line 610) | def on_train_epoch_end(self) -> None: method on_validation_epoch_start (line 618) | def on_validation_epoch_start(self) -> None: method on_validation_batch_start (line 623) | def on_validation_batch_start(self, batch: Any, batch_idx: int, datalo... method on_validation_batch_end (line 627) | def on_validation_batch_end( method on_validation_epoch_end (line 637) | def on_validation_epoch_end(self) -> None: method on_test_batch_end (line 651) | def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int, ... method on_test_epoch_end (line 654) | def on_test_epoch_end(self) -> None: method on_train_start (line 665) | def on_train_start(self): method get_progress_bar_dict (line 671) | def get_progress_bar_dict(self) -> Dict[str, float]: method __repr__ (line 682) | def __repr__(self) -> str: method list_pretrained_models (line 692) | def list_pretrained_models(): method load_pretrained_model (line 697) | def load_pretrained_model(name_or_path: str, device: str = None): method set_max_nodes_edges_per_graph (line 720) | def set_max_nodes_edges_per_graph(self, datamodule: BaseDataModule, st... method get_num_graphs (line 728) | def get_num_graphs(self, data: Batch): FILE: graphium/trainer/predictor_options.py class ModelOptions (line 37) | class ModelOptions: class OptimOptions (line 54) | class OptimOptions: method set_kwargs (line 93) | def set_kwargs(self): class EvalOptions (line 137) | class EvalOptions: method check_metrics_validity (line 167) | def check_metrics_validity(self): method parse_loss_fun (line 185) | def parse_loss_fun(loss_fun: Union[str, Dict, Callable]) -> Callable: class FlagOptions (line 226) | class FlagOptions: method set_kwargs (line 245) | def set_kwargs(self): FILE: graphium/trainer/predictor_summaries.py class SummaryInterface (line 26) | class SummaryInterface(object): method set_results (line 31) | def set_results(self, **kwargs): method get_dict_summary (line 34) | def get_dict_summary(self): method update_predictor_state (line 37) | def update_predictor_state(self, **kwargs): method get_metrics_logs (line 40) | def get_metrics_logs(self, **kwargs): class Summary (line 44) | class Summary(SummaryInterface): method __init__ (line 46) | def __init__( method update_predictor_state (line 103) | def update_predictor_state( method set_results (line 121) | def set_results( method is_best_epoch (line 145) | def is_best_epoch(self, step_name: str, loss: Tensor, metrics: Dict[st... method get_results (line 173) | def get_results( method get_best_results (line 186) | def get_best_results( method get_results_on_progress_bar (line 199) | def get_results_on_progress_bar( method get_dict_summary (line 220) | def get_dict_summary(self) -> Dict[str, Any]: method get_metrics_logs (line 241) | def get_metrics_logs(self) -> Dict[str, Any]: method metric_log_name (line 295) | def metric_log_name(self, task_name, metric_name, step_name): class Results (line 301) | class Results: method __init__ (line 302) | def __init__( class TaskSummaries (line 334) | class TaskSummaries(SummaryInterface): method __init__ (line 335) | def __init__( method update_predictor_state (line 380) | def update_predictor_state( method set_results (line 410) | def set_results(self, task_metrics: Dict[str, Dict[str, Tensor]]): method get_results (line 425) | def get_results( method get_best_results (line 441) | def get_best_results( method get_results_on_progress_bar (line 457) | def get_results_on_progress_bar( method get_dict_summary (line 475) | def get_dict_summary( method get_metrics_logs (line 488) | def get_metrics_logs( method concatenate_metrics_logs (line 513) | def concatenate_metrics_logs( method metric_log_name (line 530) | def metric_log_name( FILE: graphium/utils/arg_checker.py function _parse_type (line 32) | def _parse_type(type_to_validate, accepted_types): function _enforce_iter_type (line 50) | def _enforce_iter_type(arg, enforce_type): function check_arg_iterator (line 62) | def check_arg_iterator(arg, enforce_type=None, enforce_subtype=None, cas... function check_list1_in_list2 (line 149) | def check_list1_in_list2(list1, list2, throw_error=True): function check_columns_choice (line 184) | def check_columns_choice(dataframe, columns_choice, extra_accepted_cols=... FILE: graphium/utils/command_line_utils.py function get_anchors_and_aliases (line 19) | def get_anchors_and_aliases(filepath): function update_config (line 67) | def update_config(cfg: Dict, unknown: List, anchors: List): FILE: graphium/utils/custom_lr.py class WarmUpLinearLR (line 18) | class WarmUpLinearLR(_LRScheduler): method __init__ (line 31) | def __init__(self, optimizer, max_num_epochs, warmup_epochs=0, min_lr=... method get_lr (line 38) | def get_lr(self): method _get_closed_form_lr (line 55) | def _get_closed_form_lr(self): FILE: graphium/utils/decorators.py class classproperty (line 14) | class classproperty(property): method __get__ (line 28) | def __get__(self, cls, owner): FILE: graphium/utils/fs.py function get_cache_dir (line 26) | def get_cache_dir(suffix: str = None, create: bool = True) -> pathlib.Path: function get_mapper (line 42) | def get_mapper(path: Union[str, os.PathLike]): function get_basename (line 50) | def get_basename(path: Union[str, os.PathLike]): function get_extension (line 61) | def get_extension(path: Union[str, os.PathLike]): function exists (line 70) | def exists(path: Union[str, os.PathLike, fsspec.core.OpenFile, io.IOBase]): function exists_and_not_empty (line 88) | def exists_and_not_empty(path: Union[str, os.PathLike]): function mkdir (line 99) | def mkdir(path: Union[str, os.PathLike], exist_ok: bool = True): function rm (line 105) | def rm(path: Union[str, os.PathLike], recursive=False, maxdepth=None): function join (line 111) | def join(*paths): function get_size (line 124) | def get_size(file: Union[str, os.PathLike, io.IOBase, fsspec.core.OpenFi... function copy (line 146) | def copy( FILE: graphium/utils/hashing.py function get_md5_hash (line 19) | def get_md5_hash(object: Any) -> str: FILE: graphium/utils/moving_average_tracker.py class MovingAverageTracker (line 18) | class MovingAverageTracker: method update (line 22) | def update(self, value: float): method reset (line 28) | def reset(self): FILE: graphium/utils/mup.py function apply_infshapes (line 24) | def apply_infshapes(model, infshapes): function set_base_shapes (line 36) | def set_base_shapes(model, base, rescale_params=True, delta=None, savefi... FILE: graphium/utils/packing.py class MolPack (line 19) | class MolPack: method __init__ (line 26) | def __init__(self): method add_mol (line 32) | def add_mol(self, num_nodes: int, idx: int) -> "MolPack": method expected_atoms (line 47) | def expected_atoms(self, remaining_mean_num_nodes: float, batch_size: ... method __repr__ (line 64) | def __repr__(self) -> str: function smart_packing (line 71) | def smart_packing(num_nodes: List[int], batch_size: int) -> List[List[in... function fast_packing (line 128) | def fast_packing(num_nodes: List[int], batch_size: int) -> List[List[int]]: function hybrid_packing (line 161) | def hybrid_packing(num_nodes: List[int], batch_size: int) -> List[List[i... function get_pack_sizes (line 227) | def get_pack_sizes(packed_indices, num_nodes): function estimate_max_pack_node_size (line 240) | def estimate_max_pack_node_size(num_nodes: Iterable[int], batch_size: in... function node_to_pack_indices_mask (line 268) | def node_to_pack_indices_mask( FILE: graphium/utils/safe_run.py class SafeRun (line 18) | class SafeRun: method __init__ (line 19) | def __init__(self, name: str, raise_error: bool = True, verbose: int =... method __enter__ (line 42) | def __enter__(self): method __exit__ (line 49) | def __exit__(self, type, value, traceback): FILE: graphium/utils/tensor.py function save_im (line 29) | def save_im(im_dir, im_name: str, ext: List[str] = ["svg", "png"], dpi: ... function is_dtype_torch_tensor (line 43) | def is_dtype_torch_tensor(dtype: Union[np.dtype, torch.dtype]) -> bool: function is_dtype_numpy_array (line 57) | def is_dtype_numpy_array(dtype: Union[np.dtype, torch.dtype]) -> bool: function one_of_k_encoding (line 78) | def one_of_k_encoding(val: Any, classes: Iterable[Any]) -> List[int]: function is_device_cuda (line 103) | def is_device_cuda(device: torch.device, ignore_errors: bool = False) ->... function nan_mean (line 127) | def nan_mean(input: Tensor, *args, **kwargs) -> Tensor: function nan_median (line 154) | def nan_median(input: Tensor, **kwargs) -> Tensor: function nan_var (line 202) | def nan_var(input: Tensor, unbiased: bool = True, **kwargs) -> Tensor: function nan_std (line 242) | def nan_std(input: Tensor, unbiased: bool = True, **kwargs) -> Tensor: function nan_mad (line 270) | def nan_mad(input: Tensor, normal: bool = True, **kwargs) -> Tensor: class ModuleWrap (line 304) | class ModuleWrap(torch.nn.Module): method __init__ (line 312) | def __init__(self, func, *args, **kwargs) -> None: method forward (line 319) | def forward(self, *args, **kwargs): method __repr__ (line 325) | def __repr__(self): class ModuleListConcat (line 329) | class ModuleListConcat(torch.nn.ModuleList): method __init__ (line 339) | def __init__(self, dim: int = -1): method forward (line 343) | def forward(self, *args, **kwargs) -> Tensor: function parse_valid_args (line 355) | def parse_valid_args(param_dict, fn): function arg_in_func (line 382) | def arg_in_func(fn, arg): function tensor_fp16_to_fp32 (line 402) | def tensor_fp16_to_fp32(tensor: Tensor) -> Tensor: function dict_tensor_fp16_to_fp32 (line 416) | def dict_tensor_fp16_to_fp32( FILE: profiling/profile_mol_to_graph.py function main (line 23) | def main(): FILE: profiling/profile_one_of_k_encoding.py function main (line 18) | def main(): FILE: profiling/profile_predictor.py function main (line 29) | def main(): FILE: tests/conftest.py function datadir (line 22) | def datadir(request): FILE: tests/test_architectures.py class test_FeedForwardNN (line 40) | class test_FeedForwardNN(ut.TestCase): method test_forward_no_residual (line 52) | def test_forward_no_residual(self): method test_forward_simple_residual_1 (line 78) | def test_forward_simple_residual_1(self): method test_forward_norms (line 104) | def test_forward_norms(self): method test_forward_simple_residual_2 (line 134) | def test_forward_simple_residual_2(self): method test_forward_concat_residual_1 (line 162) | def test_forward_concat_residual_1(self): method test_forward_concat_residual_2 (line 190) | def test_forward_concat_residual_2(self): method test_forward_densenet_residual_1 (line 218) | def test_forward_densenet_residual_1(self): method test_forward_densenet_residual_2 (line 246) | def test_forward_densenet_residual_2(self): method test_forward_weighted_residual_1 (line 274) | def test_forward_weighted_residual_1(self): method test_forward_weighted_residual_2 (line 304) | def test_forward_weighted_residual_2(self): class test_FeedForwardGraph (line 335) | class test_FeedForwardGraph(ut.TestCase): method test_forward_no_residual (line 379) | def test_forward_no_residual(self): method test_forward_simple_residual (line 437) | def test_forward_simple_residual(self): method test_forward_weighted_residual (line 494) | def test_forward_weighted_residual(self): method test_forward_concat_residual (line 551) | def test_forward_concat_residual(self): method test_forward_densenet_residual (line 609) | def test_forward_densenet_residual(self): FILE: tests/test_attention.py function seed_everything (line 29) | def seed_everything(seed: int): class test_MultiHeadAttention (line 40) | class test_MultiHeadAttention(ut.TestCase): method test_attention_class (line 58) | def test_attention_class(self): FILE: tests/test_base_layers.py class test_Base_Layers (line 27) | class test_Base_Layers(ut.TestCase): method test_droppath_layer_0p5 (line 48) | def test_droppath_layer_0p5(self): method test_droppath_layer_1p0 (line 56) | def test_droppath_layer_1p0(self): method test_droppath_layer_0p0 (line 65) | def test_droppath_layer_0p0(self): method test_transformer_encoder_layer_mup (line 73) | def test_transformer_encoder_layer_mup(self): FILE: tests/test_collate.py class test_Collate (line 27) | class test_Collate(ut.TestCase): method test_collate_labels (line 28) | def test_collate_labels(self): FILE: tests/test_data_utils.py class TestDataUtils (line 20) | class TestDataUtils(ut.TestCase): method test_list_datasets (line 21) | def test_list_datasets( method test_download_datasets (line 28) | def test_download_datasets(self): FILE: tests/test_datamodule.py class Test_DataModule (line 27) | class Test_DataModule(ut.TestCase): method test_ogb_datamodule (line 28) | def test_ogb_datamodule(self): method test_none_filtering (line 90) | def test_none_filtering(self): method test_caching (line 181) | def test_caching(self): method test_datamodule_with_none_molecules (line 336) | def test_datamodule_with_none_molecules(self): method test_datamodule_multiple_data_files (line 444) | def test_datamodule_multiple_data_files(self): method test_splits_file (line 497) | def test_splits_file(self): FILE: tests/test_dataset.py class Test_Multitask_Dataset (line 22) | class Test_Multitask_Dataset(ut.TestCase): method test_multitask_dataset_case_1 (line 29) | def test_multitask_dataset_case_1(self): method test_multitask_dataset_case_2 (line 89) | def test_multitask_dataset_case_2(self): method test_multitask_dataset_case_3 (line 168) | def test_multitask_dataset_case_3(self): FILE: tests/test_ensemble_layers.py class test_Ensemble_Layers (line 33) | class test_Ensemble_Layers(ut.TestCase): method check_ensemble_linear (line 34) | def check_ensemble_linear( method test_ensemble_linear (line 99) | def test_ensemble_linear(self): method test_ensemble_mureadout_graphium (line 115) | def test_ensemble_mureadout_graphium(self): method check_ensemble_fclayer (line 150) | def check_ensemble_fclayer( method test_ensemble_fclayer (line 209) | def test_ensemble_fclayer(self): method check_ensemble_mlp (line 236) | def check_ensemble_mlp( method test_ensemble_mlp (line 302) | def test_ensemble_mlp(self): method check_ensemble_feedforwardnn (line 329) | def check_ensemble_feedforwardnn( method check_ensemble_feedforwardnn_mean (line 404) | def check_ensemble_feedforwardnn_mean( method check_ensemble_feedforwardnn_simple (line 476) | def check_ensemble_feedforwardnn_simple( method test_ensemble_feedforwardnn (line 507) | def test_ensemble_feedforwardnn(self): FILE: tests/test_featurizer.py class test_featurizer (line 33) | class test_featurizer(ut.TestCase): method test_get_mol_atomic_features_onehot (line 99) | def test_get_mol_atomic_features_onehot(self): method test_get_mol_atomic_features_float (line 123) | def test_get_mol_atomic_features_float(self): method test_get_mol_atomic_features_float_nan_mask (line 145) | def test_get_mol_atomic_features_float_nan_mask(self): method test_get_mol_edge_features (line 178) | def test_get_mol_edge_features(self): method test_mol_to_adj_and_features (line 199) | def test_mol_to_adj_and_features(self): method test_mol_to_pyggraph (line 236) | def test_mol_to_pyggraph(self): FILE: tests/test_finetuning.py class Test_Finetuning (line 42) | class Test_Finetuning(ut.TestCase): method test_finetuning_from_task_head (line 43) | def test_finetuning_from_task_head(self): method test_finetuning_from_gnn (line 232) | def test_finetuning_from_gnn(self): FILE: tests/test_ipu_dataloader.py function random_packing (line 31) | def random_packing(num_nodes, batch_size): function global_batch_collator (line 39) | def global_batch_collator(batch_size, batches): class test_DataLoading (line 49) | class test_DataLoading(ut.TestCase): class TestSimpleLightning (line 50) | class TestSimpleLightning(LightningModule): method __init__ (line 52) | def __init__(self, batch_size, node_feat_size, edge_feat_size, num_b... method validation_step (line 61) | def validation_step(self, batch, batch_idx): method training_step (line 66) | def training_step(self, batch, batch_idx): method forward (line 71) | def forward(self, batch): method assert_shapes (line 76) | def assert_shapes(self, batch, batch_idx, step): method configure_optimizers (line 98) | def configure_optimizers(self): class TestDataset (line 101) | class TestDataset(torch.utils.data.Dataset): method __init__ (line 103) | def __init__(self, labels, node_features, edge_features): method __len__ (line 108) | def __len__(self): method __getitem__ (line 111) | def __getitem__(self, idx): method test_poptorch_simple_deviceiterations_gradient_accumulation (line 116) | def test_poptorch_simple_deviceiterations_gradient_accumulation(self): method test_poptorch_graphium_deviceiterations_gradient_accumulation_full (line 193) | def test_poptorch_graphium_deviceiterations_gradient_accumulation_full... FILE: tests/test_ipu_losses.py class test_Losses (line 25) | class test_Losses(ut.TestCase): method test_bce (line 40) | def test_bce(self): method test_mse (line 81) | def test_mse(self): method test_l1 (line 104) | def test_l1(self): method test_bce_logits (line 127) | def test_bce_logits(self): FILE: tests/test_ipu_metrics.py class test_Metrics (line 50) | class test_Metrics(ut.TestCase): method test_auroc (line 65) | def test_auroc(self): method test_average_precision (line 116) | def test_average_precision(self): # TODO: Make work with multi-class method test_precision (line 148) | def test_precision(self): method test_accuracy (line 254) | def test_accuracy(self): method test_recall (line 341) | def test_recall(self): method test_pearsonr (line 429) | def test_pearsonr(self): method test_spearmanr (line 451) | def test_spearmanr(self): method test_r2_score (line 473) | def test_r2_score(self): method test_fbeta_score (line 498) | def test_fbeta_score(self): method test_f1_score (line 602) | def test_f1_score(self): method test_mse (line 695) | def test_mse(self): method test_mae (line 745) | def test_mae(self): FILE: tests/test_ipu_options.py function test_ipu_options (line 65) | def test_ipu_options(): function test_ipu_options_list_to_file (line 122) | def test_ipu_options_list_to_file(): FILE: tests/test_ipu_poptorch.py function test_poptorch (line 18) | def test_poptorch(): FILE: tests/test_ipu_to_dense_batch.py class TestIPUBatch (line 39) | class TestIPUBatch: method setup_class (line 41) | def setup_class(self): method test_ipu_to_dense_batch (line 63) | def test_ipu_to_dense_batch(self, max_num_nodes_per_graph, batch_size): method test_ipu_to_dense_batch_no_batch_no_max_nodes (line 106) | def test_ipu_to_dense_batch_no_batch_no_max_nodes(self): method test_ipu_to_dense_batch_no_batch (line 119) | def test_ipu_to_dense_batch_no_batch(self): method test_ipu_to_dense_batch_drop_last (line 133) | def test_ipu_to_dense_batch_drop_last(self): FILE: tests/test_loaders.py class TestLoader (line 19) | class TestLoader(ut.TestCase): method test_merge_dicts (line 20) | def test_merge_dicts(self): FILE: tests/test_losses.py function _parse (line 26) | def _parse(loss_fun): class test_HybridCELoss (line 31) | class test_HybridCELoss(ut.TestCase): method test_pure_ce_loss (line 38) | def test_pure_ce_loss(self): method test_pure_mae_loss (line 46) | def test_pure_mae_loss(self): method test_pure_mse_loss (line 61) | def test_pure_mse_loss(self): method test_hybrid_loss (line 77) | def test_hybrid_loss(self): method test_loss_parser (line 88) | def test_loss_parser(self): class test_BCELoss (line 112) | class test_BCELoss(ut.TestCase): method test_loss_parser (line 113) | def test_loss_parser(self): FILE: tests/test_metrics.py class test_Metrics (line 32) | class test_Metrics(ut.TestCase): method test_thresholder (line 33) | def test_thresholder(self): class test_MetricWrapper (line 75) | class test_MetricWrapper(ut.TestCase): method test_target_nan_mask (line 76) | def test_target_nan_mask(self): method test_pickling (line 142) | def test_pickling(self): method test_classifigression_target_squeezing (line 206) | def test_classifigression_target_squeezing(self): FILE: tests/test_mtl_architecture.py function toy_test_data (line 133) | def toy_test_data(in_dim=7, in_dim_edges=3): class test_GraphOutputNN (line 153) | class test_GraphOutputNN(ut.TestCase): method generate_test_data (line 154) | def generate_test_data(self): method test_nodepair_max_num_nodes_not_set (line 207) | def test_nodepair_max_num_nodes_not_set(self): method test_nodepair_with_max_num_nodes (line 231) | def test_nodepair_with_max_num_nodes(self): class test_TaskHeads (line 257) | class test_TaskHeads(ut.TestCase): method test_task_heads_forward (line 258) | def test_task_heads_forward(self): method test_task_heads_non_supported_level (line 342) | def test_task_heads_non_supported_level(self): class test_Multitask_NN (line 370) | class test_Multitask_NN(ut.TestCase): method test_full_graph_multitask_forward (line 396) | def test_full_graph_multitask_forward(self): method test_full_graph_multi_task_from_config (line 559) | def test_full_graph_multi_task_from_config(self): method test_full_graph_multi_task_set_max_num_nodes (line 591) | def test_full_graph_multi_task_set_max_num_nodes(self): FILE: tests/test_multitask_datamodule.py class Test_Multitask_DataModule (line 25) | class Test_Multitask_DataModule(ut.TestCase): method setUp (line 26) | def setUp(self): method tearDown (line 30) | def tearDown(self): method test_multitask_fromsmiles_dm (line 34) | def test_multitask_fromsmiles_dm( method test_multitask_fromsmiles_from_config (line 148) | def test_multitask_fromsmiles_from_config(self): method test_multitask_fromsmiles_from_config_csv (line 203) | def test_multitask_fromsmiles_from_config_csv(self): method test_multitask_fromsmiles_from_config_parquet (line 230) | def test_multitask_fromsmiles_from_config_parquet(self): method test_multitask_with_missing_fromsmiles_from_config_parquet (line 258) | def test_multitask_with_missing_fromsmiles_from_config_parquet(self): method test_extract_graph_level_singletask (line 286) | def test_extract_graph_level_singletask(self): method test_extract_graph_level_multitask (line 297) | def test_extract_graph_level_multitask(self): method test_extract_graph_level_multitask_missing_cols (line 308) | def test_extract_graph_level_multitask_missing_cols(self): method test_non_graph_level_extract_labels (line 325) | def test_non_graph_level_extract_labels(self): method test_non_graph_level_extract_labels_missing_cols (line 336) | def test_non_graph_level_extract_labels_missing_cols(self): method test_tdc_admet_benchmark_data_module (line 357) | def test_tdc_admet_benchmark_data_module(self): FILE: tests/test_mup.py function get_pyg_graphs (line 29) | def get_pyg_graphs(in_dim, in_dim_edges): class test_mup (line 43) | class test_mup(ut.TestCase): method test_feedforwardnn_mup (line 63) | def test_feedforwardnn_mup(self): method test_feedforwardgraph_mup (line 106) | def test_feedforwardgraph_mup(self): method test_fullgraphmultitasknetwork (line 151) | def test_fullgraphmultitasknetwork(self): FILE: tests/test_packing.py function random_packing (line 31) | def random_packing(num_nodes, batch_size): class test_Packing (line 39) | class test_Packing(ut.TestCase): method test_smart_packing (line 40) | def test_smart_packing(self): method test_fast_packing (line 79) | def test_fast_packing(self): method test_hybrid_packing (line 119) | def test_hybrid_packing(self): method test_node_to_pack_indices_mask (line 158) | def test_node_to_pack_indices_mask(self): FILE: tests/test_pe_nodepair.py class test_positional_encodings (line 27) | class test_positional_encodings(ut.TestCase): method test_dimensions (line 63) | def test_dimensions(self): method test_symmetry (line 74) | def test_symmetry(self): method test_max_dist (line 82) | def test_max_dist(self): FILE: tests/test_pe_rw.py class test_pe_spectral (line 25) | class test_pe_spectral(ut.TestCase): method test_caching_and_outputs (line 26) | def test_caching_and_outputs(self): FILE: tests/test_pe_spectral.py class test_pe_spectral (line 25) | class test_pe_spectral(ut.TestCase): method test_for_connected_vs_disconnected_graph (line 35) | def test_for_connected_vs_disconnected_graph(self): FILE: tests/test_pos_transfer_funcs.py class test_pos_transfer_funcs (line 33) | class test_pos_transfer_funcs(ut.TestCase): method test_different_pathways_from_node_to_edge (line 40) | def test_different_pathways_from_node_to_edge(self): FILE: tests/test_positional_encoders.py class test_positional_encoder (line 33) | class test_positional_encoder(ut.TestCase): method test_laplacian_eigvec_eigval (line 45) | def test_laplacian_eigvec_eigval(self): method test_rwse (line 89) | def test_rwse(self): method test_laplacian_eigvec_with_encoder (line 106) | def test_laplacian_eigvec_with_encoder(self): FILE: tests/test_positional_encodings.py class test_positional_encodings (line 29) | class test_positional_encodings(ut.TestCase): method test_dimensions (line 65) | def test_dimensions(self): method test_symmetry (line 76) | def test_symmetry(self): method test_max_dist (line 84) | def test_max_dist(self): FILE: tests/test_predictor.py class test_Predictor (line 25) | class test_Predictor(ut.TestCase): method test_parse_loss_fun (line 26) | def test_parse_loss_fun(self): FILE: tests/test_pyg_layers.py class test_Pyg_Layers (line 44) | class test_Pyg_Layers(ut.TestCase): method test_gpslayer (line 71) | def test_gpslayer(self): method test_ginlayer (line 98) | def test_ginlayer(self): method test_ginelayer (line 114) | def test_ginelayer(self): method test_mpnnlayer (line 135) | def test_mpnnlayer(self): method test_gatedgcnlayer (line 161) | def test_gatedgcnlayer(self): method test_pnamessagepassinglayer (line 187) | def test_pnamessagepassinglayer(self): method test_dimenetlayer (line 234) | def test_dimenetlayer(self): method test_preprocess3Dfeaturelayer (line 297) | def test_preprocess3Dfeaturelayer(self): method test_gaussianlayer (line 319) | def test_gaussianlayer(self): method test_pooling_virtual_node (line 327) | def test_pooling_virtual_node(self): FILE: tests/test_residual_connections.py class test_ResidualConnectionNone (line 31) | class test_ResidualConnectionNone(ut.TestCase): method test_get_true_out_dims_none (line 32) | def test_get_true_out_dims_none(self): method test_forward_none (line 41) | def test_forward_none(self): class test_ResidualConnectionSimple (line 54) | class test_ResidualConnectionSimple(ut.TestCase): method test_get_true_out_dims_simple (line 55) | def test_get_true_out_dims_simple(self): method test_forward_simple (line 64) | def test_forward_simple(self): class test_ResidualConnectionRandom (line 93) | class test_ResidualConnectionRandom(ut.TestCase): method test_get_true_out_dims_random (line 94) | def test_get_true_out_dims_random(self): method test_forward_random (line 112) | def test_forward_random(self): class test_ResidualConnectionWeighted (line 132) | class test_ResidualConnectionWeighted(ut.TestCase): method test_get_true_out_dims_weighted (line 133) | def test_get_true_out_dims_weighted(self): method test_forward_weighted (line 142) | def test_forward_weighted(self): class test_ResidualConnectionConcat (line 184) | class test_ResidualConnectionConcat(ut.TestCase): method test_get_true_out_dims_concat (line 185) | def test_get_true_out_dims_concat(self): method test_forward_concat (line 207) | def test_forward_concat(self): class test_ResidualConnectionDenseNet (line 236) | class test_ResidualConnectionDenseNet(ut.TestCase): method test_get_true_out_dims_densenet (line 237) | def test_get_true_out_dims_densenet(self): method test_forward_densenet (line 259) | def test_forward_densenet(self): FILE: tests/test_training.py class TestCLITraining (line 22) | class TestCLITraining: method setup_class (line 24) | def setup_class(cls): method call_cli_with_overrides (line 51) | def call_cli_with_overrides(self, acc_type: str, acc_prec: str, load_t... method test_cpu_cli_training (line 96) | def test_cpu_cli_training(self, load_type): method test_ipu_cli_training (line 102) | def test_ipu_cli_training(self, load_type): FILE: tests/test_utils.py class test_nan_statistics (line 37) | class test_nan_statistics(ut.TestCase): method test_nan_mean (line 57) | def test_nan_mean(self): method test_nan_std_var (line 77) | def test_nan_std_var(self): method test_nan_median (line 103) | def test_nan_median(self): method test_nan_mad (line 128) | def test_nan_mad(self): class test_SafeRun (line 152) | class test_SafeRun(ut.TestCase): method test_safe_run (line 153) | def test_safe_run(self): class TestTensorFp16ToFp32 (line 181) | class TestTensorFp16ToFp32(ut.TestCase): method test_tensor_fp16_to_fp32 (line 182) | def test_tensor_fp16_to_fp32(self): method test_dict_tensor_fp16_to_fp32 (line 200) | def test_dict_tensor_fp16_to_fp32(self):