Full Code of datamol-io/graphium for AI

main f1c180938718 cached
369 files
4.8 MB
1.3M tokens
1124 symbols
1 requests
Download .txt
Showing preview only (5,101K chars total). Download the full file or copy to clipboard to get everything.
Repository: datamol-io/graphium
Branch: main
Commit: f1c180938718
Files: 369
Total size: 4.8 MB

Directory structure:
gitextract_455isr46/

├── .github/
│   ├── CODEOWNERS
│   ├── CODE_OF_CONDUCT.md
│   ├── PULL_REQUEST_TEMPLATE.md
│   └── workflows/
│       ├── code-check.yml
│       ├── doc.yml
│       ├── release.yml
│       ├── test.yml
│       └── test_ipu.yml
├── .gitignore
├── LICENSE
├── README.md
├── cleanup_files.sh
├── codecov.yml
├── docs/
│   ├── _assets/
│   │   ├── css/
│   │   │   ├── custom-graphium.css
│   │   │   └── custom.css
│   │   └── js/
│   │       └── google-analytics.js
│   ├── api/
│   │   ├── graphium.config.md
│   │   ├── graphium.data.md
│   │   ├── graphium.features.md
│   │   ├── graphium.finetuning.md
│   │   ├── graphium.ipu.md
│   │   ├── graphium.nn/
│   │   │   ├── architectures.md
│   │   │   ├── encoders.md
│   │   │   ├── graphium.nn.md
│   │   │   └── pyg_layers.md
│   │   ├── graphium.trainer.md
│   │   └── graphium.utils.md
│   ├── baseline.md
│   ├── cli/
│   │   ├── graphium-train.md
│   │   ├── graphium.md
│   │   └── reference.md
│   ├── contribute.md
│   ├── datasets.md
│   ├── design.md
│   ├── index.md
│   ├── license.md
│   ├── pretrained_models.md
│   └── tutorials/
│       ├── feature_processing/
│       │   ├── add_new_positional_encoding.ipynb
│       │   ├── choosing_parallelization.ipynb
│       │   ├── csv_to_parquet.ipynb
│       │   └── timing_parallel.ipynb
│       ├── gnn/
│       │   ├── add_new_gnn_layers.ipynb
│       │   ├── making_gnn_networks.ipynb
│       │   └── using_gnn_layers.ipynb
│       └── model_training/
│           └── simple-molecular-model.ipynb
├── enable_ipu.sh
├── env.yml
├── expts/
│   ├── __init__.py
│   ├── configs/
│   │   ├── config_gps_10M_pcqm4m.yaml
│   │   ├── config_gps_10M_pcqm4m_mod.yaml
│   │   ├── config_mpnn_10M_b3lyp.yaml
│   │   └── config_mpnn_pcqm4m.yaml
│   ├── data/
│   │   ├── micro_zinc_splits.csv
│   │   └── tiny_zinc_splits.csv
│   ├── dataset_benchmark.py
│   ├── debug_yaml.py
│   ├── hydra-configs/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── accelerator/
│   │   │   ├── cpu.yaml
│   │   │   ├── gpu.yaml
│   │   │   ├── ipu.yaml
│   │   │   └── ipu_pipeline.yaml
│   │   ├── architecture/
│   │   │   ├── largemix.yaml
│   │   │   ├── pcqm4m.yaml
│   │   │   └── toymix.yaml
│   │   ├── experiment/
│   │   │   └── toymix_mpnn.yaml
│   │   ├── finetuning/
│   │   │   ├── admet.yaml
│   │   │   └── admet_baseline.yaml
│   │   ├── hparam_search/
│   │   │   └── optuna.yaml
│   │   ├── main.yaml
│   │   ├── model/
│   │   │   ├── gated_gcn.yaml
│   │   │   ├── gcn.yaml
│   │   │   ├── gin.yaml
│   │   │   ├── gine.yaml
│   │   │   ├── gpspp.yaml
│   │   │   └── mpnn.yaml
│   │   ├── tasks/
│   │   │   ├── admet.yaml
│   │   │   ├── l1000_mcf7.yaml
│   │   │   ├── l1000_vcap.yaml
│   │   │   ├── largemix.yaml
│   │   │   ├── loss_metrics_datamodule/
│   │   │   │   ├── admet.yaml
│   │   │   │   ├── l1000_mcf7.yaml
│   │   │   │   ├── l1000_vcap.yaml
│   │   │   │   ├── largemix.yaml
│   │   │   │   ├── pcba_1328.yaml
│   │   │   │   ├── pcqm4m.yaml
│   │   │   │   ├── pcqm4m_g25.yaml
│   │   │   │   ├── pcqm4m_n4.yaml
│   │   │   │   └── toymix.yaml
│   │   │   ├── pcba_1328.yaml
│   │   │   ├── pcqm4m.yaml
│   │   │   ├── pcqm4m_g25.yaml
│   │   │   ├── pcqm4m_n4.yaml
│   │   │   ├── task_heads/
│   │   │   │   ├── admet.yaml
│   │   │   │   ├── l1000_mcf7.yaml
│   │   │   │   ├── l1000_vcap.yaml
│   │   │   │   ├── largemix.yaml
│   │   │   │   ├── pcba_1328.yaml
│   │   │   │   ├── pcqm4m.yaml
│   │   │   │   ├── pcqm4m_g25.yaml
│   │   │   │   ├── pcqm4m_n4.yaml
│   │   │   │   └── toymix.yaml
│   │   │   └── toymix.yaml
│   │   └── training/
│   │       ├── accelerator/
│   │       │   ├── largemix_cpu.yaml
│   │       │   ├── largemix_gpu.yaml
│   │       │   ├── largemix_ipu.yaml
│   │       │   ├── pcqm4m_ipu.yaml
│   │       │   ├── toymix_cpu.yaml
│   │       │   ├── toymix_gpu.yaml
│   │       │   └── toymix_ipu.yaml
│   │       ├── largemix.yaml
│   │       ├── model/
│   │       │   ├── largemix_gated_gcn.yaml
│   │       │   ├── largemix_gcn.yaml
│   │       │   ├── largemix_gin.yaml
│   │       │   ├── largemix_gine.yaml
│   │       │   ├── largemix_mpnn.yaml
│   │       │   ├── pcqm4m_gpspp.yaml
│   │       │   ├── pcqm4m_mpnn.yaml
│   │       │   ├── toymix_gcn.yaml
│   │       │   └── toymix_gin.yaml
│   │       ├── pcqm4m.yaml
│   │       └── toymix.yaml
│   ├── main_run_get_fingerprints.py
│   ├── main_run_multitask.py
│   ├── main_run_predict.py
│   ├── main_run_test.py
│   ├── neurips2023_configs/
│   │   ├── base_config/
│   │   │   ├── large.yaml
│   │   │   ├── large_pcba.yaml
│   │   │   ├── large_pcqm_g25.yaml
│   │   │   ├── large_pcqm_n4.yaml
│   │   │   └── small.yaml
│   │   ├── baseline/
│   │   │   ├── config_small_gcn_baseline.yaml
│   │   │   ├── config_small_gin_baseline.yaml
│   │   │   └── config_small_gine_baseline.yaml
│   │   ├── config_classifigression_l1000.yaml
│   │   ├── config_large_gcn.yaml
│   │   ├── config_large_gcn_g25.yaml
│   │   ├── config_large_gcn_gpu.yaml
│   │   ├── config_large_gcn_n4.yaml
│   │   ├── config_large_gcn_pcba.yaml
│   │   ├── config_large_gin.yaml
│   │   ├── config_large_gin_g25.yaml
│   │   ├── config_large_gin_n4.yaml
│   │   ├── config_large_gin_pcba.yaml
│   │   ├── config_large_gine.yaml
│   │   ├── config_large_gine_g25.yaml
│   │   ├── config_large_gine_n4.yaml
│   │   ├── config_large_gine_pcba.yaml
│   │   ├── config_large_mpnn.yaml
│   │   ├── config_luis_jama.yaml
│   │   ├── config_small_gated_gcn.yaml
│   │   ├── config_small_gcn.yaml
│   │   ├── config_small_gcn_gpu.yaml
│   │   ├── config_small_gin.yaml
│   │   ├── config_small_gine.yaml
│   │   ├── config_small_mpnn.yaml
│   │   ├── single_task_gcn/
│   │   │   ├── config_large_gcn_mcf7.yaml
│   │   │   ├── config_large_gcn_pcba.yaml
│   │   │   └── config_large_gcn_vcap.yaml
│   │   ├── single_task_gin/
│   │   │   ├── config_large_gin_g25.yaml
│   │   │   ├── config_large_gin_mcf7.yaml
│   │   │   ├── config_large_gin_n4.yaml
│   │   │   ├── config_large_gin_pcba.yaml
│   │   │   ├── config_large_gin_pcq.yaml
│   │   │   └── config_large_gin_vcap.yaml
│   │   └── single_task_gine/
│   │       ├── config_large_gine_g25.yaml
│   │       ├── config_large_gine_mcf7.yaml
│   │       ├── config_large_gine_n4.yaml
│   │       ├── config_large_gine_pcba.yaml
│   │       ├── config_large_gine_pcq.yaml
│   │       └── config_large_gine_vcap.yaml
│   └── run_validation_test.py
├── graphium/
│   ├── __init__.py
│   ├── _version.py
│   ├── cli/
│   │   ├── __init__.py
│   │   ├── __main__.py
│   │   ├── data.py
│   │   ├── finetune_utils.py
│   │   ├── fingerprints.py
│   │   ├── main.py
│   │   ├── parameters.py
│   │   └── train_finetune_test.py
│   ├── config/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── _load.py
│   │   ├── _loader.py
│   │   ├── config_convert.py
│   │   ├── dummy_finetuning_from_gnn.yaml
│   │   ├── dummy_finetuning_from_task_head.yaml
│   │   ├── fake_and_missing_multilevel_multitask_pyg.yaml
│   │   ├── fake_multilevel_multitask_pyg.yaml
│   │   └── zinc_default_multitask_pyg.yaml
│   ├── data/
│   │   ├── L1000/
│   │   │   └── datasets_goli-L1000_parse_gctx_to_csv.ipynb
│   │   ├── QM9/
│   │   │   ├── micro_qm9.csv
│   │   │   ├── micro_qm9.parquet
│   │   │   └── norm_micro_qm9.csv
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── collate.py
│   │   ├── datamodule.py
│   │   ├── dataset.py
│   │   ├── make_data_splits/
│   │   │   ├── make_train_val_test_splits_large.ipynb
│   │   │   └── make_train_val_test_splits_small.ipynb
│   │   ├── micro_ZINC/
│   │   │   ├── __init__.py
│   │   │   └── micro_ZINC.csv
│   │   ├── multilevel_utils.py
│   │   ├── multitask/
│   │   │   ├── __init__.py
│   │   │   ├── tiny_ZINC_SA.csv
│   │   │   ├── tiny_ZINC_logp.csv
│   │   │   └── tiny_ZINC_score.csv
│   │   ├── normalization.py
│   │   ├── sampler.py
│   │   ├── sdf2csv.py
│   │   ├── single_atom_dataset/
│   │   │   └── single_atom_dataset.csv
│   │   ├── smiles_transform.py
│   │   └── utils.py
│   ├── expts/
│   │   └── pyg_batching_sparse.ipynb
│   ├── features/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── commute.py
│   │   ├── electrostatic.py
│   │   ├── featurizer.py
│   │   ├── graphormer.py
│   │   ├── nmp.py
│   │   ├── periodic_table.csv
│   │   ├── positional_encoding.py
│   │   ├── properties.py
│   │   ├── rw.py
│   │   ├── spectral.py
│   │   └── transfer_pos_level.py
│   ├── finetuning/
│   │   ├── __init__.py
│   │   ├── finetuning.py
│   │   ├── finetuning_architecture.py
│   │   ├── fingerprinting.py
│   │   └── utils.py
│   ├── hyper_param_search/
│   │   ├── __init__.py
│   │   └── results.py
│   ├── ipu/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── ipu_dataloader.py
│   │   ├── ipu_losses.py
│   │   ├── ipu_metrics.py
│   │   ├── ipu_simple_lightning.py
│   │   ├── ipu_utils.py
│   │   ├── ipu_wrapper.py
│   │   └── to_dense_batch.py
│   ├── nn/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── architectures/
│   │   │   ├── README.md
│   │   │   ├── __init__.py
│   │   │   ├── encoder_manager.py
│   │   │   ├── global_architectures.py
│   │   │   └── pyg_architectures.py
│   │   ├── base_graph_layer.py
│   │   ├── base_layers.py
│   │   ├── encoders/
│   │   │   ├── README.md
│   │   │   ├── __init__.py
│   │   │   ├── base_encoder.py
│   │   │   ├── bessel_pos_encoder.py
│   │   │   ├── gaussian_kernel_pos_encoder.py
│   │   │   ├── laplace_pos_encoder.py
│   │   │   ├── mlp_encoder.py
│   │   │   └── signnet_pos_encoder.py
│   │   ├── ensemble_layers.py
│   │   ├── pyg_layers/
│   │   │   ├── README.md
│   │   │   ├── __init__.py
│   │   │   ├── dimenet_pyg.py
│   │   │   ├── gated_gcn_pyg.py
│   │   │   ├── gcn_pyg.py
│   │   │   ├── gin_pyg.py
│   │   │   ├── gps_pyg.py
│   │   │   ├── mpnn_pyg.py
│   │   │   ├── pna_pyg.py
│   │   │   ├── pooling_pyg.py
│   │   │   └── utils.py
│   │   ├── residual_connections.py
│   │   └── utils.py
│   ├── trainer/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── losses.py
│   │   ├── metrics.py
│   │   ├── predictor.py
│   │   ├── predictor_options.py
│   │   └── predictor_summaries.py
│   └── utils/
│       ├── README.md
│       ├── __init__.py
│       ├── arg_checker.py
│       ├── command_line_utils.py
│       ├── custom_lr.py
│       ├── decorators.py
│       ├── fs.py
│       ├── hashing.py
│       ├── moving_average_tracker.py
│       ├── mup.py
│       ├── packing.py
│       ├── safe_run.py
│       ├── spaces.py
│       └── tensor.py
├── install_ipu.sh
├── mkdocs.yml
├── notebooks/
│   ├── compare-pretraining-finetuning-performance.ipynb
│   ├── dev-datamodule-invalidate-cache.ipynb
│   ├── dev-datamodule-ogb.ipynb
│   ├── dev-datamodule.ipynb
│   ├── dev-pretrained.ipynb
│   ├── dev-training-loop.ipynb
│   ├── dev.ipynb
│   ├── finetuning-on-tdc-admet-benchmark.ipynb
│   ├── running-fingerprints-from-pretrained-model.ipynb
│   └── running-model-from-config.ipynb
├── profiling/
│   ├── configs_profiling.yaml
│   ├── profile_mol_to_graph.py
│   ├── profile_one_of_k_encoding.py
│   └── profile_predictor.py
├── pyproject.toml
├── scripts/
│   ├── balance_params_and_train.sh
│   ├── convert_yml.py
│   ├── ipu_start.sh
│   ├── ipu_venv.sh
│   └── scale_mpnn.sh
└── tests/
    ├── .gitignore
    ├── __init__.py
    ├── config_test_ipu_dataloader.yaml
    ├── config_test_ipu_dataloader_multitask.yaml
    ├── conftest.py
    ├── converted_fake_multilevel_data.parquet
    ├── data/
    │   ├── config_micro_ZINC.yaml
    │   ├── micro_ZINC.csv
    │   ├── micro_ZINC_corrupt.csv
    │   ├── micro_ZINC_shard_1.csv
    │   ├── micro_ZINC_shard_1.parquet
    │   ├── micro_ZINC_shard_2.csv
    │   ├── micro_ZINC_shard_2.parquet
    │   └── pcqm4mv2-2k.csv
    ├── fake_and_missing_multilevel_data.parquet
    ├── test_architectures.py
    ├── test_attention.py
    ├── test_base_layers.py
    ├── test_collate.py
    ├── test_data_utils.py
    ├── test_datamodule.py
    ├── test_dataset.py
    ├── test_ensemble_layers.py
    ├── test_featurizer.py
    ├── test_finetuning.py
    ├── test_ipu_dataloader.py
    ├── test_ipu_losses.py
    ├── test_ipu_metrics.py
    ├── test_ipu_options.py
    ├── test_ipu_poptorch.py
    ├── test_ipu_to_dense_batch.py
    ├── test_loaders.py
    ├── test_losses.py
    ├── test_metrics.py
    ├── test_mtl_architecture.py
    ├── test_multitask_datamodule.py
    ├── test_mup.py
    ├── test_packing.py
    ├── test_pe_nodepair.py
    ├── test_pe_rw.py
    ├── test_pe_spectral.py
    ├── test_pos_transfer_funcs.py
    ├── test_positional_encoders.py
    ├── test_positional_encodings.py
    ├── test_predictor.py
    ├── test_pyg_layers.py
    ├── test_residual_connections.py
    ├── test_training.py
    └── test_utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/CODEOWNERS
================================================
* @DomInvivo


================================================
FILE: .github/CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct

## Our Pledge

We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.

We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.

## Our Standards

Examples of behavior that contributes to a positive environment for our
community include:

* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
  and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
  overall community

Examples of unacceptable behavior include:

* The use of sexualized language or imagery, and sexual attention or
  advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
  address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
  professional setting

## Enforcement Responsibilities

Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.

Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.

## Scope

This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
.
All complaints will be reviewed and investigated promptly and fairly.

All community leaders are obligated to respect the privacy and security of the
reporter of any incident.

## Enforcement Guidelines

Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:

### 1. Correction

**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.

**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.

### 2. Warning

**Community Impact**: A violation through a single incident or series
of actions.

**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.

### 3. Temporary Ban

**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.

**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.

### 4. Permanent Ban

**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior,  harassment of an
individual, or aggression toward or disparagement of classes of individuals.

**Consequence**: A permanent ban from any sort of public interaction within
the community.

## Attribution

This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.

Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).

[homepage]: https://www.contributor-covenant.org

For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.


================================================
FILE: .github/PULL_REQUEST_TEMPLATE.md
================================================
## Changelogs

- _enumerate the changes of that PR._

---

_Checklist:_

- [ ] _Was this PR discussed in an issue? It is recommended to first discuss a new feature into a GitHub issue before opening a PR._
- [ ] _Add tests to cover the fixed bug(s) or the new introduced feature(s) (if appropriate)._
- [ ] _Update the API documentation is a new function is added, or an existing one is deleted._
- [ ] _Write concise and explanatory changelogs above._
- [ ] _If possible, assign one of the following labels to the PR: `feature`, `fix` or `test` (or ask a maintainer to do it for you)._

---

_discussion related to that PR_


================================================
FILE: .github/workflows/code-check.yml
================================================
name: code-check

on:
  push:
    branches: ["main"]
    tags: ["*"]
  pull_request:
    branches:
      - "*"
      - "!gh-pages"

jobs:
  python-format-black:
    name: Python lint [black]
    runs-on: ubuntu-latest
    steps:
      - name: Checkout the code
        uses: actions/checkout@v3

      - name: Set up Python
        uses: actions/setup-python@v4
        with:
          python-version: "3.10"

      - name: Install black
        run: |
          pip install black>=23

      - name: Lint
        run: black --check .


================================================
FILE: .github/workflows/doc.yml
================================================
name: doc

on:
  push:
    branches: ["main"]

# Prevent doc action on `main` to conflict with each others.
concurrency:
  group: doc-${{ github.ref }}
  cancel-in-progress: true

jobs:
  doc:
    runs-on: "ubuntu-latest"
    timeout-minutes: 30

    defaults:
      run:
        shell: bash -l {0}

    steps:
      - name: Checkout the code
        uses: actions/checkout@v3

      - name: Setup mamba
        uses: mamba-org/setup-micromamba@v1
        with:
          environment-file: env.yml
          environment-name: graphium
          cache-environment: true
          cache-downloads: true

      - name: Install library
        run: |
          python -m pip install --no-deps .
          pip install typer-cli

      - name: Configure git
        run: |
          git config --global user.name "${GITHUB_ACTOR}"
          git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com"

      - name: Deploy the doc
        run: |

          echo "Auto-generating typer docs"
          typer graphium.cli.__main__ utils docs --name graphium --output docs/cli/graphium.md
          
          echo "Get the gh-pages branch"
          git fetch origin gh-pages

          echo "Build and deploy the doc on main"
          mike deploy --push main


================================================
FILE: .github/workflows/release.yml
================================================
name: release

on:
  workflow_dispatch:
    inputs:
      release-version:
        description: "A valid Semver version string"
        required: true

permissions:
  contents: write
  pull-requests: write

jobs:
  release:
    # Do not release if not triggered from the default branch
    if: github.ref == format('refs/heads/{0}', github.event.repository.default_branch)

    runs-on: ubuntu-latest
    timeout-minutes: 30

    defaults:
      run:
        shell: bash -l {0}

    steps:
      - name: Checkout the code
        uses: actions/checkout@v3

      - name: Setup mamba
        uses: mamba-org/setup-micromamba@v1
        with:
          environment-file: env.yml
          environment-name: graphium
          cache-environment: true
          cache-downloads: true
          create-args: >-
            pip
            semver
            python-build
            setuptools_scm

      - name: Check the version is valid semver
        run: |
          RELEASE_VERSION="${{ inputs.release-version }}"

          {
            pysemver check $RELEASE_VERSION
          } || {
            echo "The version '$RELEASE_VERSION' is not a valid Semver version string."
            echo "Please use a valid semver version string. More details at https://semver.org/"
            echo "The release process is aborted."
            exit 1
          }

      - name: Check the version is higher than the latest one
        run: |
          # Retrieve the git tags first
          git fetch --prune --unshallow --tags &> /dev/null

          RELEASE_VERSION="${{ inputs.release-version }}"
          LATEST_VERSION=$(git describe --abbrev=0 --tags)

          IS_HIGHER_VERSION=$(pysemver compare $RELEASE_VERSION $LATEST_VERSION)

          if [ "$IS_HIGHER_VERSION" != "1" ]; then
            echo "The version '$RELEASE_VERSION' is not higher than the latest version '$LATEST_VERSION'."
            echo "The release process is aborted."
            exit 1
          fi

      - name: Build Changelog
        id: github_release
        uses: mikepenz/release-changelog-builder-action@v4
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
        with:
          toTag: "main"

      - name: Configure git
        run: |
          git config --global user.name "${GITHUB_ACTOR}"
          git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com"

      - name: Create and push git tag
        env:
          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
        run: |
          # Tag the release
          git tag -a "${{ inputs.release-version }}" -m "Release version ${{ inputs.release-version }}"

          # Checkout the git tag
          git checkout "${{ inputs.release-version }}"

          # Push the modified changelogs
          git push origin main

          # Push the tags
          git push origin "${{ inputs.release-version }}"

      - name: Install library
        run: python -m pip install --no-deps .

      - name: Build the wheel and sdist
        run: python -m build --no-isolation

      - name: Publish package to PyPI
        uses: pypa/gh-action-pypi-publish@release/v1
        with:
          password: ${{ secrets.PYPI_API_TOKEN }}
          packages-dir: dist/

      - name: Deploy the doc
        run: |
          echo "Get the gh-pages branch"
          git fetch origin gh-pages

          echo "Build and deploy the doc on ${{ inputs.release-version }}"
          mike deploy --push stable
          mike deploy --push ${{ inputs.release-version }}

      - name: Create GitHub Release
        uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844
        with:
          tag_name: ${{ inputs.release-version }}
          body: ${{steps.github_release.outputs.changelog}}


================================================
FILE: .github/workflows/test.yml
================================================
name: test

on:
  push:
    branches: ["main"]
    tags: ["*"]
  pull_request:
    branches:
      - "*"
      - "!gh-pages"
  schedule:
    - cron: "0 4 * * *"

jobs:
  test:
    strategy:
      fail-fast: false
      matrix:
        python-version: ["3.8", "3.9", "3.10"]
        pytorch-version: ["2.0"]

    runs-on: "ubuntu-latest"
    timeout-minutes: 30

    defaults:
      run:
        shell: bash -l {0}

    name: |
        regular_env -
        python=${{ matrix.python-version }} -
        pytorch=${{ matrix.pytorch-version }}

    steps:
      - name: Checkout the code
        uses: actions/checkout@v3

      - name: Setup mamba
        uses: mamba-org/setup-micromamba@v1
        with:
          environment-file: env.yml
          environment-name: graphium
          cache-environment: true
          cache-downloads: true
          create-args: >-
            python=${{ matrix.python-version }}
            pytorch=${{ matrix.pytorch-version }}

      - name: Install library
        run: python -m pip install --no-deps -e . # `-e` required for correct `coverage` run.

      - name: Run tests
        run: pytest -m 'not ipu'

      - name: Test CLI
        run: graphium --help

      - name: Test building the doc
        run: mkdocs build

      - name: Codecov Upload
        uses: codecov/codecov-action@v3
        with:
          files: ./coverage.xml
          flags: unittests
          name: codecov-umbrella
          fail_ci_if_error: false
          verbose: false
          env_vars: ${{ matrix.python-version }},${{ matrix.pytorch-version }}


================================================
FILE: .github/workflows/test_ipu.yml
================================================
name: test-ipu

on:
  push:
    branches: ["main"]
    tags: ["*"]
  pull_request:
    branches:
      - "*"
      - "!gh-pages"
  schedule:
    - cron: "0 4 * * *"

jobs:
  test-ipu:
    strategy:
      fail-fast: false
      matrix:
        python-version: ["3.8"]
        pytorch-version: ["2.0"]

    runs-on: "ubuntu-20.04"
    timeout-minutes: 30

    defaults:
      run:
        shell: bash -l {0}

    name: |
        poptorch_env - 
        python=${{ matrix.python-version }} -
        pytorch=${{ matrix.pytorch-version }}

    steps:
      - name: Checkout the code
        uses: actions/checkout@v3

      - name: Activate SDK + Install Requirements
        run: |
          python3 -m pip install --upgrade pip
          wget -q -O 'poplar_sdk-ubuntu_20_04-3.3.0-208993bbb7.tar.gz' 'https://downloads.graphcore.ai/direct?package=poplar-poplar_sdk_ubuntu_20_04_3.3.0_208993bbb7-3.3.0&file=poplar_sdk-ubuntu_20_04-3.3.0-208993bbb7.tar.gz'
          tar -xzf poplar_sdk-ubuntu_20_04-3.3.0-208993bbb7.tar.gz
          python3 -m pip install poplar_sdk-ubuntu_20_04-3.3.0+1403-208993bbb7/poptorch-3.3.0+113432_960e9c294b_ubuntu_20_04-cp38-cp38-linux_x86_64.whl
          # Enable Poplar SDK (including Poplar and PopART)
          source poplar_sdk-ubuntu_20_04-3.3.0+1403-208993bbb7/enable 
          
          python -c "import poptorch"

          # Download the datafiles (Total ~ 10Mb - nothing compared to the libraries)
          wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz
          wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz
          wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/qm9.csv.gz


          # Install the IPU specific and graphium requirements
          pip install -r requirements_ipu.txt
          # Install Graphium in dev mode
          python -m pip install --no-deps -e .
          python3 -m pytest -m 'not skip_ipu'

      - name: Codecov Upload
        uses: codecov/codecov-action@v3
        with:
          files: ./coverage.xml
          flags: unittests
          name: codecov-umbrella
          fail_ci_if_error: false
          verbose: false
          env_vars: ${{ matrix.python-version }},${{ matrix.pytorch-version }}


================================================
FILE: .gitignore
================================================
############ graphium Custom GitIgnore ##############

# Workspace
*.code-workspace
.vscode/
.idea/

# Training logs and models
*.out
*.cache
*.ckpt
*.pickle
*.datacache
*.datacache.gz
lightning_logs/
logs/
multirun/
hparam-search-results/
models_checkpoints/
outputs/
out/
saved_models/
wandb/
out/
datacache/
tests/temp_cache*
predictions/
draft/
scripts-expts/
sweeps/
mup/

# Data and predictions
graphium/data/ZINC_bench_gnn/
graphium/data/BindingDB/
graphium/data/mega-pubchem/
graphium/data/cache/
graphium/data/b3lyp/
graphium/data/PCQM4Mv2/
graphium/data/PCQM4M/
graphium/data/neurips2023/small-dataset/
graphium/data/neurips2023/large-dataset/
graphium/data/neurips2023/dummy-dataset/
graphium/data/make_data_splits/*.csv*
graphium/data/make_data_splits/*.pt*
graphium/data/make_data_splits/*.parquet*
*.csv.gz
*.pt

# Others
expts_untracked/
debug/
change_commits.sh
graphium/features/test_new_pes.ipynb

# IPU related ignores and profiler outputs
*.a
*.cbor
*.capnp
*.pop
*.popart
*.pop_cache
*.popef
*.pvti*

############ END graphium Custom GitIgnore ##############


# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
*.txt
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/


================================================
FILE: LICENSE
================================================
 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright 2023 Valence Labs
   Copyright 2023 Recursion Pharmaceuticals
   Copyright 2023 Graphcore Limited

   Various Academic groups have also contributed to this software under
   the given license. These include, but are not limited, to the following

   1. University of Montreal
   2. McGill University
   3. Mila - Institut Quebecois d'intelligence artificielle
   4. HEC Montreal
   5. CIFAR AI Chair
   6. New Jersey Institute of Technology
   7. RWTH Aachen University

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<div align="center">
    <img src="docs/images/banner-tight.png" height="200px">
    <h3>Scaling molecular GNNs to infinity</h3>
</div>

---

[![PyPI](https://img.shields.io/pypi/v/graphium)](https://pypi.org/project/graphium/)
[![Conda](https://img.shields.io/conda/v/conda-forge/graphium?label=conda&color=success)](https://anaconda.org/conda-forge/graphium)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/graphium)](https://pypi.org/project/graphium/)
[![Conda](https://img.shields.io/conda/dn/conda-forge/graphium)](https://anaconda.org/conda-forge/graphium)
[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/datamol-io/graphium/blob/main/LICENSE)
[![GitHub Repo stars](https://img.shields.io/github/stars/datamol-io/graphium)](https://github.com/datamol-io/graphium/stargazers)
[![GitHub Repo stars](https://img.shields.io/github/forks/datamol-io/graphium)](https://github.com/datamol-io/graphium/network/members)
[![test](https://github.com/datamol-io/graphium/actions/workflows/test.yml/badge.svg)](https://github.com/datamol-io/graphium/actions/workflows/test.yml)
[![test-ipu](https://github.com/datamol-io/graphium/actions/workflows/test_ipu.yml/badge.svg)](https://github.com/datamol-io/graphium/actions/workflows/test_ipu.yml)
[![release](https://github.com/datamol-io/graphium/actions/workflows/release.yml/badge.svg)](https://github.com/datamol-io/graphium/actions/workflows/release.yml)
[![code-check](https://github.com/datamol-io/graphium/actions/workflows/code-check.yml/badge.svg)](https://github.com/datamol-io/graphium/actions/workflows/code-check.yml)
[![doc](https://github.com/datamol-io/graphium/actions/workflows/doc.yml/badge.svg)](https://github.com/datamol-io/graphium/actions/workflows/doc.yml)
[![codecov](https://codecov.io/gh/datamol-io/graphium/branch/main/graph/badge.svg?token=bHOkKY5Fze)](https://codecov.io/gh/datamol-io/graphium)
[![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/)

A deep learning library focused on graph representation learning for real-world chemical tasks.

- ✅ State-of-the-art GNN architectures.
- 🐍 Extensible API: build your own GNN model and train it with ease.
- ⚗️ Rich featurization: powerful and flexible built-in molecular featurization.
- 🧠 Pretrained models: for fast and easy inference or transfer learning.
- ⮔ Read-to-use training loop based on [Pytorch Lightning](https://www.pytorchlightning.ai/).
- 🔌 Have a new dataset? Graphium provides a simple plug-and-play interface. Change the path, the name of the columns to predict, the atomic featurization, and you’re ready to play!

## Documentation

Visit https://graphium-docs.datamol.io/.

## Installation for developers

### For CPU and GPU developers

Use [`mamba`](https://github.com/mamba-org/mamba), a faster and better alternative to `conda`.

If you are using a GPU, we recommend enforcing the CUDA version that you need with `CONDA_OVERRIDE_CUDA=XX.X`.

```bash
# Install Graphium's dependencies in a new environment named `graphium`
mamba env create -f env.yml -n graphium

# To force the CUDA version to 11.2, or any other version you prefer, use the following command:
# CONDA_OVERRIDE_CUDA=11.2 mamba env create -f env.yml -n graphium

# Install Graphium in dev mode
mamba activate graphium
pip install --no-deps -e .
```

### For IPU developers
```bash
# Install Graphcore's SDK and Graphium dependencies in a new environment called `.graphium_ipu`
./install_ipu.sh .graphium_ipu
```

The above step needs to be done once. After that, enable the SDK and the environment as follows:

```bash
source enable_ipu.sh .graphium_ipu
```

## Training a model

To learn how to train a model, we invite you to look at the documentation, or the jupyter notebooks available [here](https://github.com/datamol-io/graphium/tree/master/docs/tutorials/model_training).

If you are not familiar with [PyTorch](https://pytorch.org/docs) or [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/), we highly recommend going through their tutorial first.

## Running an experiment
We have setup Graphium with `hydra` for managing config files. To run an experiment go to the `expts/` folder. For example, to benchmark a GCN on the ToyMix dataset run
```bash
graphium-train architecture=toymix tasks=toymix training=toymix model=gcn
```
To change parameters specific to this experiment like switching from `fp16` to `fp32` precision, you can either override them directly in the CLI via
```bash
graphium-train architecture=toymix tasks=toymix training=toymix model=gcn trainer.trainer.precision=32
```
or change them permanently in the dedicated experiment config under `expts/hydra-configs/toymix_gcn.yaml`.
Integrating `hydra` also allows you to quickly switch between accelerators. E.g., running
```bash
graphium-train architecture=toymix tasks=toymix training=toymix model=gcn accelerator=gpu
```
automatically selects the correct configs to run the experiment on GPU.
Finally, you can also run a fine-tuning loop:
```bash
graphium-train +finetuning=admet
```

To use a config file you built from scratch you can run
```bash
graphium-train --config-path [PATH] --config-name [CONFIG]
```
Thanks to the modular nature of `hydra` you can reuse many of our config settings for your own experiments with Graphium.

## Preparing the data in advance
The data preparation including the featurization (e.g., of molecules from smiles to pyg-compatible format) is embedded in the pipeline and will be performed when executing `graphium-train [...]`.

However, when working with larger datasets, it is recommended to perform data preparation in advance using a machine with sufficient allocated memory (e.g., ~400GB in the case of `LargeMix`). Preparing data in advance is also beneficial when running lots of concurrent jobs with identical molecular featurization, so that resources aren't wasted and processes don't conflict reading/writing in the same directory.

The following command-line will prepare the data and cache it, then use it to train a model.
```bash
# First prepare the data and cache it in `path_to_cached_data`
graphium data prepare ++datamodule.args.processed_graph_data_path=[path_to_cached_data]

# Then train the model on the prepared data
graphium-train [...] datamodule.args.processed_graph_data_path=[path_to_cached_data]
```

**Note** that `datamodule.args.processed_graph_data_path` can also be specified at `expts/hydra_configs/`.

**Note** that, every time the configs of `datamodule.args.featurization` changes, you will need to run a new data preparation, which will automatically be saved in a separate directory that uses a hash unique to the configs.

## License

Under the Apache-2.0 license. See [LICENSE](LICENSE).

## Documentation

- Diagram for data processing in Graphium.

<img src="docs/images/datamodule.png" alt="Data Processing Chart" width="60%" height="60%">

- Diagram for Muti-task network in Graphium

<img src="docs/images/full_graph_network.png" alt="Full Graph Multi-task Network" width="80%" height="80%">


================================================
FILE: cleanup_files.sh
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.

Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.

Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""


# Delete saved files like wandb, checkpoints, profiling etc.
# Usage: bash cleanup_files.sh
rm -rf wandb/*
rm -rf checkpoints/*
rm -rf pro_vision/*
rm -rf outputs/*
rm -rf models_checkpoints/*
rm -rf datacache/*


================================================
FILE: codecov.yml
================================================
coverage:
  range: "50...80"
  status:
    project:
      default:
        threshold: 1%
    patch: false

comment:
  layout: "header, diff, flags, components" # show component info in the PR comment

component_management:
  default_rules: # default rules that will be inherited by all components
    statuses:
      - type: project # in this case every component that doens't have a status defined will have a project type one
        target: auto
        branches:
          - "!main"
  individual_components:
    - component_id: ipu # this is an identifier that should not be changed
      name: ipu # this is a display name, and can be changed freely
      paths:
        - graphium/ipu/**


================================================
FILE: docs/_assets/css/custom-graphium.css
================================================
:root {
  --graphium-primary: #548235;
  --graphium-secondary: #343a40;

  /* Primary color shades */
  --md-primary-fg-color: var(--graphium-primary);
  --md-primary-fg-color--light: var(--graphium-primary);
  --md-primary-fg-color--dark: var(--graphium-primary);
  --md-primary-bg-color: var(--graphium-secondary);
  --md-primary-bg-color--light: var(--graphium-secondary);
  --md-text-link-color: var(--graphium-secondary);

  /* Accent color shades */
  --md-accent-fg-color: var(--graphium-secondary);
  --md-accent-fg-color--transparent: var(--graphium-secondary);
  --md-accent-bg-color: var(--graphium-secondary);
  --md-accent-bg-color--light: var(--graphium-secondary);
}

:root > * {
  /* Code block color shades */
  --md-code-bg-color: hsla(0, 0%, 96%, 1);
  --md-code-fg-color: hsla(200, 18%, 26%, 1);

  /* Footer */
  --md-footer-bg-color: var(--graphium-primary);
  /* --md-footer-bg-color--dark: hsla(0, 0%, 0%, 0.32); */
  --md-footer-fg-color: var(--graphium-secondary);
  --md-footer-fg-color--light: var(--graphium-secondary);
  --md-footer-fg-color--lighter: var(--graphium-secondary);
}

.md-header {
  background-image: linear-gradient(to right, #548235, #8DB771);
}

.md-footer {
  background-image: linear-gradient(to right, #548235, #8DB771);
}

.md-tabs {
  background-image: linear-gradient(to right, #f4f6f9, #e2cec3);
}

.md-header__topic {
  color: #fff;
}

.md-source__repository,
.md-source__icon,
.md-search__input,
.md-search__input::placeholder,
.md-search__input ~ .md-search__icon,
.md-footer__inner.md-grid,
.md-copyright__highlight,
.md-copyright,
.md-footer-meta.md-typeset a,
.md-version {
  color: #fff !important;
}

.md-search__form {
  background-color: rgba(255, 255, 255, 0.2);
}

.md-search__input {
  color: #222 !important;
}

.md-header__topic {
  color: #fff;
  font-size: 1.4em;
}

/* Increase the size of the logo */
.md-header__button.md-logo img,
.md-header__button.md-logo svg {
  height: 2rem !important;
}

/* Reduce the margin around the logo */
.md-header__button.md-logo {
  margin: 0.4em;
  padding: 0.4em;
}

/* Remove the `In` and `Out` block in rendered Jupyter notebooks */
.md-container .jp-Cell-outputWrapper .jp-OutputPrompt.jp-OutputArea-prompt,
.md-container .jp-Cell-inputWrapper .jp-InputPrompt.jp-InputArea-prompt {
  display: none !important;
}


================================================
FILE: docs/_assets/css/custom.css
================================================
/* Indentation. */
div.doc-contents:not(.first) {
  padding-left: 25px;
  border-left: 4px solid #e6e6e6;
  margin-bottom: 80px;
}

/* Don't capitalize names. */
h5.doc-heading {
  text-transform: none !important;
}

/* Don't use vertical space on hidden ToC entries. */
.hidden-toc::before {
  margin-top: 0 !important;
  padding-top: 0 !important;
}

/* Don't show permalink of hidden ToC entries. */
.hidden-toc a.headerlink {
  display: none;
}

/* Avoid breaking parameters name, etc. in table cells. */
td code {
  word-break: normal !important;
}

/* For pieces of Markdown rendered in table cells. */
td p {
  margin-top: 0 !important;
  margin-bottom: 0 !important;
}


================================================
FILE: docs/_assets/js/google-analytics.js
================================================
var gtag_id = "G-K76VNHBRM4";

var script = document.createElement("script");
script.src = "https://www.googletagmanager.com/gtag/js?id=" + gtag_id;
document.head.appendChild(script);

window.dataLayer = window.dataLayer || [];
function gtag() {
  dataLayer.push(arguments);
}
gtag("js", new Date());
gtag("config", gtag_id);


================================================
FILE: docs/api/graphium.config.md
================================================
graphium.config
====================
helper functions to load and convert config files

::: graphium.config._loader
::: graphium.config.config_convert
::: graphium.config._load


================================================
FILE: docs/api/graphium.data.md
================================================
graphium.data
====================
module for loading datasets and collating batches

=== "Contents"

    * [Data Module](#data-module)
    * [Collate Module](#collate-module)
    * [Util Functions](#util-functions)

## Data Module
------------
::: graphium.data.datamodule


## Collate Module
------------
::: graphium.data.collate


## Util Functions
------------
::: graphium.data.utils


================================================
FILE: docs/api/graphium.features.md
================================================
graphium.features
====================
Feature extraction and manipulation

=== "Contents"

    * [Featurizer](#featurizer)
    * [Positional Encoding](#positional-encoding)
    * [Properties](#properties)
    * [Spectral PE](#spectral-pe)
    * [Random Walk PE](#random-walk-pe)
    * [NMP](#nmp)

## Featurizer
------------
::: graphium.features.featurizer


## Positional Encoding
------------
::: graphium.features.positional_encoding


## Properties
------------
::: graphium.features.properties


## Spectral PE
------------
::: graphium.features.spectral


## Random Walk PE
------------
::: graphium.features.rw


## NMP
------------
::: graphium.features.nmp


================================================
FILE: docs/api/graphium.finetuning.md
================================================
graphium.finetuning
====================
Module for finetuning models and doing linear probing (fingerprinting).

::: graphium.finetuning.finetuning.GraphFinetuning

::: graphium.finetuning.finetuning_architecture.FullGraphFinetuningNetwork

::: graphium.finetuning.finetuning_architecture.PretrainedModel

::: graphium.finetuning.finetuning_architecture.FinetuningHead

::: graphium.finetuning.fingerprinting.Fingerprinter


================================================
FILE: docs/api/graphium.ipu.md
================================================
graphium.ipu
====================
Code for adapting to run on IPU

=== "Contents"

    * [IPU Dataloader](#ipu-dataloader)
    * [IPU Losses](#ipu-losses)
    * [IPU Metrics](#ipu-metrics)
    * [IPU Simple Lightning](#ipu-simple-lightning)
    * [IPU Utils](#ipu-utils)
    * [IPU Wrapper](#ipu-wrapper)
    * [To Dense Batch](#to-dense-batch)

## IPU Dataloader
------------
::: graphium.ipu.ipu_dataloader


## IPU Losses
------------
::: graphium.ipu.ipu_losses


## IPU Metrics
------------
::: graphium.ipu.ipu_metrics


## IPU Simple Lightning
------------
::: graphium.ipu.ipu_simple_lightning


## IPU Utils
------------
::: graphium.ipu.ipu_utils


## IPU Wrapper
------------
::: graphium.ipu.ipu_wrapper


## To Dense Batch
------------
::: graphium.ipu.to_dense_batch



================================================
FILE: docs/api/graphium.nn/architectures.md
================================================
graphium.nn.architectures
====================

High level architectures in the library

=== "Contents"

    * [Global Architectures](#global-architectures)
    * [PyG Architectures](#pyg-architectures)
    * [Encoder Manager](#encoder-manager)


## Global Architectures
------------
::: graphium.nn.architectures.global_architectures


## PyG Architectures
------------
::: graphium.nn.architectures.pyg_architectures


## Encoder Manager
------------
::: graphium.nn.architectures.encoder_manager


================================================
FILE: docs/api/graphium.nn/encoders.md
================================================
graphium.nn.encoders
====================

Implementations of positional encoders in the library

=== "Contents"

    * [Base Encoder](#base-encoder)
    * [Gaussian Kernal Positional Encoder](#gaussian-kernal-positional-encoder)
    * [Laplacian Positional Encoder](#laplacian-positional-encoder)
    * [MLP Encoder](#mlp-encoder)
    * [Signnet Positional Encoder](#signnet-positional-encoder)


## Base Encoder
------------
::: graphium.nn.encoders.base_encoder


## Gaussian Kernal Positional Encoder
------------
::: graphium.nn.encoders.gaussian_kernel_pos_encoder


## Laplacian Positional Encoder
------------
::: graphium.nn.encoders.laplace_pos_encoder


## MLP Encoder
------------
::: graphium.nn.encoders.mlp_encoder


## Signnet Positional Encoder
------------
::: graphium.nn.encoders.signnet_pos_encoder


================================================
FILE: docs/api/graphium.nn/graphium.nn.md
================================================
graphium.nn
====================

Base structures for neural networks in the library (to be inherented by other modules)

=== "Contents"

    * [Base Graph Layer](#base-graph-layer)
    * [Base Layers](#base-layers)
    * [Residual Connection](#residual-connections)

## Base Graph Layer
------------
::: graphium.nn.base_graph_layer


## Base Layers
------------
::: graphium.nn.base_layers


## Residual Connections
------------
::: graphium.nn.residual_connections



================================================
FILE: docs/api/graphium.nn/pyg_layers.md
================================================
graphium.nn.pyg_layers
====================
Implementations of GNN layers based on PyG in the library

=== "Contents"

    * [Gated GCN Layer](#gated-gcn-layer)
    * [GPS Layer](#gps-layer)
    * [GIN and GINE Layers](#gin-and-gine-layers)
    * [MPNN Layer](#mpnn-layer)
    * [PNA Layer](#pna-layer)
    * [Pooling Layer](#pooling-layers)

## Gated GCN Layer
------------
::: graphium.nn.pyg_layers.gated_gcn_pyg


## GPS Layer
------------
::: graphium.nn.pyg_layers.gps_pyg


## GIN and GINE Layers
------------
::: graphium.nn.pyg_layers.gin_pyg


## MPNN Layer
------------
::: graphium.nn.pyg_layers.mpnn_pyg


## PNA Layer
------------
::: graphium.nn.pyg_layers.pna_pyg


## Pooling Layers
------------
::: graphium.nn.pyg_layers.pooling_pyg


## Utils
------------
::: graphium.nn.pyg_layers.utils


================================================
FILE: docs/api/graphium.trainer.md
================================================
graphium.trainer
====================

Code for training models

=== "Contents"
    * [Predictor](#predictor)
    * [Metrics](#metrics)
    * [Predictor Summaries](#predictor-summaries)
    * [Predictor Options](#predictor-options)


## Predictor
------------
::: graphium.trainer.predictor


## Metrics
------------
::: graphium.trainer.metrics


## Predictor Summaries
------------
::: graphium.trainer.predictor_summaries


## Predictor Options
------------
::: graphium.trainer.predictor_options




================================================
FILE: docs/api/graphium.utils.md
================================================
graphium.utils
====================
module for utility functions

=== "Contents"
    * [Argument Checker](#argument-checker)
    * [Decorators](#decorators)
    * [Dictionary of Tensors](#dictionary-of-tensors)
    * [File System](#file-system)
    * [Hashing](#hashing)
    * [Moving Average Tracker](#moving-average-tracker)
    * [MUP](#mup)
    * [Read File](#read-file)
    * [Safe Run](#safe-run)
    * [Spaces](#spaces)
    * [Tensor](#tensor)


## Argument Checker
----------------
::: graphium.utils.arg_checker


## Decorators
----------------
::: graphium.utils.decorators


## File System
----------------
::: graphium.utils.fs


## Hashing
----------------
::: graphium.utils.hashing


## Moving Average Tracker
----------------
::: graphium.utils.moving_average_tracker


## MUP
----------------
::: graphium.utils.mup


## Read File
----------------
::: graphium.utils.read_file

## Safe Run
----------------
::: graphium.utils.safe_run


## Spaces
----------------
::: graphium.utils.spaces


## Tensor
----------------
::: graphium.utils.tensor


================================================
FILE: docs/baseline.md
================================================
# ToyMix Baseline - Test set metrics

From the paper to be released soon. Below, you can see the baselines for the `ToyMix` dataset, a multitasking dataset comprising of `QM9`, `Zinc12k` and `Tox21`. The datasets and their splits are available on [this link](https://zenodo.org/record/7998401). The following baselines are all for models with ~150k parameters.

One can observe that the smaller datasets (`Zinc12k` and `Tox21`) beneficiate from adding another unrelated task (`QM9`), where the labels are computed from DFT simulations.

**NEW baselines added 2023/09/18**: Multitask baselines have been added for GatedGCN and MPNN++ (sum aggretator) using 3 random seeds. They achieve the best performance by a significant margin on Zinc12k and Tox21, while sacrificing a little on QM9.

| Dataset   | Model | MAE ↓     | Pearson ↑ | R² ↑     | MAE ↓   | Pearson ↑ | R² ↑   |
|-----------|-------|-----------|-----------|-----------|---------|-----------|---------|
|    | <th colspan="3" style="text-align: center;">Single-Task Model</th>  <th colspan="3" style="text-align: center;">Multi-Task Model</th>   |
|
| **QM9**   | GCN   | 0.102 ± 0.0003 | 0.958 ± 0.0007 | 0.920 ± 0.002 | 0.119 ± 0.01 | 0.955 ± 0.001 | 0.915 ± 0.001 |
|           | GIN   | 0.0976 ± 0.0006 | **0.959 ± 0.0002** | **0.922 ± 0.0004** | 0.117 ± 0.01 | 0.950 ± 0.002 | 0.908 ± 0.003 |
|           | GINE  | **0.0959 ± 0.0002** | 0.955 ± 0.002 | 0.918 ± 0.004 | 0.102 ± 0.01 | 0.956 ± 0.0009 | 0.918 ± 0.002 |
|       |   GatedGCN    |       |       |       | 0.1212 ± 0.0009 | 0.9457 ± 0.0002 | 0.8964 ± 0.0006 |
|       |   MPNN++ (sum)    |       |       |    | 0.1174 ± 0.0012 | 0.9460 ± 0.0005 | 0.8989 ± 0.0008 |
 **Zinc12k** | GCN   | 0.348 ± 0.02 | 0.941 ± 0.002 | 0.863 ± 0.01 | 0.226 ± 0.004 | 0.973 ± 0.0005 | 0.940 ± 0.003 |
|           | GIN   | 0.303 ± 0.007 | 0.950 ± 0.003 | 0.889 ± 0.003 | 0.189 ± 0.004 | 0.978 ± 0.006 | 0.953 ± 0.002 |
|           | GINE  | 0.266 ± 0.02 | 0.961 ± 0.003 | 0.915 ± 0.01 | 0.147 ± 0.009 | 0.987 ± 0.001 | 0.971 ± 0.003 |
|       | GatedGCN       |       |       |       | 0.1282 ± 0.0045 | 0.9850 ± 0.0006 | 0.9639 ± 0.0024 |
|       | MPNN++ (sum)   |       |       |       | **0.1002 ± 0.0025** | **0.9909 ± 0.0004** | **0.9777 ± 0.0014** |

|           |       | BCE ↓     | AUROC ↑ | AP ↑     | BCE ↓   | AUROC ↑ | AP ↑   |
|-----------|-------|-----------|-----------|-----------|---------|-----------|---------|
|    | <th colspan="3" style="text-align: center;">Single-Task Model</th>  <th colspan="3" style="text-align: center;">Multi-Task Model</th>   |
|
| **Tox21**   | GCN   | 0.202 ± 0.005 | 0.773 ± 0.006 | 0.334 ± 0.03 | 0.176 ± 0.001 | 0.850 ± 0.006 | 0.446 ± 0.01 |
|           | GIN   | 0.200 ± 0.002 | 0.789 ± 0.009 | 0.350 ± 0.01 | 0.176 ± 0.001 | 0.841 ± 0.005 | 0.454 ± 0.009 |
|           | GINE  | 0.201 ± 0.007 | 0.783 ± 0.007 | 0.345 ± 0.02 | 0.177 ± 0.0008 | 0.836 ± 0.004 | 0.455 ± 0.008 |
|       | GatedGCN       |       |       |       | 0.1733 ± 0.0015 | 0.8522 ± 0.0022 | **0.4620 ± 0.0118** |
|       | MPNN++ (sum)   |       |       |       | **0.1725 ± 0.0012** | **0.8569 ± 0.0005** | 0.4598 ± 0.0044 |


# LargeMix Baseline
## LargeMix test set metrics

From the paper to be released soon. Below, you can see the baselines for the `LargeMix` dataset, a multitasking dataset comprising of `PCQM4M_N4`, `PCQM4M_G25`, `PCBA_1328`, `L1000_VCAP`, and `L1000_MCF7`. The datasets and their splits are available on [this link](https://zenodo.org/record/7998401). The following baselines are all for models with 4-6M parameters.

One can observe that the smaller datasets (`L1000_VCAP` and `L1000_MCF7`) beneficiate tremendously from the multitasking. Indeed, the lack of molecular samples means that it is very easy for a model to overfit.

While `PCQM4M_G25` has no noticeable changes, the node predictions of `PCQM4M_N4` and assay predictions of `PCBA_1328` take a hit, but it is most likely due to underfitting since the training loss is also increased. It seems that 4-6M parameters is far from sufficient to capturing all of the tasks simultaneously, which motivates the need for a larger model.

| Dataset   | Model | MAE ↓     | Pearson ↑ | R² ↑     | MAE ↓   | Pearson ↑ | R² ↑   |
|-----------|-------|-----------|-----------|-----------|---------|-----------|---------|
|    | <th colspan="3" style="text-align: center;">Single-Task Model</th>  <th colspan="3" style="text-align: center;">Multi-Task Model</th>   |
|
| **Pcqm4m_g25** | GCN | 0.2362 ± 0.0003 | 0.8781 ± 0.0005 | 0.7803 ± 0.0006 | 0.2458 ± 0.0007 | 0.8701 ± 0.0002 | 0.8189 ± 0.0004 |
|               | GIN | 0.2270 ± 0.0003 | 0.8854 ± 0.0004 | 0.7912 ± 0.0006 | 0.2352 ± 0.0006 | 0.8802 ± 0.0007 | 0.7827 ± 0.0005 |
|               | GINE| **0.2223 ± 0.0007** | **0.8874 ± 0.0003** | **0.7949 ± 0.0001** | 0.2315 ± 0.0002 | 0.8823 ± 0.0002 | 0.7864 ± 0.0008 |
| **Pcqm4m_n4** | GCN | 0.2080 ± 0.0003 | 0.5497 ± 0.0010 | 0.2942 ± 0.0007 | 0.2040 ± 0.0001 | 0.4796 ± 0.0006 | 0.2185 ± 0.0002 |
|               | GIN | 0.1912 ± 0.0027 | **0.6138 ± 0.0088** | **0.3688 ± 0.0116** | 0.1966 ± 0.0003 | 0.5198 ± 0.0008 | 0.2602 ± 0.0012 |
|               | GINE| **0.1910 ± 0.0001** | 0.6127 ± 0.0003 | 0.3666 ± 0.0008 | 0.1941 ± 0.0003 | 0.5303 ± 0.0023 | 0.2701 ± 0.0034 |


|           |       | BCE ↓     | AUROC ↑ | AP ↑     | BCE ↓   | AUROC ↑ | AP ↑   |
|-----------|-------|-----------|-----------|-----------|---------|-----------|---------|
|    | <th colspan="3" style="text-align: center;">Single-Task Model</th>  <th colspan="3" style="text-align: center;">Multi-Task Model</th>   |
| <hi> | <hi> | <hi> | <hi> | <hi> | <hi> | <hi> | <hi> |
| **Pcba\_1328**    | GCN      | **0.0316 ± 0.0000** | **0.7960 ± 0.0020** | **0.3368 ± 0.0027** | 0.0349 ± 0.0002 | 0.7661 ± 0.0031 | 0.2527 ± 0.0041 |
|               | GIN      | 0.0324 ± 0.0000 | 0.7941 ± 0.0018 | 0.3328 ± 0.0019 | 0.0342 ± 0.0001 | 0.7747 ± 0.0025 | 0.2650 ± 0.0020 |
|               | GINE      | 0.0320 ± 0.0001 | 0.7944 ± 0.0023 | 0.3337 ± 0.0027 | 0.0341 ± 0.0001 | 0.7737 ± 0.0007 | 0.2611 ± 0.0043 |
| **L1000\_vcap**   | GCN      | 0.1900 ± 0.0002 | 0.5788 ± 0.0034 | 0.3708 ± 0.0007 | 0.1872 ± 0.0020 | 0.6362 ± 0.0012 | 0.4022 ± 0.0008 |
|               | GIN      | 0.1909 ± 0.0005 | 0.5734 ± 0.0029 | 0.3731 ± 0.0014 | 0.1870 ± 0.0010 | 0.6351 ± 0.0014 | 0.4062 ± 0.0001 |
|               | GINE      | 0.1907 ± 0.0006 | 0.5708 ± 0.0079 | 0.3705 ± 0.0015 | **0.1862 ± 0.0007** | **0.6398 ± 0.0043** | **0.4068 ± 0.0023** |
| **L1000\_mcf7**   | GCN      | 0.1869 ± 0.0003 | 0.6123 ± 0.0051 | 0.3866 ± 0.0010 | 0.1863 ± 0.0011 | **0.6401 ± 0.0021** | 0.4194 ± 0.0004 |
|               | GIN      | 0.1862 ± 0.0003 | 0.6202 ± 0.0091 | 0.3876 ± 0.0017 | 0.1874 ± 0.0013 | 0.6367 ± 0.0066 | **0.4198 ± 0.0036** |
|               | GINE      | **0.1856 ± 0.0005** | 0.6166 ± 0.0017 | 0.3892 ± 0.0035 | 0.1873 ± 0.0009 | 0.6347 ± 0.0048 | 0.4177 ± 0.0024 |

## LargeMix training set loss

Below is the loss on the training set. One can observe that the multi-task model always underfits the single-task, except on the two `L1000` datasets.

This is not surprising as they contain two orders of magnitude more datapoints and pose a significant challenge for the relatively small models used in this analysis. This favors the Single dataset setup (which uses a model of the same size) and we conjecture larger models to bridge this gap moving forward.

|            |       | CE or MSE loss in single-task $\downarrow$ | CE or MSE loss in multi-task $\downarrow$ |
|------------|-------|-----------------------------------------|-----------------------------------------|
|
| **Pcqm4m\_g25**    | GCN   | **0.2660 ± 0.0005** | 0.2767 ± 0.0015 |
|             | GIN   | **0.2439 ± 0.0004** | 0.2595 ± 0.0016 |
|             | GINE  | **0.2424 ± 0.0007** | 0.2568 ± 0.0012 |
|
| **Pcqm4m\_n4**    | GCN   | **0.2515 ± 0.0002** | 0.2613 ± 0.0008 |
|             | GIN   | **0.2317 ± 0.0003** | 0.2512 ± 0.0008 |
|             | GINE  | **0.2272 ± 0.0001** | 0.2483 ± 0.0004 |
|
| **Pcba\_1328**    | GCN   | **0.0284 ± 0.0010** | 0.0382 ± 0.0005 |
|             | GIN   | **0.0249 ± 0.0017** | 0.0359 ± 0.0011 |
|             | GINE  | **0.0258 ± 0.0017** | 0.0361 ± 0.0008 |
|
| **L1000\_vcap**   | GCN   | 0.1906 ± 0.0036 | **0.1854 ± 0.0148** |
|             | GIN   | 0.1854 ± 0.0030 | **0.1833 ± 0.0185** |
|             | GINE  | **0.1860 ± 0.0025** | 0.1887 ± 0.0200 |
|
| **L1000\_mcf7**   | GCN   | 0.1902 ± 0.0038 | **0.1829 ± 0.0095** |
|             | GIN   | 0.1873 ± 0.0033 | **0.1701 ± 0.0142** |
|             | GINE  | 0.1883 ± 0.0039 | **0.1771 ± 0.0010** |

## NEW: Largemix improved sweep - 2023/08-18

Unsatisfied with the prior results, we ran a bayesian search over a broader set of parameters, and including only more expressive models, namely GINE, GatedGCN and MPNN++. We further increase the number of parameters to 10M due to evidence of underfitting. We evaluate only the multitask setting.

We observe a significant improvement over all tasks, with a very notable r2-score increase of +0.53 (0.27 -> 0.80) compared to the best node-level property prediction on PCQM4M_N4.

The results are reported below over 1 seed. We are currently running more seeds of the same models.

| Dataset       | Model          | MAE ↓     | Pearson ↑ | R² ↑     |
|---------------|----------------|--------|---------|--------|
| **PCQM4M_G25**    | GINE           | 0.2250 | 0.8840  | 0.7911 |
|               | GatedGCN       | 0.2457 | 0.8698  | 0.7688 |
|               | MPNN++ (sum)   | 0.2269 | 0.8802  | 0.7855 |
|
| **PCQM4M_N4**     | GINE           | 0.2699 | 0.8475  | 0.7182 |
|               | GatedGCN       | 0.3337 | 0.8102  | 0.6566 |
|               | MPNN++ (sum)   | 0.2114 | 0.8942  | 0.8000 |

| Dataset       | Model          | BCE ↓     | AUROC ↑ | AP ↑     |
|---------------|----------------|--------|---------|--------|
| **PCBA_1328**     | GINE           | 0.0334 | 0.7879  | 0.2808 |
|               | GatedGCN       | 0.0351 | 0.7788  | 0.2611 |
|               | MPNN++ (sum)   | 0.0344 | 0.7815  | 0.2666 |
|
| **L1000_VCAP**    | GINE           | 0.1907 | 0.6416  | 0.4042 |
|               | GatedGCN       | 0.1866 | 0.6395  | 0.4092 |
|               | MPNN++ (sum)   | 0.1867 | 0.6478  | 0.4131 |
|
| **L1000_MCF7**    | GINE           | 0.1931 | 0.6352  | 0.4235 |
|               | GatedGCN       | 0.1859 | 0.6547  | 0.4224 |
|               | MPNN++ (sum)   | 0.1870 | 0.6593  | 0.4254 |



# UltraLarge Baseline

## UltraLarge test set metrics

For `UltraLarge`, we provide results for the same GNN baselines as for
`LargeMix`. Each model is trained for 50 epochs and results are averaged over 3 seeds. The remaining
setup is the same as for TOYMIX (Section E.1), reporting metrics on the Single Dataset and Multi Dataset using the same performance metrics. We further use the same models (in terms of size) as used for `LargeMix`.

For now, we report only the results for a subset representing 5% of the total dataset due to computational constraint, but aim to provide the full results soon.

Results discussion. `UltraLarge` results can be found in Table 6. Interestingly, on both graph- and node-level tasks we observe that there is no advantage of multi-tasking in terms of performance. We
expect that for this ultra-large dataset, significantly larger models are needed to successfully leverage the multi-task setup. This could be attributed to underfitting, as already demonstrated for `LargeMix`. Nonetheless, our baselines set the stage for large-scale pre-training on `UltraLarge`.

The results presented used approximately 500 GPU hours of compute, with
more compute used for development and hyperparameter search.

We further note that the graph-level tasks results are very strong. Regarding the node-level tasks, they are expected to underperform in low-parameters regime, due to clear signs of underfitting, a very large amount of labels to learn, and susceptibility to over-smoothing from traditional GNNs.


| Dataset          | Model | MAE ↓             | Pearson ↑         | R² ↑          | MAE ↓             | Pearson ↑         | R² ↑              |
|------------------|-------|-------------------|-------------------|-------------------|-------------------|-------------------|-------------------|
|    | <th colspan="3" style="text-align: center;">Single-Task Model</th>  <th colspan="3" style="text-align: center;">Multi-Task Model</th>   |
| <hi> | <hi> | <hi> | <hi> | <hi> | <hi> | <hi> | <hi> |
| **Pm6_83m_g62**      | GCN   | .2606 ± .0011     | .9004 ± .0003     | .7997 ± .0009       | .2625 ± .0011     | .8896 ± .0001     | .7982 ± .0001     |
|                  | GIN   | .2546 ± .0021     | .9051 ± .0019     | .8064 ± .0037       | .2562 ± .0000     | .8901 ± .0000     | .806 ± .0000      |
|                  | GINE  | **.2538 ± .0006** | **.9059 ± .0010** | **.8082 ± .0015**   | .258 ± .0011      | .904 ± .0000      | .8048 ± .0001     |
|
| **Pm6_83m_n7**       | GCN   | .5803 ± .0001     | .3372 ± .0004     | .1191 ± .0002      | .5971 ± .0002     | .3164 ± .0001     | .1019 ± .0011     |
|                  | GIN   | .573 ± .0002      | .3478 ± .0001     | **.1269 ± .0002**  | .5831 ± .0001     | .3315 ± .0005     | .1141 ± .0000     |
|                  | GINE  | **.572 ± .0004**  | **.3487 ± .0002** | .1266 ± .0001      | .5839 ± .0004     | .3294 ± .0002     | .1104 ± .0000     |

## UltraLarge training set loss

In the table below, we observe that the multi-task model slightly underfits the single-task model, indicating that parameters can be efficiently shared between the node-level and graph-level tasks. We further note that the training loss and the test MAE are almost equal for all tasks, indicating further benefits as we scale both the model and the data.

|                  |       | **MAE loss in single-task ↓** | **MAE loss in multi-task ↓** |
|------------------|-------|---------------------------|--------------------------|
|
| Pm6_83m_g62      | GCN   | **.2679 ± .0020**        | .2713 ± .0017           |
|                  | GIN   | **.2582 ± .0018**        | .2636 ± .0014           |
|                  | GINE  | **.2567 ± .0036**        | .2603 ± .0021           |
|
| Pm6_83m_n7       | GCN   | **.5818 ± .0021**        | .5955 ± .0023           |
|                  | GIN   | **.5707 ± .0019**        | .5851 ± .0038           |
|                  | GINE  | **.5724 ± .0015**        | .5832 ± .0027           |



================================================
FILE: docs/cli/graphium-train.md
================================================
# `graphium-train`

To support advanced configuration, Graphium uses [`hydra`](https://hydra.cc/) to manage and write config files. A limitation of `hydra`, is that it is designed to function as the main entrypoint for a CLI application and does not easily support subcommands. For that reason, we introduced the `graphium-train` command in addition to the [`graphium`](./graphium.md) command.

!!! info "Curious about the configs?"
    If you would like to learn more about the configs, please visit the docs [here](https://github.com/datamol-io/graphium/tree/main/expts/hydra-configs).

This page documents `graphium-train`.

## Running an experiment
We have setup Graphium with `hydra` for managing config files. To run an experiment go to the `expts/` folder. For example, to benchmark a GCN on the ToyMix dataset run
```bash
graphium-train architecture=toymix tasks=toymix training=toymix model=gcn
```
To change parameters specific to this experiment like switching from `fp16` to `fp32` precision, you can either override them directly in the CLI via
```bash
graphium-train architecture=toymix tasks=toymix training=toymix model=gcn trainer.trainer.precision=32
```
or change them permanently in the dedicated experiment config under `expts/hydra-configs/toymix_gcn.yaml`.
Integrating `hydra` also allows you to quickly switch between accelerators. E.g., running
```bash
graphium-train architecture=toymix tasks=toymix training=toymix model=gcn accelerator=gpu
```
automatically selects the correct configs to run the experiment on GPU.
Finally, you can also run a fine-tuning loop:
```bash
graphium-train +finetuning=admet
```

To use a config file you built from scratch you can run
```bash
graphium-train --config-path [PATH] --config-name [CONFIG]
```
Thanks to the modular nature of `hydra` you can reuse many of our config settings for your own experiments with Graphium.

### Preparing the data in advance
The data preparation including the featurization (e.g., of molecules from smiles to pyg-compatible format) is embedded in the pipeline and will be performed when executing `graphium-train [...]`.

However, when working with larger datasets, it is recommended to perform data preparation in advance using a machine with sufficient allocated memory (e.g., ~400GB in the case of `LargeMix`). Preparing data in advance is also beneficial when running lots of concurrent jobs with identical molecular featurization, so that resources aren't wasted and processes don't conflict reading/writing in the same directory.

The following command-line will prepare the data and cache it, then use it to train a model.
```bash
# First prepare the data and cache it in `path_to_cached_data`
graphium data prepare ++datamodule.args.processed_graph_data_path=[path_to_cached_data]

# Then train the model on the prepared data
graphium-train [...] datamodule.args.processed_graph_data_path=[path_to_cached_data]
```

??? note "Config vs. Override"
    As with any configuration, note that `datamodule.args.processed_graph_data_path` can also be specified in the configs at `expts/hydra_configs/`.

??? note "Featurization"
    Every time the configs of `datamodule.args.featurization` change, you will need to run a new data preparation, which will automatically be saved in a separate directory that uses a hash unique to the configs.

================================================
FILE: docs/cli/graphium.md
================================================
# `graphium`

**Usage**:

```console
$ graphium [OPTIONS] COMMAND [ARGS]...
```

**Options**:

* `--help`: Show this message and exit.

**Commands**:

* `data`: Graphium datasets.
* `finetune`: Utility CLI for extra fine-tuning utilities.

## `graphium data`

Graphium datasets.

**Usage**:

```console
$ graphium data [OPTIONS] COMMAND [ARGS]...
```

**Options**:

* `--help`: Show this message and exit.

**Commands**:

* `download`: Download a Graphium dataset.
* `list`: List available Graphium dataset.
* `prepare`: Prepare a Graphium dataset.

### `graphium data download`

Download a Graphium dataset.

**Usage**:

```console
$ graphium data download [OPTIONS] NAME OUTPUT
```

**Arguments**:

* `NAME`: [required]
* `OUTPUT`: [required]

**Options**:

* `--progress / --no-progress`: [default: progress]
* `--help`: Show this message and exit.

### `graphium data list`

List available Graphium dataset.

**Usage**:

```console
$ graphium data list [OPTIONS]
```

**Options**:

* `--help`: Show this message and exit.

### `graphium data prepare`

Prepare a Graphium dataset.

**Usage**:

```console
$ graphium data prepare [OPTIONS] OVERRIDES...
```

**Arguments**:

* `OVERRIDES...`: [required]

**Options**:

* `--help`: Show this message and exit.

## `graphium finetune`

Utility CLI for extra fine-tuning utilities.

**Usage**:

```console
$ graphium finetune [OPTIONS] COMMAND [ARGS]...
```

**Options**:

* `--help`: Show this message and exit.

**Commands**:

* `admet`: Utility CLI to easily fine-tune a model on...
* `fingerprint`: Endpoint for getting fingerprints from a...

### `graphium finetune admet`

Utility CLI to easily fine-tune a model on (a subset of) the benchmarks in the TDC ADMET group.

A major limitation is that we cannot use all features of the Hydra CLI, such as multiruns.

**Usage**:

```console
$ graphium finetune admet [OPTIONS] OVERRIDES...
```

**Arguments**:

* `OVERRIDES...`: [required]

**Options**:

* `--name TEXT`
* `--inclusive-filter / --no-inclusive-filter`: [default: inclusive-filter]
* `--help`: Show this message and exit.

### `graphium finetune fingerprint`

Endpoint for getting fingerprints from a pretrained model.

The pretrained model should be a `.ckpt` path or pre-specified, named model within Graphium.
The fingerprint layer specification should be of the format `module:layer`.
If specified as a list, the fingerprints from all the specified layers will be concatenated.
See the docs of the `graphium.finetuning.fingerprinting.Fingerprinter` class for more info.

**Usage**:

```console
$ graphium finetune fingerprint [OPTIONS] FINGERPRINT_LAYER_SPEC... PRETRAINED_MODEL SAVE_DESTINATION
```

**Arguments**:

* `FINGERPRINT_LAYER_SPEC...`: [required]
* `PRETRAINED_MODEL`: [required]
* `SAVE_DESTINATION`: [required]

**Options**:

* `--output-type TEXT`: Either numpy (.npy) or torch (.pt) output  [default: torch]
* `-o, --override TEXT`: Hydra overrides
* `--help`: Show this message and exit.


================================================
FILE: docs/cli/reference.md
================================================
# CLI Reference

Installing the Graphium library, makes two CLI tools available. 

- [`graphium-train`](./graphium-train.md) is the hydra endpoint - specifically meant for training, finetuning and testing. Since this uses `@hydra.main`, it has access to all advanced hydra functionality such as tab completion, multirun, working directory management, logging management. 
- [`graphium`](./graphium.md) is the more general CLI endpoint, organized with various sub commands. 

Ideally, we would've integrated both in a single CLI endpoint, but the hydra CLI cannot be a subcommand of another CLI, nor does it support easily adding subcommands, which is why provide two separate CLI tools with different purposes.

!!! note "Interactive, embedded CLI docs with `--help`"
    In addition to these pages, you can also use `graphium --help` and `graphium-train --help` to interactively navigate the documentation of these tools directly in the CLI.

================================================
FILE: docs/contribute.md
================================================
# Contribute

We are happy to see that you want to contribute 🤗.
Feel free to open an issue or pull request at any time. But first, follow this page to install Graphium in dev mode.

## Installation for developers

### For CPU and GPU developers

Use [`mamba`](https://github.com/mamba-org/mamba), a preferred alternative to conda, to create your environment:

```bash
# Install Graphium's dependencies in a new environment named `graphium`
mamba env create -f env.yml -n graphium

# Install Graphium in dev mode
mamba activate graphium
pip install --no-deps -e .
```

### For IPU developers

Download the SDK and use pypi to create your environment:

```bash
# Install Graphcore's SDK and Graphium dependencies in a new environment called `.graphium_ipu`
./install_ipu.sh .graphium_ipu
```

The above step needs to be done once. After that, enable the SDK and the environment as follows:

```bash
source enable_ipu.sh .graphium_ipu
```

## Build the documentation

You can build and serve the documentation locally with:

```bash
# Build and serve the doc
mkdocs serve
```


================================================
FILE: docs/datasets.md
================================================
# Graphium Datasets

Graphium datasets are hosted at on Zenodo 
- ***ToyMix*** and  ***LargeMix*** dataseets are hosted on [this link](https://doi.org/10.5281/zenodo.7998401)
- ***UltraLarge*** dataset is hosted on [this link](https://doi.org/10.5281/zenodo.8370547)

Instead of provinding datasets as a single entity, our aim is to provide dataset mixes containing a variety of datasets that are meant to be predicted simultaneously using multi-tasking.

They are visually described in this image, with detailed description below.
![Visual description of the ToyMix, LargeMix, UltraLarge datasets](dataset_abstract.png)

## ToyMix (QM9 + Tox21 + Zinc12K)

The ***ToyMix*** dataset combines the ***QM9***, ***Tox21***, and ***Zinc12K*** datasets. These datasets are well-known in the literature and used as toy datasets, or very simple datasets, in various contexts to enable fast iterations of models. By regrouping toy datasets from quantum ML, drug discovery, and GNN expressivity, we hope that the learned model will be representative of the model performance we can expect on the larger datasets.

### Train/Validation/Test Splits
for all the datasets in ***ToyMix*** are split randomly with a ratio of 0.8/0.1/0.1. Random splitting is used since it is the simplest and fits the idea of having a toy dataset well.

### QM9
is a well-known dataset in the field of 3D GNNs. It consists of 19 graph-level quantum properties associated to an energy-minimized 3D conformation of the molecules [1]. It is considered a simple dataset since all the molecules have at most 9 heavy atoms. We chose QM9 in our ***ToyMix*** since it is very similar to the larger proposed quantum datasets, PCQM4M\_multitask and PM6\_83M, but with smaller molecules.


### Tox21
is a well-known dataset for researchers in machine learning for drug discovery [2]. It consists of a multi-label classification task with 12 labels, with most labels missing and a strong imbalance towards the negative class. We chose ***Tox21*** in our ***ToyMix*** since it is very similar to the larger proposed bioassay dataset, ***PCBA\_1328\_1564k*** both in terms of sparsity and imbalance and to the ***L1000*** datasets in terms of imbalance.

### ZINC12k
is a well-known dataset for researchers in GNN expressivity [3]. We include it in our ***ToyMix*** since GNN expressivity is very important for performance on large-scale data. Hence, we hope that the performance on this task will correlate well with the performance when scaling.

## LargeMix (PCQM4M + PCBA1328 + L1000)
In this section, we present the ***LargeMix*** dataset, comprised of four different datasets with tasks taken from quantum chemistry (***PCQM4M***), bio-assays (***PCBA***) and transcriptomics.

### Train/validation/test/test\_seen
Splits For the ***PCQM4M\_G25\_N4***, we create a 0.92/0.04/0.04 split. Then, for all the other datasets in ***LargeMix***, we first create a "test\_seen" split by taking the set of molecules from ***L1000*** and ***PCBA1328*** that are also present in the training set of ***PCQM4M\_G25\_N4***, such that we can evaluate whether having the quantum properties of a molecule helps generalize for biological properties. For the remaining parts, we split randomly with a ratio of 0.92/0.04/0.04.



### L1000 VCAP and MCF7
The ***LINCS L1000*** is a database of high-throughput transcriptomics that screened more than 30,000 perturbations on a set of 978 landmark genes [4] from multiple cell lines. ***VCAP*** and ***MCF7*** are, respectively, prostate cancer and human breast cancer cell lines. In ***L1000***, most of the perturbagens are chemical, meaning that small drug-like molecules are added to the cell lines to observe how the gene expressions change. This allows to generate biological signatures of the molecules, which are known to correlate with drug activity and side effects.


To process the data into our two datasets comprising the ***VCAP*** and ***MCF7*** cell lines, we used their "level 5" data composed of the cleanup data converted to z-scores, and filtered to keep only chemical perturbagens. However, we were left with multiple data points per molecule since some variables could change (e.g., incubation time) and generate a new measure. Given our objective of generating a single signature per molecule, we decided to take the measurements with the strongest global activity such that the variance over the 978 genes is maximal. Then, since these signatures are generally noisy, we binned them into five classes corresponding to z-scores based on the thresholds $\{-4, -2, 2, 4\}$.

The cell lines ***VCAP*** and ***MCF7*** were selected since they have a higher number of unique molecule perturbagens than other cell lines. They also have a relatively lower data imbalance, with ~92% falling in the "neutral class" when the z-score was between -2 and 2.

### PCBA1328
This dataset is very similar to the ***OGBG-PCBA*** dataset [5], but instead of being limited to 128 assays and 437k molecules, it comprises 1,328 assays and 1.56M molecules. This dataset is very interesting for pre-training molecular models since it contains information about a molecule's behavior in various settings relevant to biochemists, with evidence that it improves binding predictions. Analogous to the gene expression, we obtain a bio-assay-expression of each molecule.

To gather the data, we have looped over the PubChem index of bioassays [6] and collected every dataset such that it contains more than 6,000 molecules annotated with either ``Active'' or ``Inactive'' and at least 10 of each. Then, we converted all the molecular IDs to canonical SMILES and used it to merge all of the bioassays into a single dataset.

### PCQM4M\_G25\_N4
This dataset comes from the same data source as the ***OGBG-PCQM4M*** dataset, famously known for being part of the OGB large-scale challenge [7] and being one of the only graph datasets where pure Transformers have proven successful. The data source is the PubChemQC project [8] that computed DFT properties on the energy-minimized conformation of 3.8M small molecules from PubChem.

Contrarily to the OGB challenge, we aim to provide enough data for pre-training GNNs, so we do not limit ourselves to the HOMO-LUMO gap prediction [7]. Instead, we gather properties directly given by the DFT (e.g., energies) and compute other 3D descriptors from the conformation (e.g., inertia, the plane of best fit). We also gather node-level properties, the Mulliken and Lowdin charges at each atom. Furthermore, about half of the molecules have time-dependent DFT to help inform about the molecule's excited state. Looking forward, we plan on adding edge-level tasks to enable the prediction of bond properties, such as their lengths and the gradient of the charges.


## UltraLarge Dataset
### PM6\_83M
This dataset is similar to the ***PCQM4M*** and comes from the same PubChemQC project. However, it uses the PM6 semi-empirical computation of the quantum properties, which is orders of magnitude faster than DFT computation at the expense of less accuracy [8, 9].

This dataset covers 83M unique molecules, 62 graph-level tasks, and 7 node-level tasks. To our knowledge, this is the largest dataset available for training 2D-GNNs regarding the number of unique molecules. The various tasks come from four different molecular states, namely ``S0'' for the ground state, ``T0'' for the lowest energy triplet excited state, ``cation'' for the positively charged state, and ``anion'' for the negatively charged state. In total, there are 221M PM6 computations.

## References
[1] https://www.nature.com/articles/sdata201422/

[2] https://europepmc.org/article/MED/23603828

[3] https://arxiv.org/abs/2003.00982v3

[4] https://pubmed.ncbi.nlm.nih.gov/29195078/

[5] https://arxiv.org/abs/2005.00687

[6] https://pubmed.ncbi.nlm.nih.gov/26400175/

[7] https://arxiv.org/abs/2103.09430

[8] https://pubs.acs.org/doi/10.1021/acs.jcim.7b00083

[9] https://arxiv.org/abs/1904.06046



================================================
FILE: docs/design.md
================================================
# Graphium Library Design

---


The library is designed with 3 things in mind:

- High modularity and configurability with *YAML* files
- Contain the state-of-the art GNNs, including positional encodings and graph Transformers
- Massively multitasking across diverse and sparse datasets

The current page will walk you through the different aspects of the design that enable that.

### Diagram for data processing in Graphium.

First, when working with molecules, there are tons of options regarding atomic and bond featurisation that can be extracted from the periodic table, from empirical results, or from simulated 3D structures.

Second, when working with graph Transformers, there are plenty of options regarding the positional and structural encodings (PSE) that are fundamental in driving the accuracy and the generalization of the models.

With this in mind, we propose a very versatile chemical and PSE encoding, alongside an encoder manager, that can be fully configured from the yaml files. The idea is to assign matching *input keys* to both the features and the encoders, then pool the outputs according to the *output keys*. It is better summarized in the image below.

<img src="images/datamodule.png" alt= "Data Processing Chart" width="100%" height="100%">



### Diagram for Muti-task network in Graphium

As mentioned, we want to be able to pperform massive multi-tasking to enable us to work across a huge diversity of datasets. The idea is to use a combination of multiple task-heads, where a different MLP is applied to each task predictions. However, it is also designed such that each task can have as many labels as desired, thus enabling to group labels together according to whether they should share weights/losses.

The design is better explained in the image below. Notice how the *keys* outputed by GraphDict are used differently across the different GNN layers.

<img src="images/full_graph_network.png" alt= "Full Graph Multi-task Network" width="100%" height="100%">

## Structure of the code

The code is built to rapidly iterate on different architectures of neural networks (NN) and graph neural networks (GNN) with Pytorch. The main focus of this work is molecular tasks, and we use the package `rdkit` to transform molecular SMILES into graphs.

Below are a list of directory and their respective documentations:

- cli
- [config](https://github.com/datamol-io/graphium/blob/main/graphium/config/README.md)
- [data](https://github.com/datamol-io/graphium/blob/main/graphium/data/README.md)
- [features](https://github.com/datamol-io/graphium/tree/main/graphium/features/README.md)
- finetuning
- [ipu](https://github.com/datamol-io/graphium/tree/main/graphium/ipu/README.md)
- [nn](https://github.com/datamol-io/graphium/tree/main/graphium/nn/README.md)
- [trainer](https://github.com/datamol-io/graphium/tree/main/graphium/trainer/README.md)
- [utils](https://github.com/datamol-io/graphium/tree/main/graphium/features/README.md)


## Structure of the configs

Making the library very modular requires to have configuration files that have >200 lines, which becomes intractable, especially when we only want to have minor changes between configurations.

Hence, we use [hydra](https://hydra.cc/docs/intro/) to enable splitting the configurations into smaller and composable configuration files.

Examples of possibilities include:

- Switching between accelerators (CPU, GPU and IPU)
- Benchmarking different models on the same dataset
- Fine-tuning a pre-trained model on a new dataset

[In this document](https://github.com/datamol-io/graphium/tree/main/expts/hydra-configs#readme), we describe in details how each of the above functionality is achieved and how users can benefit from this design to achieve the most with Graphium with as little configuration as possible.



================================================
FILE: docs/index.md
================================================
# Overview

A deep learning library focused on graph representation learning for real-world chemical tasks.

- ✅ State-of-the-art GNN architectures.
- 🐍 Extensible API: build your own GNN model and train it with ease.
- ⚗️ Rich featurization: powerful and flexible built-in molecular featurization.
- 🧠 Pretrained models: for fast and easy inference or transfer learning.
- ⮔ Read-to-use training loop based on [Pytorch Lightning](https://www.pytorchlightning.ai/).
- 🔌 Have a new dataset? Graphium provides a simple plug-and-play interface. Change the path, the name of the columns to predict, the atomic featurization, and you’re ready to play!

## Installation

### For CPU or GPU
Use [`mamba`](https://github.com/mamba-org/mamba):

```bash
# Install Graphium
mamba install -c conda-forge graphium
```

or pip:

```bash
pip install graphium
```

### For IPU
```bash
# Install Graphcore's SDK and Graphium dependencies in a new environment called `.graphium_ipu`
./install_ipu.sh .graphium_ipu
```

The above step needs to be done once. After that, enable the SDK and the environment as follows:

```bash
source enable_ipu.sh .graphium_ipu
```

Finally, you will need to install graphium with pip
```bash
pip install graphium
```


================================================
FILE: docs/license.md
================================================
```
{!LICENSE!}
```


================================================
FILE: docs/pretrained_models.md
================================================
# Graphium pretrained models

Graphium aims to provide a set of pretrained models that you can use for inference or transfer learning. The models will be made available once trained and validated.

## Listing all available models

To know which models are available, you can run the following command

```python
import graphium

print(graphium.trainer.PredictorModule.list_pretrained_models())
```

## Dummy pre-trained model

At the moment, only `tests/dummy-pretrained-model.ckpt` is provided, which is mostly useful for development and debugging of the checkpointing and finetuning pipelines.

You can load a pretrained models using the Graphium API:

```python
import graphium

predictor = graphium.trainer.PredictorModule.load_pretrained_models("dummy-pretrained-model")
```


================================================
FILE: docs/tutorials/feature_processing/add_new_positional_encoding.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Add new positional encoding\n",
    "\n",
    "One of the main advantage of this library is the ability to easily incorporate novel positional encodings on the node, edge and graph level. The positional encodings are computed and feed into respective encoders and then the hidden embeddings from all pe encoders are pooled (according to if they are node, edge, or graph level) and then feed into the GNN layers as features. The designs allow any combination of positional encodings to be used by modifying the configuration file. For more details on the data processing part, please visit the [design page of the doc](https://graphium-docs.datamol.io/stable/design.html).\n",
    "\n",
    "Here is the workflow for computing and processing positional encoding in the library:\n",
    "1. edit related parts in the yaml configuration file\n",
    "\n",
    "2. compute the raw positional encoding from the graph in [`graphium/features/positional_encoding.py`](https://graphium-docs.datamol.io/stable/api/graphium.features.html#graphium.features.positional_encoding) (from the [`graph positional encoder`](https://graphium-docs.datamol.io/stable/api/graphium.features.html#graphium.features.positional_encoding.graph_positional_encoder))\n",
    "\n",
    "3. feed the raw positional encoding into the respective (specialized) encoders in [`graphium/nn/encoders`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html). For example, a simple [`MLP positional encoder`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html#graphium.nn.encoders.mlp_encoder) can be found. \n",
    "\n",
    "4. Output the hidden embeddings of pe from the encoders in their respective output keys: `feat`(node feature), `edge_feat`(edge feature), `graph_feat`(graph feature) and potentially other keys if needed such as `nodepair_feat` \n",
    "\n",
    "5. pool the hidden embeddings with same keys together: for example, all output with `feat` key will be pooled together\n",
    "\n",
    "6. Construct the [`PyG Batch`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Batch.html#torch_geometric.data.Batch), batch of graphs, each contain the output keys seen above, ready for use in the [GNN layers](https://graphium-docs.datamol.io/stable/api/graphium.nn/pyg_layers.html) \n",
    "\n",
    "Since this library is built using PyG, we recommend looking at their [Docs](https://pytorch-geometric.readthedocs.io/en/latest/) and [Tutorials](https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html) for more info. \n",
    "\n",
    "We start by editing the configuration file first.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Table of content\n",
    "1. [edit the config file](#Edit-the-yaml-Configuration-File)\n",
    "2. [compute the pe from mol](#Compute-the-Positional-Encoding)\n",
    "3. [add existing encoder](#Add-Existing-Encoder)\n",
    "4. [add specialized encoder](#Add-Specialized-Encoder)\n",
    "5. [add the keys to spaces](#Add-the-Keys-to-Spaces)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Edit the yaml Configuration File\n",
    "\n",
    "### Computing Raw PE\n",
    "We will use the degree of each node as a positional encoding in this tutorial. \n",
    "First start with an existing yaml configuration file, you can find them in `expts/configs`\n",
    "\n",
    "We first look at where in the yaml file is the raw positional encodings computed. `deg_pos` is added as an example below. You can add relevant arguments for computing the positional encoding here as well such as `normalize` in the example."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "pos_encoding_as_features:\n",
    "    pos_types:\n",
    "      deg_pos: #example, degree centrality\n",
    "        pos_type: degree\n",
    "        normalize: False\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Specifying Encoders for the PE\n",
    "Now we want to specify arguments for the encoders associated with the pe"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "pe_encoders:\n",
    "    out_dim: 64\n",
    "    pool: \"sum\" #choice of pooling across multiple pe encoders\n",
    "    last_norm: None #\"batch_norm\", \"layer_norm\"\n",
    "    encoders: \n",
    "      deg_pos: #same name from the previous cell\n",
    "        encoder_type: \"mlp\" #or you can specify your own specialized encoder\n",
    "        input_keys: [\"degree\"] #same as the pos_type configured before\n",
    "        output_keys: [\"feat\"] #node feature\n",
    "        hidden_dim: 64\n",
    "        num_layers: 1\n",
    "        dropout: 0.1\n",
    "        normalization: \"none\"   #\"batch_norm\" or \"layer_norm\"\n",
    "        first_normalization: \"layer_norm\"   #\"batch_norm\" or \"layer_norm\"\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute the Positional Encoding\n",
    "Next, we want to compute the raw degree of each node from the molecule graph.\n",
    "\n",
    "### add function to compute the pe\n",
    "Go to [graphium/features](https://graphium-docs.datamol.io/stable/api/graphium.features.html) and add a new file `deg.py` to add the function to compute the pe. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Tuple, Union, Optional\n",
    "\n",
    "from scipy import sparse\n",
    "from scipy.sparse import spmatrix\n",
    "import numpy as np\n",
    "\n",
    "def compute_deg(adj: Union[np.ndarray, spmatrix], normalize: bool) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Compute the node degree positional encoding \n",
    "\n",
    "    Parameters:\n",
    "        adj: Adjacency matrix\n",
    "        normalize: indicate if the degree across all nodes are normalized to [0,1] or not\n",
    "    Returns:\n",
    "        2D array with shape (num_nodes, 1) specifying (outgoing) degree for each node\n",
    "    \"\"\"\n",
    "    \n",
    "    #first adj convert to scipy sparse matrix if not already\n",
    "    if type(adj) is np.ndarray:\n",
    "        adj = sparse.csr_matrix(adj)\n",
    "    \n",
    "    #https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.sum.html\n",
    "    degs = adj.sum(axis=0) #sum over each row\n",
    "    \n",
    "    if (normalize): #normalize the degree sequence to [0,1]\n",
    "        degs = degs / np.max(degs)\n",
    "    return degs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test with toy matrix\n",
    "\n",
    "here we will test if our code compute the degrees of each node correctly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "matrix([[1., 1., 1., 1., 1.]])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adj = np.identity(5) #make an identity matrix\n",
    "normalize = True\n",
    "\n",
    "degs = compute_deg(adj, normalize=normalize)\n",
    "\n",
    "degs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### add to positional_encoding.py\n",
    "\n",
    "To compute the new pe along with all existing pe, we need to add the function we wrote to [`graphium/feature/positional_encoding.py`](https://graphium-docs.datamol.io/stable/api/graphium.features.html#graphium.features.positional_encoding). Modify the [`graph_positional_encoder`](https://graphium-docs.datamol.io/stable/api/graphium.features.html#graphium.features.positional_encoding.graph_positional_encoder) function by adding `pos_type == \"degree\"` logic "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add Existing Encoder\n",
    "\n",
    "In order to pool over all the positional encodings, we need to add encoder to process the raw computed positional encoding and ensure the output dimension from all pe encoders are the same. When designing the encoder, you can either use an existing encoder or write a specialized encoder you made\n",
    "\n",
    "here we can simply specify [`MLPEncoder`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html#graphium.nn.encoders.mlp_encoder.MLPEncoder) in the yaml file and the library will automatically feed the raw positional encoding to a mlp encoder based on the input arguments. Note that in this example, the encoder takes in the pe stored at the input key `degree` and then outputs to the output key `feat`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "encoders: \n",
    "  deg_pos: \n",
    "    encoder_type: \"mlp\" \n",
    "    input_keys: [\"degree\"] \n",
    "    output_keys: [\"feat\"] # node feature\n",
    "    hidden_dim: 64\n",
    "    num_layers: 1\n",
    "    dropout: 0.1\n",
    "    normalization: \"none\"   #\"batch_norm\" or \"layer_norm\"\n",
    "    first_normalization: \"layer_norm\"   #\"batch_norm\" or \"layer_norm\"\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add Specialized Encoder\n",
    "\n",
    "You can also add specialized encoder, such as `laplacian_pe` for the laplacian eigenvectors and eigenvalues. Here, we can add a new `deg_pos_encoder.py` in [`graphium/nn/encoders`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html). As an example and template, please see the [`MLPEncoder`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html#graphium.nn.encoders.mlp_encoder.MLPEncoder)\n",
    "\n",
    "Note that all new encoders must inherent from [`BaseEncoder`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html#graphium.nn.encoders.base_encoder.BaseEncoder) class and implement the following abstract methods\n",
    "\n",
    "- [`forward`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html#graphium.nn.encoders.base_encoder.BaseEncoder.forward): the forward function of the encoder, how to process the input\n",
    "\n",
    "- [`parse_input_keys`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html#graphium.nn.encoders.base_encoder.BaseEncoder.parse_input_keys): how to parse the input keys\n",
    "\n",
    "- [`parse_output_keys`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html#graphium.nn.encoders.laplace_pos_encoder.LapPENodeEncoder.parse_output_keys): how to parse the output keys\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add the Keys to Spaces\n",
    "\n",
    "In order to directly find the correct encoders from the yaml file, we need to specify which key corresponding to what class. \n",
    "\n",
    "- add our new `deg_pos_encoder` to `graphium/utils/spaces.py` in the `PE_ENCODERS_DICT`\n",
    "- add our new `deg_pos_encoder` to [`graphium/nn/architectures/encoder_manager.py`](https://graphium-docs.datamol.io/stable/api/graphium.nn/architectures.html#graphium.nn.architectures.encoder_manager.EncoderManager) in the `PE_ENCODERS_DICT`\n",
    "- add the import of our encoder to  `graphium/nn/encoders/__init__.py`\n",
    "\n",
    "Now we can modify the yaml file to use our new encoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "encoders: \n",
    "  deg_pos: \n",
    "    encoder_type: \"deg_pos_encoder\" \n",
    "    input_keys: [\"degree\"] \n",
    "    output_keys: [\"feat\"] # node feature\n",
    "    hidden_dim: 64\n",
    "    #any other keys that might be used for initialization\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "f4a99d018a205fcbcc0480c84566beaebcb91b08d0414b39a842df533e2a1d25"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}


================================================
FILE: docs/tutorials/feature_processing/choosing_parallelization.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0693c5c9",
   "metadata": {},
   "source": [
    "# Choosing the right parallelization\n",
    "\n",
    "This tutorial is meant to help you benchmark different parallel processing methods for the processing of molecules into graphs. This will allow you to chose the one most suitable for your machine, since the benchmarks vary per machine.\n",
    "\n",
    "In general, we find that using `joblib` with the `loky` parallel processing and a batch size of `1000` is most beneficial. The logic is abstracted into `datamol.parallelized_with_batches`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b5df2ac6-2ded-4597-a445-f2b5fb106330",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import joblib\n",
    "\n",
    "import numpy as np\n",
    "import datamol as dm\n",
    "import pandas as pd\n",
    "\n",
    "# from pandarallel import pandarallel\n",
    "\n",
    "# pandarallel.initialize(progress_bar=True, nb_workers=joblib.cpu_count())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f81fe0f5-055f-436e-9ef4-e58a89fd50ec",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0f31e18d-bdd9-4d9b-8ba5-81e5887b857e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# download from https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv\n",
    "# data = pd.read_csv(\"/home/hadim/250k_rndm_zinc_drugs_clean_3.csv\", usecols=[\"smiles\"])\n",
    "\n",
    "# download from https://storage.valencelabs.com/graphium/datasets/QM9/norm_qm9.csv\n",
    "data = pd.read_csv(\"https://storage.valencelabs.com/graphium/datasets/QM9/norm_qm9.csv\", usecols=[\"smiles\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a1197c31-7dbc-4fd7-a69a-5215e1a96b8e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "rows_number_list = [250_000]\n",
    "batch_size_list = [10, 100, 1_000, 10_000]\n",
    "\n",
    "\n",
    "def smiles_to_unique_mol_id(smiles):\n",
    "    try:\n",
    "        mol = dm.to_mol(mol=smiles)\n",
    "        mol_id = dm.unique_id(mol)\n",
    "    except:\n",
    "        mol_id = \"\"\n",
    "    if mol_id is None:\n",
    "        mol_id = \"\"\n",
    "    return mol_id\n",
    "\n",
    "\n",
    "def smiles_to_unique_mol_id_batch(smiles_list):\n",
    "    mol_id_list = []\n",
    "    for smiles in smiles_list:\n",
    "        mol_id_list.append(smiles_to_unique_mol_id(smiles))\n",
    "    return mol_id_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d8eea8f-b983-46e7-bfb4-b8012b3ede1a",
   "metadata": {},
   "source": [
    "## Benchmarks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2f8ce5c3-4232-4279-8ea3-7a74832303be",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "benchmark = []"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "691f8963-6cac-4d58-b8a0-8049f1185486",
   "metadata": {},
   "source": [
    "### No batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a246cdcf-b5ea-4c9e-9ccc-dd3c544587bb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cc396220c7144c8d8b195fb87694bbfe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/133885 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for n in rows_number_list:\n",
    "    df = data.iloc[:n]\n",
    "\n",
    "    with dm.utils.perf.watch_duration(log=False) as d:\n",
    "        out = dm.parallelized(\n",
    "            smiles_to_unique_mol_id,\n",
    "            df[\"smiles\"].values,\n",
    "            progress=True,\n",
    "            n_jobs=-1,\n",
    "            scheduler=\"processes\",\n",
    "        )\n",
    "\n",
    "    datum = {\n",
    "        \"batch\": False,\n",
    "        \"batch_size\": None,\n",
    "        \"scheduler\": \"loky_processes\",\n",
    "        \"duration_minutes\": d.duration_minutes,\n",
    "        \"duration_seconds\": d.duration,\n",
    "        \"n_rows\": len(df),\n",
    "    }\n",
    "    benchmark.append(datum)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41139854-892f-4481-8dd6-a3972bb6ece0",
   "metadata": {},
   "source": [
    "### Batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "845a72d5-0b3f-4a21-8d7d-8312671e9924",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0187ede3dbd8429995511883319e246f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/13388 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fdd5ac6375b444359f6e8cef7b755b9b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1338 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c11c5894843a46848725d0c03fef9e02",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/133 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bd34e1a806604192bbd6f95ba6435b44",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/13 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for batch_size in batch_size_list:\n",
    "    for n in rows_number_list:\n",
    "        df = data.iloc[:n]\n",
    "\n",
    "        with dm.utils.perf.watch_duration(log=False) as d:\n",
    "            out = dm.parallelized_with_batches(\n",
    "                smiles_to_unique_mol_id_batch,\n",
    "                df[\"smiles\"].values,\n",
    "                batch_size=batch_size,\n",
    "                progress=True,\n",
    "                n_jobs=-1,\n",
    "                scheduler=\"processes\",\n",
    "            )\n",
    "        assert len(out) == len(df), f\"{len(out)} != {len(df)}\"\n",
    "\n",
    "        datum = {\n",
    "            \"batch\": True,\n",
    "            \"batch_size\": batch_size,\n",
    "            \"scheduler\": \"loky_processes\",\n",
    "            \"duration_minutes\": d.duration_minutes,\n",
    "            \"duration_seconds\": d.duration,\n",
    "            \"n_rows\": len(df),\n",
    "        }\n",
    "        benchmark.append(datum)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ec7529d1-2ca3-42ec-a811-3dc010a03350",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "928c2236948d4a109c79a6bc79bd4649",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=558), Label(value='0 / 558'))), HB…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for n in rows_number_list:\n",
    "    df = data.iloc[:n]\n",
    "\n",
    "    with dm.utils.perf.watch_duration(log=False) as d:\n",
    "        _ = df[\"smiles\"].parallel_apply(smiles_to_unique_mol_id)\n",
    "\n",
    "    datum = {\n",
    "        \"batch\": False,\n",
    "        \"batch_size\": None,\n",
    "        \"scheduler\": \"pandarallel\",\n",
    "        \"duration_minutes\": d.duration_minutes,\n",
    "        \"duration_seconds\": d.duration,\n",
    "        \"n_rows\": len(df),\n",
    "    }\n",
    "    benchmark.append(datum)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ddf56ef-6349-4d43-9720-9e53699e9a8e",
   "metadata": {},
   "source": [
    "## Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "989ae5dd-9826-4adb-af0a-64d9e2c19e2c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>batch</th>\n",
       "      <th>batch_size</th>\n",
       "      <th>scheduler</th>\n",
       "      <th>duration_minutes</th>\n",
       "      <th>duration_seconds</th>\n",
       "      <th>n_rows</th>\n",
       "      <th>duration_seconds_per_mol</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>True</td>\n",
       "      <td>1000.0</td>\n",
       "      <td>loky_processes</td>\n",
       "      <td>0.015375</td>\n",
       "      <td>0.922496</td>\n",
       "      <td>133885</td>\n",
       "      <td>0.000007</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>True</td>\n",
       "      <td>100.0</td>\n",
       "      <td>loky_processes</td>\n",
       "      <td>0.037034</td>\n",
       "      <td>2.222021</td>\n",
       "      <td>133885</td>\n",
       "      <td>0.000017</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>True</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>loky_processes</td>\n",
       "      <td>0.055083</td>\n",
       "      <td>3.304987</td>\n",
       "      <td>133885</td>\n",
       "      <td>0.000025</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>False</td>\n",
       "      <td>NaN</td>\n",
       "      <td>pandarallel</td>\n",
       "      <td>0.121338</td>\n",
       "      <td>7.280271</td>\n",
       "      <td>133885</td>\n",
       "      <td>0.000054</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>True</td>\n",
       "      <td>10.0</td>\n",
       "      <td>loky_processes</td>\n",
       "      <td>0.165935</td>\n",
       "      <td>9.956113</td>\n",
       "      <td>133885</td>\n",
       "      <td>0.000074</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>False</td>\n",
       "      <td>NaN</td>\n",
       "      <td>loky_processes</td>\n",
       "      <td>3.639053</td>\n",
       "      <td>218.343192</td>\n",
       "      <td>133885</td>\n",
       "      <td>0.001631</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   batch  batch_size       scheduler  duration_minutes  duration_seconds  \\\n",
       "3   True      1000.0  loky_processes          0.015375          0.922496   \n",
       "2   True       100.0  loky_processes          0.037034          2.222021   \n",
       "4   True     10000.0  loky_processes          0.055083          3.304987   \n",
       "5  False         NaN     pandarallel          0.121338          7.280271   \n",
       "1   True        10.0  loky_processes          0.165935          9.956113   \n",
       "0  False         NaN  loky_processes          3.639053        218.343192   \n",
       "\n",
       "   n_rows  duration_seconds_per_mol  \n",
       "3  133885                  0.000007  \n",
       "2  133885                  0.000017  \n",
       "4  133885                  0.000025  \n",
       "5  133885                  0.000054  \n",
       "1  133885                  0.000074  \n",
       "0  133885                  0.001631  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b = pd.DataFrame(benchmark)\n",
    "b[\"duration_seconds_per_mol\"] = b[\"duration_seconds\"] / b[\"n_rows\"]\n",
    "\n",
    "b.sort_values(\"duration_seconds_per_mol\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04c0b10c-bf99-43bf-979a-9d1f0c1ff1fd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39b3d77e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


================================================
FILE: docs/tutorials/feature_processing/csv_to_parquet.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import graphium\n",
    "from os.path import dirname, abspath\n",
    "\n",
    "MAIN_DIR = dirname(dirname(abspath(graphium.__file__)))\n",
    "\n",
    "# TODO create funciton to read parquet, test from GCP storage (put it in path, should support gs and path, explore function from Pandas instead of parquet \"pq\")\n",
    "def _csv_to_parquet(csv_path, parquet_path):\n",
    "    df = pd.read_csv(csv_path)\n",
    "    df.to_parquet(parquet_path)\n",
    "\n",
    "_csv_to_parquet(MAIN_DIR + '/graphium/data/QM9/micro_qm9.csv', MAIN_DIR + '/graphium/data/QM9/micro_qm9.parquet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "    # TODO create funciton to specify if you read parquet or csv\n",
    "    # TODO replace all location with call for _read_csv and make sure to read all files if path ends with \"*\"\n",
    "    # def read_table:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/nethome/andyh/graphium/docs/tutorials/feature_processing\r\n"
     ]
    }
   ],
   "source": [
    "!pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import graphium\n",
    "import os\n",
    "from os.path import dirname, abspath\n",
    "MAIN_DIR = dirname(dirname(abspath(graphium.__file__)))\n",
    "os.chdir(MAIN_DIR) # No need for this file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "b813b16c721ca8d9e65e6e6945a38a6ae48015ae799c455bab639da90d3bb278"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}


================================================
FILE: docs/tutorials/feature_processing/timing_parallel.ipynb
================================================
{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "0693c5c9",
   "metadata": {},
   "source": [
    "# Timing parallel processing\n",
    "\n",
    "This tutorial is meant to help you benchmark different parallel processing methods for the processing of molecules into graphs. This will allow you to chose the one most suitable for your machine, since the benchmarks vary per machine.\n",
    "\n",
    "In general, we find that using `joblib` with the `loky` parallel processing and a batch size of `1000` is most beneficial. The logic is abstracted into `datamol.parallelized_with_batches`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b5df2ac6-2ded-4597-a445-f2b5fb106330",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO: Pandarallel will run on 240 workers.\n",
      "INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import joblib\n",
    "\n",
    "import numpy as np\n",
    "import datamol as dm\n",
    "import pandas as pd\n",
    "\n",
    "from pandarallel import pandarallel\n",
    "\n",
    "pandarallel.initialize(progress_bar=True, nb_workers=joblib.cpu_count())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f81fe0f5-055f-436e-9ef4-e58a89fd50ec",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0f31e18d-bdd9-4d9b-8ba5-81e5887b857e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# download from https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv\n",
    "# data = pd.read_csv(\"/home/hadim/250k_rndm_zinc_drugs_clean_3.csv\", usecols=[\"smiles\"])\n",
    "\n",
    "# download from https://storage.valencelabs.com/graphium/datasets/QM9/norm_qm9.csv\n",
    "data = pd.read_csv(\"https://storage.valencelabs.com/graphium/datasets/QM9/norm_qm9.csv\", usecols=[\"smiles\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a1197c31-7dbc-4fd7-a69a-5215e1a96b8e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "rows_number_list = [250_000]\n",
    "batch_size_list = [10, 100, 1_000, 10_000]\n",
    "\n",
    "\n",
    "def smiles_to_unique_mol_id(smiles):\n",
    "    try:\n",
    "        mol = dm.to_mol(mol=smiles)\n",
    "        mol_id = dm.unique_id(mol)\n",
    "    except:\n",
    "        mol_id = \"\"\n",
    "    if mol_id is None:\n",
    "        mol_id = \"\"\n",
    "    return mol_id\n",
    "\n",
    "\n",
    "def smiles_to_unique_mol_id_batch(smiles_list):\n",
    "    mol_id_list = []\n",
    "    for smiles in smiles_list:\n",
    "        mol_id_list.append(smiles_to_unique_mol_id(smiles))\n",
    "    return mol_id_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d8eea8f-b983-46e7-bfb4-b8012b3ede1a",
   "metadata": {},
   "source": [
    "## Benchmarks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2f8ce5c3-4232-4279-8ea3-7a74832303be",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "benchmark = []"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "691f8963-6cac-4d58-b8a0-8049f1185486",
   "metadata": {},
   "source": [
    "### No batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a246cdcf-b5ea-4c9e-9ccc-dd3c544587bb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b2e528726a384b258d6ed3576bc0db1c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/133885 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for n in rows_number_list:\n",
    "    df = data.iloc[:n]\n",
    "\n",
    "    with dm.utils.perf.watch_duration(log=False) as d:\n",
    "        out = dm.parallelized(\n",
    "            smiles_to_unique_mol_id,\n",
    "            df[\"smiles\"].values,\n",
    "            progress=True,\n",
    "            n_jobs=-1,\n",
    "            scheduler=\"processes\",\n",
    "        )\n",
    "\n",
    "    datum = {\n",
    "        \"batch\": False,\n",
    "        \"batch_size\": None,\n",
    "        \"scheduler\": \"loky_processes\",\n",
    "        \"duration_minutes\": d.duration_minutes,\n",
    "        \"duration_seconds\": d.duration,\n",
    "        \"n_rows\": len(df),\n",
    "    }\n",
    "    benchmark.append(datum)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41139854-892f-4481-8dd6-a3972bb6ece0",
   "metadata": {},
   "source": [
    "### Batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "845a72d5-0b3f-4a21-8d7d-8312671e9924",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "85e79375f3e94d9abc435bdc8d28d2df",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/13388 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f5df65b3d2474b02aff1058e044f7dcf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1338 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3a5170b0055e4971bd0ea5e7b6cdcbd8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/133 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fa091db65b454db4bc824439645290ad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/13 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for batch_size in batch_size_list:\n",
    "    for n in rows_number_list:\n",
    "        df = data.iloc[:n]\n",
    "\n",
    "        with dm.utils.perf.watch_duration(log=False) as d:\n",
    "            out = dm.parallelized_with_batches(\n",
    "                smiles_to_unique_mol_id_batch,\n",
    "                df[\"smiles\"].values,\n",
    "                batch_size=batch_size,\n",
    "                progress=True,\n",
    "                n_jobs=-1,\n",
    "                scheduler=\"processes\",\n",
    "            )\n",
    "        assert len(out) == len(df), f\"{len(out)} != {len(df)}\"\n",
    "\n",
    "        datum = {\n",
    "            \"batch\": True,\n",
    "            \"batch_size\": batch_size,\n",
    "            \"scheduler\": \"loky_processes\",\n",
    "            \"duration_minutes\": d.duration_minutes,\n",
    "            \"duration_seconds\": d.duration,\n",
    "            \"n_rows\": len(df),\n",
    "        }\n",
    "        benchmark.append(datum)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ec7529d1-2ca3-42ec-a811-3dc010a03350",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a9de4fa3630e4c9aa4d81cac4f10b4c8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=558), Label(value='0 / 558'))), HB…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for n in rows_number_list:\n",
    "    df = data.iloc[:n]\n",
    "\n",
    "    with dm.utils.perf.watch_duration(log=False) as d:\n",
    "        _ = df[\"smiles\"].parallel_apply(smiles_to_unique_mol_id)\n",
    "\n",
    "    datum = {\n",
    "        \"batch\": False,\n",
    "        \"batch_size\": None,\n",
    "        \"scheduler\": \"pandarallel\",\n",
    "        \"duration_minutes\": d.duration_minutes,\n",
    "        \"duration_seconds\": d.duration,\n",
    "        \"n_rows\": len(df),\n",
    "    }\n",
    "    benchmark.append(datum)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ddf56ef-6349-4d43-9720-9e53699e9a8e",
   "metadata": {},
   "source": [
    "## Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "989ae5dd-9826-4adb-af0a-64d9e2c19e2c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>batch</th>\n",
       "      <th>batch_size</th>\n",
       "      <th>scheduler</th>\n",
       "      <th>duration_minutes</th>\n",
       "      <th>duration_seconds</th>\n",
       "      <th>n_rows</th>\n",
       "      <th>duration_seconds_per_mol</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>True</td>\n",
       "      <td>1000.0</td>\n",
       "      <td>loky_processes</td>\n",
       "      <td>0.014199</td>\n",
       "      <td>0.851930</td>\n",
       "      <td>133885</td>\n",
       "      <td>0.000006</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>True</td>\n",
       "      <td>100.0</td>\n",
       "      <td>loky_processes</td>\n",
       "      <td>0.037132</td>\n",
       "      <td>2.227947</td>\n",
       "      <td>133885</td>\n",
       "      <td>0.000017</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>True</td>\n",
       "      <td>10000.0</td>\n",
       "      <td>loky_processes</td>\n",
       "      <td>0.047438</td>\n",
       "      <td>2.846266</td>\n",
       "      <td>133885</td>\n",
       "      <td>0.000021</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>False</td>\n",
       "      <td>NaN</td>\n",
       "      <td>pandarallel</td>\n",
       "      <td>0.118230</td>\n",
       "      <td>7.093791</td>\n",
       "      <td>133885</td>\n",
       "      <td>0.000053</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>True</td>\n",
       "      <td>10.0</td>\n",
       "      <td>loky_processes</td>\n",
       "      <td>0.222177</td>\n",
       "      <td>13.330603</td>\n",
       "      <td>133885</td>\n",
       "      <td>0.000100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>False</td>\n",
       "      <td>NaN</td>\n",
       "      <td>loky_processes</td>\n",
       "      <td>4.002346</td>\n",
       "      <td>240.140754</td>\n",
       "      <td>133885</td>\n",
       "      <td>0.001794</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   batch  batch_size       scheduler  duration_minutes  duration_seconds   \n",
       "3   True      1000.0  loky_processes          0.014199          0.851930  \\\n",
       "2   True       100.0  loky_processes          0.037132          2.227947   \n",
       "4   True     10000.0  loky_processes          0.047438          2.846266   \n",
       "5  False         NaN     pandarallel          0.118230          7.093791   \n",
       "1   True        10.0  loky_processes          0.222177         13.330603   \n",
       "0  False         NaN  loky_processes          4.002346        240.140754   \n",
       "\n",
       "   n_rows  duration_seconds_per_mol  \n",
       "3  133885                  0.000006  \n",
       "2  133885                  0.000017  \n",
       "4  133885                  0.000021  \n",
       "5  133885                  0.000053  \n",
       "1  133885                  0.000100  \n",
       "0  133885                  0.001794  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b = pd.DataFrame(benchmark)\n",
    "b[\"duration_seconds_per_mol\"] = b[\"duration_seconds\"] / b[\"n_rows\"]\n",
    "\n",
    "b.sort_values(\"duration_seconds_per_mol\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04c0b10c-bf99-43bf-979a-9d1f0c1ff1fd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39b3d77e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "graphium_ipu",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


================================================
FILE: docs/tutorials/gnn/add_new_gnn_layers.ipynb
================================================
{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Creating GNN layers\n",
    "\n",
    "One of the primary advantage of this library is the fact that GNN layers are independent from model architecture, thus allowing more flexibility with the code by easily swapping different layer types as hyper-parameters. However, this requires that the layers be implemented using the PyG library, and must be inherited from the class [`BaseGraphStructure`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure), which standardizes the inputs, outputs and properties of the layers. Thus, the architecture can be handled independantly using the class [`FeedForwardPyg`](https://graphium-docs.datamol.io/stable/api/graphium.nn/architectures.html#graphium.nn.architectures.pyg_architectures.FeedForwardPyg) or other custom classes.\n",
    "\n",
    "We will first start by a simple layer that does not use edges, to a more complex layer that uses edges.\n",
    "\n",
    "Since these examples are built on top of PyG, we recommend looking at their [Docs](https://pytorch-geometric.readthedocs.io/en/latest/) and [Tutorials](https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html) for more info. "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Table of content\n",
    "1. [Define test graph](#Define-synthetic-test-graph)\n",
    "2. [Create simple layer](#Creating-a-simple-layer)\n",
    "3. [Test simple layer](#Test-our-simple-layer)\n",
    "4. [Create complex layer](#Creating-a-complex-layer-with-edges)\n",
    "5. [Test complex layer](#Test-our-complex-layer) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import torch\n",
    "from torch import Tensor\n",
    "from torch_geometric.data import Data, Batch\n",
    "from torch_geometric.nn.conv import MessagePassing\n",
    "from torch_scatter import scatter\n",
    "from torch_geometric.typing import (\n",
    "    OptPairTensor,\n",
    "    OptTensor\n",
    ")\n",
    "\n",
    "from copy import deepcopy\n",
    "from typing import Callable, Union, Optional, List\n",
    "from graphium.nn.base_graph_layer import BaseGraphStructure\n",
    "from graphium.nn.base_layers import FCLayer\n",
    "from graphium.utils.decorators import classproperty\n",
    "\n",
    "_ = torch.manual_seed(42)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define synthetic test graph\n",
    "\n",
    "We define below a small batched graph on which we can test the created layers. You can also use synthetic graphs to define unit tests "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataBatch(edge_index=[2, 7], feat=[7, 5], edge_feat=[7, 13], batch=[7], ptr=[3])\n"
     ]
    }
   ],
   "source": [
    "in_dim = 5          # Input node-feature dimensions\n",
    "out_dim = 11        # Desired output node-feature dimensions\n",
    "in_dim_edges = 13   # Input edge-feature dimensions\n",
    "\n",
    "\n",
    "# Let's create 2 simple pyg graphs. \n",
    "# start by specifying the edges with edge index\n",
    "edge_idx1 = torch.tensor([[0, 1, 2],\n",
    "                          [1, 2, 3]])\n",
    "edge_idx2 = torch.tensor([[2, 0, 0, 1],\n",
    "                          [0, 1, 2, 0]])\n",
    "\n",
    "# specify the node features, convention with variable x\n",
    "x1 = torch.randn(edge_idx1.max() + 1, in_dim, dtype=torch.float32)\n",
    "x2 = torch.randn(edge_idx2.max() + 1, in_dim, dtype=torch.float32)\n",
    "\n",
    "# specify the edge features in e\n",
    "e1 = torch.randn(edge_idx1.shape[-1], in_dim_edges, dtype=torch.float32)\n",
    "e2 = torch.randn(edge_idx2.shape[-1], in_dim_edges, dtype=torch.float32)\n",
    "\n",
    "# make the pyg graph objects with our constructed features\n",
    "g1 = Data(feat=x1, edge_index=edge_idx1, edge_feat=e1)\n",
    "g2 = Data(feat=x2, edge_index=edge_idx2, edge_feat=e2)\n",
    "\n",
    "# put the two graphs into a Batch graph\n",
    "bg = Batch.from_data_list([g1, g2])\n",
    "\n",
    "# The batched graph will show as a single graph with 7 nodes\n",
    "print(bg)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating a simple layer\n",
    "\n",
    "Here, we will show how to create a GNN layer that simples does a mean aggregation on the neighbouring features.\n",
    "\n",
    "First, for the layer to be fully compatible with the flexible architecture provided by [`FeedForwardPyg`](https://graphium-docs.datamol.io/stable/api/graphium.nn/architectures.html#graphium.nn.architectures.pyg_architectures.FeedForwardPyg), it needs to inherit from the class [`BaseGraphStructure`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure). This base-layer has multiple virtual methods that must be implemented in any class that inherits from it.\n",
    "\n",
    "The virtual methods are below\n",
    "\n",
    "- [`layer_supports_edges`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure): We want to return `False` since our layer doesn't support edges\n",
    "- [`layer_inputs_edges`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure.layer_inputs_edges): We want to return `False` since our layer doesn't input edges\n",
    "- [`layer_outputs_edges`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure.layer_outputs_edges): We want to return `False` since our layer doesn't output edges\n",
    "- [`out_dim_factor`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure.out_dim_factor): We want to return `1` since the output dimension does not depend on internal parameters.\n",
    "\n",
    "The example is given below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inherit from message passing class from pyg \n",
    "# inherit from BaseGraphStructure\n",
    "# this example is also based of : https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/simple_conv.html#SimpleConv.forward\n",
    "\n",
    "class SimpleMeanLayer(MessagePassing, BaseGraphStructure):\n",
    "    def __init__(self, \n",
    "                 in_dim: int, \n",
    "                 out_dim: int, \n",
    "                 activation: Union[Callable, str] = \"relu\", \n",
    "                 dropout: float = 0.0, \n",
    "                 normalization: Union[str, Callable] = \"none\",\n",
    "                 aggr: str = \"mean\",\n",
    "                ):\n",
    "        r\"\"\"\n",
    "        add documentation in the format shown here for this layer to be automatically added to the graphium api reference\n",
    "        the type information will also be automatically shown in the doc\n",
    "        \n",
    "        Parameters:\n",
    "\n",
    "            in_dim:\n",
    "                Input feature dimensions of the layer\n",
    "\n",
    "            out_dim:\n",
    "                Output feature dimensions of the layer\n",
    "\n",
    "            activation:\n",
    "                activation function to use in the layer\n",
    "\n",
    "            dropout:\n",
    "                The ratio of units to dropout. Must be between 0 and 1\n",
    "\n",
    "            normalization:\n",
    "                Normalization to use. Choices:\n",
    "\n",
    "                - \"none\" or `None`: No normalization\n",
    "                - \"batch_norm\": Batch normalization\n",
    "                - \"layer_norm\": Layer normalization\n",
    "                - `Callable`: Any callable function\n",
    "            aggr:\n",
    "                what aggregation to use (\"add\", \"mean\" or \"max\")\n",
    "        \"\"\"\n",
    "        # Initialize the parent class\n",
    "        MessagePassing.__init__(self, node_dim=0, aggr=aggr)\n",
    "        BaseGraphStructure.__init__(self,\n",
    "                                    in_dim=in_dim, \n",
    "                                    out_dim=out_dim, \n",
    "                                    activation=activation,\n",
    "                                    dropout=dropout, \n",
    "                                    normalization=normalization)\n",
    "        \n",
    "        self.aggr = aggr\n",
    "        self._initialize_activation_dropout_norm()\n",
    "        # Create the mlp layer \n",
    "        # https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_layers.MLP\n",
    "        self.transform = FCLayer(in_dim=in_dim, out_dim=out_dim)\n",
    "            \n",
    "    # define the forward function \n",
    "    def forward(self, \n",
    "                batch: Union[Data, Batch],\n",
    "               ) -> Union[Data, Batch]:\n",
    "        r\"\"\"\n",
    "        similarly add documentation to the functions in the class\n",
    "        \n",
    "        Parameters:\n",
    "            batch: pyg Batch graphs to pass through the layer\n",
    "        Returns:\n",
    "            batch: pyg Batch graphs\n",
    "        \"\"\"\n",
    "        \n",
    "        x = batch.feat\n",
    "        edge_index = batch.edge_index\n",
    "        \n",
    "        if isinstance(x, Tensor):\n",
    "            x: OptPairTensor = (x, x)\n",
    "\n",
    "        # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)\n",
    "        out = self.propagate(edge_index, x=x)\n",
    "        out = self.transform(out)\n",
    "        out = self.apply_norm_activation_dropout(out, batch_idx=batch.batch)\n",
    "        batch.feat = out\n",
    "        return batch\n",
    "    \n",
    "    def message(self, x_j):\n",
    "\n",
    "        return x_j\n",
    "\n",
    "    # Finally, we define all the virtual properties according to how the class works with proper documentation of course\n",
    "    @classproperty\n",
    "    def layer_supports_edges(cls) -> bool:\n",
    "        r\"\"\"\n",
    "        Return a boolean specifying if the layer type supports edges or not.\n",
    "\n",
    "        Returns:\n",
    "\n",
    "            bool:\n",
    "                False for the current class\n",
    "        \"\"\"\n",
    "        return False\n",
    "\n",
    "    @property\n",
    "    def layer_inputs_edges(self) -> bool:\n",
    "        r\"\"\"\n",
    "        Returns:\n",
    "\n",
    "            bool:\n",
    "                Returns False\n",
    "        \"\"\"\n",
    "        return False\n",
    "            \n",
    "    @property\n",
    "    def layer_outputs_edges(self) -> bool:\n",
    "        r\"\"\"\n",
    "        Returns:\n",
    "\n",
    "            bool:\n",
    "                Always ``False`` for the current class\n",
    "        \"\"\"\n",
    "        return False\n",
    "\n",
    "    @property\n",
    "    def out_dim_factor(self) -> int:\n",
    "        r\"\"\"\n",
    "        Get the factor by which the output dimension is multiplied for\n",
    "        the next layer.\n",
    "\n",
    "        For standard layers, this will return ``1``.\n",
    "        Returns:\n",
    "\n",
    "            int:\n",
    "                Always ``1`` for the current class\n",
    "        \"\"\"\n",
    "        return 1"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test our simple layer\n",
    "\n",
    "Now, we are ready to test the `SimpleMeanLayer` on our constructed graphs. Note that in this example, we **ignore** the edge features since they are not supported."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([7, 5])\n",
      "torch.Size([7, 11])\n"
     ]
    }
   ],
   "source": [
    "graph = deepcopy(bg)\n",
    "print(graph.feat.shape)\n",
    "\n",
    "layer = SimpleMeanLayer(\n",
    "            in_dim=in_dim, \n",
    "            out_dim=out_dim, \n",
    "            activation=\"relu\", \n",
    "            dropout=.3, \n",
    "            normalization=\"batch_norm\")\n",
    "\n",
    "graph = layer(graph)\n",
    "print(graph.feat.shape)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating a complex layer with edges\n",
    "\n",
    "Here, we will show how to create a GNN layer that does a mean aggregation on the neighbouring features, concatenated to the edge features with their neighbours. In that case, only the node features will change, and the network will not update the edge features.\n",
    "\n",
    "The virtual methods will have different outputs\n",
    "\n",
    "- [`layer_supports_edges`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure): We want to return `True` since our layer does support edges\n",
    "- [`layer_inputs_edges`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure.layer_inputs_edges): We want to return `True` since our layer does input edges\n",
    "- [`layer_outputs_edges`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure.layer_outputs_edges): We want to return `False` since our layer will not output new edges\n",
    "- [`out_dim_factor`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure.out_dim_factor): We want to return `1` since the output dimension does not depend on internal parameters.\n",
    "\n",
    "The example is given below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inherent from message passing class from pyg \n",
    "# inherent from BaseGraphStructure\n",
    "# adapting example from graphium/nn/pyg_layers/png_pyg.py\n",
    "\n",
    "class ComplexMeanLayer(MessagePassing, BaseGraphStructure):\n",
    "    def __init__(self, \n",
    "                 in_dim: int, \n",
    "                 out_dim: int, \n",
    "                 in_dim_edges: int, \n",
    "                 activation: Union[Callable, str] = \"relu\", \n",
    "                 dropout: float = 0.0, \n",
    "                 normalization: Union[str, Callable] = \"none\",\n",
    "                 aggr: str = \"mean\",\n",
    "                ):\n",
    "        r\"\"\"\n",
    "        add documentation in the format shown here for this layer to be automatically added to the graphium api reference\n",
    "        the type information will also be automatically shown in the doc\n",
    "        \n",
    "        Parameters:\n",
    "\n",
    "            in_dim:\n",
    "                Input feature dimensions of the layer\n",
    "\n",
    "            out_dim:\n",
    "                Output feature dimensions of the layer\n",
    "            \n",
    "            in_dim_edges:\n",
    "                Input edge feature dimension of the layer\n",
    "\n",
    "            activation:\n",
    "                activation function to use in the layer\n",
    "\n",
    "            dropout:\n",
    "                The ratio of units to dropout. Must be between 0 and 1\n",
    "\n",
    "            normalization:\n",
    "                Normalization to use. Choices:\n",
    "\n",
    "                - \"none\" or `None`: No normalization\n",
    "                - \"batch_norm\": Batch normalization\n",
    "                - \"layer_norm\": Layer normalization\n",
    "                - `Callable`: Any callable function\n",
    "            aggr:\n",
    "                what aggregation to use (\"add\", \"mean\" or \"max\")\n",
    "        \"\"\"\n",
    "        # Initialize the parent class\n",
    "        MessagePassing.__init__(self, node_dim=0, aggr=aggr)\n",
    "        BaseGraphStructure.__init__(self,\n",
    "                                    in_dim=in_dim, \n",
    "                                    out_dim=out_dim, \n",
    "                                    activation=activation,\n",
    "                                    dropout=dropout, \n",
    "                                    normalization=normalization)\n",
    "        \n",
    "        self.aggr = aggr\n",
    "        self._initialize_activation_dropout_norm()\n",
    "        # Create the mlp layer \n",
    "        # https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_layers.MLP\n",
    "        self.transform = FCLayer(in_dim=(in_dim + in_dim_edges), out_dim=out_dim)\n",
    "            \n",
    "    # define the forward function \n",
    "    def forward(self, \n",
    "                batch: Union[Data, Batch],\n",
    "               ) -> Union[Data, Batch]:\n",
    "        r\"\"\"\n",
    "        similarly add documentation to the functions in the class\n",
    "        \n",
    "        Parameters:\n",
    "            batch: pyg Batch graphs to pass through the layer\n",
    "        Returns:\n",
    "            batch: pyg Batch graphs\n",
    "        \"\"\"        \n",
    "        x = batch.feat\n",
    "        edge_index = batch.edge_index\n",
    "        edge_feat = batch.edge_feat\n",
    "        \n",
    "        if isinstance(x, Tensor):\n",
    "            x: OptPairTensor = (x, x)\n",
    "\n",
    "        # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)\n",
    "        out = self.propagate(edge_index, x=x, edge_feat=edge_feat, size=None)\n",
    "        out = self.transform(out)\n",
    "        out = self.apply_norm_activation_dropout(out, batch_idx=batch.batch)\n",
    "        batch.feat = out\n",
    "        return batch\n",
    "    \n",
    "    def message(self, x_j: Tensor, edge_feat: OptTensor) -> Tensor:\n",
    "        r\"\"\"\n",
    "        message function\n",
    "\n",
    "        Parameters:\n",
    "            x_i: node features\n",
    "            x_j: neighbour node features\n",
    "            edge_feat: edge features\n",
    "        Returns:\n",
    "            feat: the message\n",
    "        \"\"\"\n",
    "        feat = torch.cat([x_j, edge_feat], dim=-1)\n",
    "        return feat\n",
    "    \n",
    "    def aggregate(\n",
    "        self,\n",
    "        inputs: Tensor,\n",
    "        index: Tensor,\n",
    "        edge_index: Tensor,\n",
    "        dim_size: Optional[int] = None,\n",
    "    ) -> Tensor:\n",
    "        r\"\"\"\n",
    "        aggregate function\n",
    "        Parameters:\n",
    "            inputs: input features\n",
    "            index: index of the nodes\n",
    "            edge_index: edge index\n",
    "            dim_size: dimension size\n",
    "        Returns:\n",
    "            out: aggregated features\n",
    "        \"\"\"\n",
    "        out = scatter(inputs, index, 0, None, dim_size, reduce=self.aggr)\n",
    "        return out\n",
    "    \n",
    "\n",
    "    # Finally, we define all the virtual properties according to how the class works with proper documentation of course\n",
    "    @classproperty\n",
    "    def layer_supports_edges(cls) -> bool:\n",
    "        r\"\"\"\n",
    "        Return a boolean specifying if the layer type supports edges or not.\n",
    "\n",
    "        Returns:\n",
    "\n",
    "            bool:\n",
    "                False for the current class\n",
    "        \"\"\"\n",
    "        return True\n",
    "\n",
    "    @property\n",
    "    def layer_inputs_edges(self) -> bool:\n",
    "        r\"\"\"\n",
    "        Returns:\n",
    "\n",
    "            bool:\n",
    "                Returns True\n",
    "        \"\"\"\n",
    "        return True\n",
    "            \n",
    "    @property\n",
    "    def layer_outputs_edges(self) -> bool:\n",
    "        r\"\"\"\n",
    "        Returns:\n",
    "\n",
    "            bool:\n",
    "                Always ``True`` for the current class\n",
    "        \"\"\"\n",
    "        return True\n",
    "\n",
    "    @property\n",
    "    def out_dim_factor(self) -> int:\n",
    "        r\"\"\"\n",
    "        Get the factor by which the output dimension is multiplied for\n",
    "        the next layer.\n",
    "\n",
    "        For standard layers, this will return ``1``.\n",
    "        Returns:\n",
    "\n",
    "            int:\n",
    "                Always ``1`` for the current class\n",
    "        \"\"\"\n",
    "        return 1"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test our complex layer \n",
    "\n",
    "Now, we are ready to test the `ComplexMeanLayer` on our constructed graphs. Note that in this example, we utilize the edge features and node features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([7, 5])\n",
      "torch.Size([7, 13])\n",
      "torch.Size([7, 11])\n"
     ]
    }
   ],
   "source": [
    "graph = deepcopy(bg)\n",
    "print(graph.feat.shape)\n",
    "print(graph.edge_feat.shape)\n",
    "\n",
    "layer = ComplexMeanLayer(\n",
    "            in_dim=in_dim, \n",
    "            out_dim=out_dim, \n",
    "            in_dim_edges=in_dim_edges,\n",
    "            activation=\"relu\", \n",
    "            dropout=.3, \n",
    "            normalization=\"batch_norm\")\n",
    "\n",
    "graph = layer(graph)\n",
    "print(graph.feat.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "f4a99d018a205fcbcc0480c84566beaebcb91b08d0414b39a842df533e2a1d25"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}


================================================
FILE: docs/tutorials/gnn/making_gnn_networks.ipynb
================================================
{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Making GNN Networks\n",
    "\n",
    "In this example, you will learn how to easily build a full multi-task GNN using any kind of GNN layer. This tutorial uses the architecture defined by the class `FullGraphMultiTaskNetwork`.\n",
    "\n",
    "`FullGraphMultiTaskNetwork` is an architecture that takes as input node features and (optionally) edge features. It applies a pre-MLP on both sets of features, then passes them into a main GNN network, and finally applies graph output NNs to produces the final outputs on possibly several task levels (graph, node, or edge-level).\n",
    "\n",
    "The network is very easy to built via a dictionnary of parameter that allow to customize each part of the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import torch\n",
    "from copy import deepcopy\n",
    "\n",
    "from torch_geometric.data import Data, Batch\n",
    "\n",
    "from graphium.nn.architectures import FullGraphMultiTaskNetwork\n",
    "\n",
    "_ = torch.manual_seed(42)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will first create some simple batched graphs that will be used accross the examples. Here, `bg` is a batch containing 2 graphs with random node features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataBatch(edge_index=[2, 7], feat=[7, 5], edge_feat=[7, 13], batch=[7], ptr=[3])\n"
     ]
    }
   ],
   "source": [
    "in_dim = 5          # Input node-feature dimensions\n",
    "in_dim_edges = 13   # Input edge-feature dimensions\n",
    "out_dim = 11        # Desired output node-feature dimensions\n",
    "\n",
    "\n",
    "# Let's create 2 simple pyg graphs. \n",
    "# start by specifying the edges with edge index\n",
    "edge_idx1 = torch.tensor([[0, 1, 2],\n",
    "                          [1, 2, 3]])\n",
    "edge_idx2 = torch.tensor([[2, 0, 0, 1],\n",
    "                          [0, 1, 2, 0]])\n",
    "\n",
    "# specify the node features, convention with variable x\n",
    "x1 = torch.randn(edge_idx1.max() + 1, in_dim, dtype=torch.float32)\n",
    "x2 = torch.randn(edge_idx2.max() + 1, in_dim, dtype=torch.float32)\n",
    "\n",
    "# specify the edge features in e\n",
    "e1 = torch.randn(edge_idx1.shape[-1], in_dim_edges, dtype=torch.float32)\n",
    "e2 = torch.randn(edge_idx2.shape[-1], in_dim_edges, dtype=torch.float32)\n",
    "\n",
    "# make the pyg graph objects with our constructed features\n",
    "g1 = Data(feat=x1, edge_index=edge_idx1, edge_feat=e1)\n",
    "g2 = Data(feat=x2, edge_index=edge_idx2, edge_feat=e2)\n",
    "\n",
    "# put the two graphs into a Batch graph\n",
    "bg = Batch.from_data_list([g1, g2])\n",
    "\n",
    "# The batched graph will show as a single graph with 7 nodes\n",
    "print(bg)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building a network\n",
    "\n",
    "To build the network, we must define the arguments to pass at the different steps:\n",
    "\n",
    "- `pre_nn_kwargs`: The parameters used by a feed-forward neural network on the input node-features, before passing to the convolutional layers. See class `FeedForwardNN` for details on the required parameters. Will be ignored if set to `None`.\n",
    "\n",
    "- `gnn_kwargs`: The parameters used by a feed-forward **graph** neural network on the features after it has passed through the pre-processing NN. See class `FeedForwardGraph` for details on the required parameters.\n",
    "\n",
    "- `task_heads_kwargs`: The parameters used to specify the different task heads for possibly multiple tasks, each with a specified `task_level` (graph, node or edge).\n",
    "\n",
    "- `graph_output_nn_kwargs`: The parameters used by the graph output NNs to process the features after the GNN layers. We need to to specify a NN for each `task_level` occuring in in the `task_heads_kwargs`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_dim_1 = 23\n",
    "temp_dim_2 = 17\n",
    "\n",
    "pre_nn_kwargs = {\n",
    "    \"in_dim\": in_dim,\n",
    "    \"out_dim\": temp_dim_1,\n",
    "    \"hidden_dims\": 4,\n",
    "    \"depth\": 2,\n",
    "    \"activation\": 'relu',\n",
    "    \"last_activation\": \"none\",\n",
    "    \"dropout\": 0.2\n",
    "}\n",
    "\n",
    "gnn_kwargs = {\n",
    "    \"in_dim\": temp_dim_1,\n",
    "    \"out_dim\": temp_dim_2,\n",
    "    \"hidden_dims\": 5,\n",
    "    \"depth\": 1,\n",
    "    \"activation\": 'gelu',\n",
    "    \"last_activation\": None,\n",
    "    \"dropout\": 0.1,\n",
    "    \"normalization\": 'layer_norm',\n",
    "    \"last_normalization\": 'layer_norm',\n",
    "    \"residual_type\": 'simple',\n",
    "    \"virtual_node\": None,\n",
    "    \"layer_type\": 'pyg:gcn',\n",
    "    \"layer_kwargs\": None\n",
    "}\n",
    "\n",
    "task_heads_kwargs = {\n",
    "    \"graph-task-1\": {\n",
    "        \"task_level\": 'graph',\n",
    "        \"out_dim\": 3,\n",
    "        \"hidden_dims\": 32,\n",
    "        \"depth\": 2,\n",
    "        \"activation\": 'relu',\n",
    "        \"last_activation\": None,\n",
    "        \"dropout\": 0.1,\n",
    "        \"normalization\": None,\n",
    "        \"last_normalization\": None,\n",
    "        \"residual_type\": \"none\"\n",
    "    },\n",
    "    \"graph-task-2\": {\n",
    "        \"task_level\": 'graph',\n",
    "        \"out_dim\": 4,\n",
    "        \"hidden_dims\": 32,\n",
    "        \"depth\": 2,\n",
    "        \"activation\": 'relu',\n",
    "        \"last_activation\": None,\n",
    "        \"dropout\": 0.1,\n",
    "        \"normalization\": None,\n",
    "        \"last_normalization\": None,\n",
    "        \"residual_type\": \"none\"\n",
    "    },\n",
    "    \"node-task-1\": {\n",
    "        \"task_level\": 'node',\n",
    "        \"out_dim\": 2,\n",
    "        \"hidden_dims\": 32,\n",
    "        \"depth\": 2,\n",
    "        \"activation\": 'relu',\n",
    "        \"last_activation\": None,\n",
    "        \"dropout\": 0.1,\n",
    "        \"normalization\": None,\n",
    "        \"last_normalization\": None,\n",
    "        \"residual_type\": \"none\"\n",
    "    }\n",
    "}\n",
    "\n",
    "graph_output_nn_kwargs = {\n",
    "    \"graph\": {\n",
    "        \"pooling\": ['sum'],\n",
    "        \"out_dim\": temp_dim_2,\n",
    "        \"hidden_dims\": temp_dim_2,\n",
    "        \"depth\": 1,\n",
    "        \"activation\": 'relu',\n",
    "        \"last_activation\": None,\n",
    "        \"dropout\": 0.1,\n",
    "        \"normalization\": None,\n",
    "        \"last_normalization\": None,\n",
    "        \"residual_type\": \"none\"\n",
    "    },\n",
    "    \"node\": {\n",
    "        \"pooling\": None,\n",
    "        \"out_dim\": temp_dim_2,\n",
    "        \"hidden_dims\": temp_dim_2,\n",
    "        \"depth\": 1,\n",
    "        \"activation\": 'relu',\n",
    "        \"last_activation\": None,\n",
    "        \"dropout\": 0.1,\n",
    "        \"normalization\": None,\n",
    "        \"last_normalization\": None,\n",
    "        \"residual_type\": \"none\"\n",
    "    }\n",
    "}\n",
    "    \n",
    "\n",
    "gnn_net = FullGraphMultiTaskNetwork(\n",
    "    gnn_kwargs=gnn_kwargs,\n",
    "    pre_nn_kwargs=pre_nn_kwargs, \n",
    "    task_heads_kwargs=task_heads_kwargs,\n",
    "    graph_output_nn_kwargs = graph_output_nn_kwargs\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Applying the network\n",
    "\n",
    "Once the network is defined, we only need to run the forward pass on the input graphs to get a prediction.\n",
    "\n",
    "The model outputs a dictionary of outputs, one for each task specified in `task_heads_kwargs`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([7, 5])\n",
      "\n",
      "\n",
      "FullGNN\n",
      "---------\n",
      "    pre-NN(depth=2, ResidualConnectionNone)\n",
      "        [FCLayer[5 -> 4 -> 23]\n",
      "    \n",
      "    GNN(depth=1, ResidualConnectionSimple(skip_steps=1))\n",
      "        GCNConvPyg[23 -> 17]\n",
      "        \n",
      "    \n",
      "        Task heads:\n",
      "        graph-task-1: NN-graph-task-1(depth=2, ResidualConnectionNone)\n",
      "            [FCLayer[17 -> 32 -> 3]\n",
      "        graph-task-2: NN-graph-task-2(depth=2, ResidualConnectionNone)\n",
      "            [FCLayer[17 -> 32 -> 4]\n",
      "        node-task-1: NN-node-task-1(depth=2, ResidualConnectionNone)\n",
      "            [FCLayer[17 -> 32 -> 2]\n",
      "\n",
      "\n",
      "graph-task-1 torch.Size([2, 3])\n",
      "graph-task-2 torch.Size([2, 4])\n",
      "node-task-1 torch.Size([7, 2])\n"
     ]
    }
   ],
   "source": [
    "graph = deepcopy(bg)\n",
    "print(graph.feat.shape)\n",
    "print(\"\\n\")\n",
    "\n",
    "print(gnn_net)\n",
    "print(\"\\n\")\n",
    "\n",
    "out = gnn_net(graph)\n",
    "for task in out.keys():\n",
    "    print(task, out[task].shape)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building a network (using additional edge features)\n",
    "\n",
    "We can further use a GNN that uses additional edge features as inputs. For that, we only need to slightly modify the `gnn_kwargs` from above. Below, we define the new `gnn_edge_kwargs`, where we can set the `layer_type` to one of the GNN layers that support edge features (e.g., `pyg:gine` or `pyg:gps`) and add the additional parameter `in_dim_edges`. Optionally, we can configure a preprocessing NN for the the edge features via a dictionaly `pre_nn_edges_kwargs` analogous to `pre_nn_kwargs` above, or skip this step by setting it to `None`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_dim_edges = 8\n",
    "\n",
    "gnn_edge_kwargs = {\n",
    "    \"in_dim\": temp_dim_1,\n",
    "    \"in_dim_edges\": temp_dim_edges,\n",
    "    \"out_dim\": temp_dim_2,\n",
    "    \"hidden_dims\": 5,\n",
    "    \"depth\": 1,\n",
    "    \"activation\": 'gelu',\n",
    "    \"last_activation\": None,\n",
    "    \"dropout\": 0.1,\n",
    "    \"normalization\": 'layer_norm',\n",
    "    \"last_normalization\": 'layer_norm',\n",
    "    \"residual_type\": 'simple',\n",
    "    \"virtual_node\": None,\n",
    "    \"layer_type\": 'pyg:gine',\n",
    "    \"layer_kwargs\": None\n",
    "}\n",
    "\n",
    "pre_nn_edges_kwargs = {\n",
    "    \"in_dim\": in_dim_edges,\n",
    "    \"out_dim\": temp_dim_edges,\n",
    "    \"hidden_dims\": 4,\n",
    "    \"depth\": 2,\n",
    "    \"activation\": 'relu',\n",
    "    \"last_activation\": \"none\",\n",
    "    \"dropout\": 0.2\n",
    "}\n",
    "\n",
    "\n",
    "gnn_net_edges = FullGraphMultiTaskNetwork(\n",
    "    gnn_kwargs=gnn_edge_kwargs,\n",
    "    pre_nn_kwargs=pre_nn_kwargs,\n",
    "    pre_nn_edges_kwargs=pre_nn_edges_kwargs,\n",
    "    task_heads_kwargs=task_heads_kwargs,\n",
    "    graph_output_nn_kwargs = graph_output_nn_kwargs\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Applying the network (using additional edge features)\n",
    "\n",
    "Again, we only need to run the forward pass on the input graphs to get a prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([7, 5])\n",
      "torch.Size([7, 13])\n",
      "\n",
      "\n",
      "FullGNN\n",
      "---------\n",
      "    pre-NN(depth=2, ResidualConnectionNone)\n",
      "        [FCLayer[5 -> 4 -> 23]\n",
      "    \n",
      "    GNN(depth=1, ResidualConnectionSimple(skip_steps=1))\n",
      "        GCNConvPyg[23 -> 17]\n",
      "        \n",
      "    \n",
      "        Task heads:\n",
      "        graph-task-1: NN-graph-task-1(depth=2, ResidualConnectionNone)\n",
      "            [FCLayer[17 -> 32 -> 3]\n",
      "        graph-task-2: NN-graph-task-2(depth=2, ResidualConnectionNone)\n",
      "            [FCLayer[17 -> 32 -> 4]\n",
      "        node-task-1: NN-node-task-1(depth=2, ResidualConnectionNone)\n",
      "            [FCLayer[17 -> 32 -> 2]\n",
      "\n",
      "\n",
      "graph-task-1 torch.Size([2, 3])\n",
      "graph-task-2 torch.Size([2, 4])\n",
      "node-task-1 torch.Size([7, 2])\n"
     ]
    }
   ],
   "source": [
    "graph = deepcopy(bg)\n",
    "print(graph.feat.shape)\n",
    "print(graph.edge_feat.shape)\n",
    "print(\"\\n\")\n",
    "\n",
    "print(gnn_net)\n",
    "print(\"\\n\")\n",
    "\n",
    "out = gnn_net_edges(graph)\n",
    "for task in out.keys():\n",
    "    print(task, out[task].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "f4a99d018a205fcbcc0480c84566beaebcb91b08d0414b39a842df533e2a1d25"
  },
  "kernelspec": {
   "display_name": "Python 3.8.8 64-bit ('goli': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


================================================
FILE: docs/tutorials/gnn/using_gnn_layers.ipynb
================================================
{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using GNN layers\n",
    "\n",
    "The current library implements multiple state-of-the-art graph neural networks. In this tutorial, you will learn how to use the **GCN**, **GIN**, **GINE**, **GPS**, **Gated-GCN** and **PNA** layers in a simple `forward` context."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import torch\n",
    "\n",
    "import torch_geometric as pyg\n",
    "from torch_geometric.data import Data, Batch\n",
    "\n",
    "from copy import deepcopy\n",
    "\n",
    "from graphium.nn.pyg_layers import (\n",
    "    GCNConvPyg,\n",
    "    GINConvPyg,\n",
    "    GatedGCNPyg,\n",
    "    GINEConvPyg,\n",
    "    GPSLayerPyg,\n",
    "    PNAMessagePassingPyg\n",
    ")\n",
    "\n",
    "_ = torch.manual_seed(42)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will first create some simple batched graphs that will be used accross the examples. Here, `bg` is a batch containing 2 graphs with random node features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataBatch(edge_index=[2, 7], feat=[7, 5], edge_feat=[7, 13], batch=[7], ptr=[3])\n"
     ]
    }
   ],
   "source": [
    "in_dim = 5          # Input node-feature dimensions\n",
    "in_dim_edges = 13   # Input edge-feature dimensions\n",
    "out_dim = 11        # Desired output node-feature dimensions\n",
    "out_dim_edges = 15  # Desired output edge-feature dimensions\n",
    "\n",
    "\n",
    "# Let's create 2 simple pyg graphs. \n",
    "# start by specifying the edges with edge index\n",
    "edge_idx1 = torch.tensor([[0, 1, 2],\n",
    "                          [1, 2, 3]])\n",
    "edge_idx2 = torch.tensor([[2, 0, 0, 1],\n",
    "                          [0, 1, 2, 0]])\n",
    "\n",
    "# specify the node features, convention with variable x\n",
    "x1 = torch.randn(edge_idx1.max() + 1, in_dim, dtype=torch.float32)\n",
    "x2 = torch.randn(edge_idx2.max() + 1, in_dim, dtype=torch.float32)\n",
    "\n",
    "# specify the edge features in e\n",
    "e1 = torch.randn(edge_idx1.shape[-1], in_dim_edges, dtype=torch.float32)\n",
    "e2 = torch.randn(edge_idx2.shape[-1], in_dim_edges, dtype=torch.float32)\n",
    "\n",
    "# make the pyg graph objects with our constructed features\n",
    "g1 = Data(feat=x1, edge_index=edge_idx1, edge_feat=e1)\n",
    "g2 = Data(feat=x2, edge_index=edge_idx2, edge_feat=e2)\n",
    "\n",
    "# put the two graphs into a Batch graph\n",
    "bg = Batch.from_data_list([g1, g2])\n",
    "\n",
    "# The batched graph will show as a single graph with 7 nodes\n",
    "print(bg)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GCN Layer\n",
    "\n",
    "To use the GCN layer from the *Kipf et al.* paper, the steps are very simple. We create the layer with the desired attributes, and apply it to the graph.\n",
    "\n",
    "<sub>Kipf, Thomas N., and Max Welling. \"Semi-supervised classification with graph convolutional networks.\" arXiv preprint arXiv:1609.02907 (2016).</sub>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([7, 5])\n",
      "GCNConvPyg(5 -> 11, activation=relu)\n",
      "torch.Size([7, 11])\n"
     ]
    }
   ],
   "source": [
    "# The GCN method doesn't support edge features, so we ignore them\n",
    "graph = deepcopy(bg)\n",
    "print(graph.feat.shape)\n",
    "\n",
    "# We create the layer\n",
    "layer = GCNConvPyg(\n",
    "            in_dim=in_dim, out_dim=out_dim, \n",
    "            activation=\"relu\", dropout=.3, normalization=\"batch_norm\")\n",
    "\n",
    "# We apply the forward loop on the node features\n",
    "graph = layer(graph)\n",
    "\n",
    "# 7 is the number of nodes, 5 number of input features and 11 number of output features\n",
    "print(layer)\n",
    "print(graph.feat.shape)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GIN Layer\n",
    "\n",
    "To use the GIN layer from the *Xu et al.* paper, the steps are identical to GCN.\n",
    "\n",
    "<sub>Xu, Keyulu, et al. \"How powerful are graph neural networks?.\" arXiv preprint arXiv:1810.00826 (2018).</sub>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([7, 5])\n",
      "GINConvPyg(5 -> 11, activation=relu)\n",
      "torch.Size([7, 11])\n"
     ]
    }
   ],
   "source": [
    "graph = deepcopy(bg)\n",
    "print(graph.feat.shape)\n",
    "\n",
    "# We create the layer\n",
    "layer = GINConvPyg(\n",
    "            in_dim=in_dim, out_dim=out_dim, \n",
    "            activation=\"relu\", dropout=.3, normalization=\"batch_norm\")\n",
    "\n",
    "# We apply the forward loop on the node features\n",
    "graph = layer(graph)\n",
    "\n",
    "# 7 is the number of nodes, 5 number of input features and 11 number of output features\n",
    "print(layer)\n",
    "print(graph.feat.shape)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GINE Layer\n",
    "\n",
    "To use the GINE layer from the *Hu et al.* paper, we also need to provide additional edge features as inputs.\n",
    "\n",
    "<sub>Hu, Weihua, et al. \"Strategies for Pre-training Graph Neural Networks.\" arXiv preprint arXiv:1905.12265 (2019).</sub>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([7, 5])\n",
      "torch.Size([7, 13])\n",
      "GINEConvPyg(5 -> 11, activation=relu)\n",
      "torch.Size([7, 11])\n"
     ]
    }
   ],
   "source": [
    "# The GINE method uses edge features, so we have to pass the input dimension\n",
    "graph = deepcopy(bg)\n",
    "print(graph.feat.shape)\n",
    "print(graph.edge_feat.shape)\n",
    "\n",
    "# We create the layer\n",
    "layer = GINEConvPyg(\n",
    "            in_dim=in_dim, out_dim=out_dim,\n",
    "            in_dim_edges=in_dim_edges,\n",
    "            activation=\"relu\", dropout=.3, normalization=\"batch_norm\")\n",
    "\n",
    "# We apply the forward loop on the node features\n",
    "graph = layer(graph)\n",
    "\n",
    "# 7 is the number of nodes, 5 number of input features and 11 number of output features\n",
    "# 7 is the number of edges, 13 number of input edge features\n",
    "print(layer)\n",
    "print(graph.feat.shape)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GPS Layer\n",
    "\n",
    "To use the GPS layer from the *Rampášek et al.* paper, we also need to provide additional edge features as inputs. It is a hybrid approach using both a GNN and transformer in conjunction. Therefore, we further need to specify the GNN type and attention type used in the layer.\n",
    "\n",
    "<sub>Rampášek, Ladislav, et al. \"Recipe for a General, Powerful, Scalable Graph Transformer.\" arXiv preprint arXiv:2205.12454 (2022).</sub>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([7, 5])\n",
      "torch.Size([7, 13])\n",
      "GPSLayerPyg(5 -> 11, activation=relu)\n",
      "torch.Size([7, 11])\n",
      "torch.Size([7, 13])\n"
     ]
    }
   ],
   "source": [
    "graph = deepcopy(bg)\n",
    "print(graph.feat.shape)\n",
    "print(graph.edge_feat.shape)\n",
    "\n",
    "# We create the layer\n",
    "layer = GPSLayerPyg(\n",
    "            in_dim=in_dim, out_dim=out_dim,\n",
    "            in_dim_edges=in_dim_edges,\n",
    "            mpnn_type = \"pyg:gine\", attn_type = \"full-attention\",\n",
    "            activation=\"relu\", dropout=.3, normalization=\"batch_norm\")\n",
    "\n",
    "# We apply the forward loop on the node features\n",
    "graph = layer(graph)\n",
    "\n",
    "# 7 is the number of nodes, 5 number of input features and 11 number of output features\n",
    "# 7 is the number of edges, 13 number of input edge features\n",
    "print(layer)\n",
    "print(graph.feat.shape)\n",
    "print(graph.edge_feat.shape)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gated-GCN Layer\n",
    "\n",
    "To use the Gated-GCN layer from the *Bresson et al.* paper, the steps are different since the layer not only requires edge features as inputs, but also outputs new edge features. Therefore, we have to further specify the number of output edge features\n",
    "\n",
    "<sub>Bresson, Xavier, and Thomas Laurent. \"Residual gated graph convnets.\" arXiv preprint arXiv:1711.07553 (2017).</sub>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([7, 5])\n",
      "torch.Size([7, 13])\n",
      "GatedGCNPyg()\n",
      "torch.Size([7, 11])\n",
      "torch.Size([7, 15])\n"
     ]
    }
   ],
   "source": [
    "graph = deepcopy(bg)\n",
    "print(graph.feat.shape)\n",
    "print(graph.edge_feat.shape)\n",
    "\n",
    "# We create the layer\n",
    "layer = GatedGCNPyg(\n",
    "            in_dim=in_dim, out_dim=out_dim,\n",
    "            in_dim_edges=in_dim_edges, out_dim_edges=out_dim_edges,\n",
    "            activation=\"relu\", dropout=.3, normalization=\"batch_norm\")\n",
    "\n",
    "# We apply the forward loop on the node features\n",
    "graph = layer(graph)\n",
    "\n",
    "# 7 is the number of nodes, 5 number of input features and 11 number of output features\n",
    "# 7 is the number of edges, 13 number of input edge features and 15 the the number of output edge features\n",
    "print(layer)\n",
    "print(graph.feat.shape)\n",
    "print(graph.edge_feat.shape)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PNA\n",
    "\n",
    "PNA is a multi-aggregator method proposed by *Corso et al.*. It supports 2 types of aggregations, convolutional *PNA-conv* or message passing *PNA-msgpass*. Here, we provide the typically more powerful *PNA-msgpass*. It supports edges as inputs, but doesn't output edges. Here, we need to further specify the aggregators and scalers specific to this layer.\n",
    "\n",
    "<sub>Corso, Gabriele, et al. \"Principal Neighbourhood Aggregation for Graph Nets.\"\n",
    "arXiv preprint arXiv:2004.05718 (2020).</sub>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([7, 5])\n",
      "torch.Size([7, 13])\n",
      "PNAMessagePassingPyg()\n",
      "torch.Size([7, 11])\n"
     ]
    }
   ],
   "source": [
    "graph = deepcopy(bg)\n",
    "print(graph.feat.shape)\n",
    "print(graph.edge_feat.shape)\n",
    "\n",
    "# We create the layer, and need to specify the aggregators and scalers\n",
    "layer = PNAMessagePassingPyg(\n",
    "    in_dim=in_dim, out_dim=out_dim,\n",
    "    in_dim_edges=in_dim_edges,\n",
    "    aggregators=[\"mean\", \"max\", \"min\", \"std\"],\n",
    "    scalers=[\"identity\", \"amplification\", \"attenuation\"],\n",
    "    activation=\"relu\", dropout=.3, normalization=\"batch_norm\")\n",
    "\n",
    "graph = layer(graph)\n",
    "\n",
    "print(layer)\n",
    "print(graph.feat.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "f4a99d018a205fcbcc0480c84566beaebcb91b08d0414b39a842df533e2a1d25"
  },
  "kernelspec": {
   "display_name": "Python 3.8.8 64-bit ('goli': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 2
 },
 "nbformat": 4,
 "nbformat_minor": 2
}


================================================
FILE: docs/tutorials/model_training/simple-molecular-model.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building and training a simple model from configurations\n",
    "\n",
    "This tutorial will walk you through how to use a configuration file to define all the parameters of a model and of the trainer. This tutorial focuses on training from SMILES data in a CSV format.\n",
    "\n",
    "The work flow of testing your code on the entire pipeline is as follows:\n",
    "\n",
    "1. Select a subset of the [available configs](https://github.com/datamol-io/graphium/tree/main/expts/hydra-configs) as a starting point.\n",
    "2. Create additional configs or modify the existing configs to suit your needs.\n",
    "3. Train or fine-tune a model with the `graphium-train` CLI.\n",
    "\n",
    "## Creating the yaml file\n",
    "\n",
    "The first step is to create a YAML file containing all the required configurations, with an example given at `graphium/expts/hydra-configs/main.yaml`. We will go through each part of the configurations. See also the README [here](https://github.com/datamol-io/graphium/tree/main/expts/hydra-configs)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "import omegaconf\n",
    "\n",
    "from hydra import compose, initialize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_config_with_key(config, key):\n",
    "    new_config = {key: config[key]}\n",
    "    print(omegaconf.OmegaConf.to_yaml(new_config))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Yaml file loaded\n"
     ]
    }
   ],
   "source": [
    "# First, let's read the yaml configuration file\n",
    "with initialize(version_base=None, config_path=\"../../../expts/hydra-configs\"):\n",
    "    yaml_config = compose(config_name=\"main\")\n",
    "\n",
    "print(\"Yaml file loaded\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Constants\n",
    "\n",
    "First, we define the constants such as the random seed and whether the model should raise or ignore an error."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "constants:\n",
      "  name: neurips2023_small_data_gcn\n",
      "  seed: 42\n",
      "  max_epochs: 100\n",
      "  data_dir: expts/data/neurips2023/small-dataset\n",
      "  raise_train_error: true\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print_config_with_key(yaml_config, \"constants\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Datamodule\n",
    "\n",
    "Here, we define all the parameters required by the datamodule to run correctly, such as the dataset path, whether to cache, the columns for the training, the molecular featurization to use, the train/val/test splits and the batch size.\n",
    "\n",
    "For more details, see class [`MultitaskFromSmilesDataModule`](https://graphium-docs.datamol.io/stable/api/graphium.data.html#graphium.data.datamodule.MultitaskFromSmilesDataModule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "datamodule:\n",
      "  module_type: MultitaskFromSmilesDataModule\n",
      "  args:\n",
      "    prepare_dict_or_graph: pyg:graph\n",
      "    featurization_n_jobs: 4\n",
      "    featurization_progress: true\n",
      "    featurization_backend: loky\n",
      "    processed_graph_data_path: ../datacache/neurips2023-small/\n",
      "    num_workers: 4\n",
      "    persistent_workers: false\n",
      "    featurization:\n",
      "      atom_property_list_onehot:\n",
      "      - atomic-number\n",
      "      - group\n",
      "      - period\n",
      "      - total-valence\n",
      "      atom_property_list_float:\n",
      "      - degree\n",
      "      - formal-charge\n",
      "      - radical-electron\n",
      "      - aromatic\n",
      "      - in-ring\n",
      "      edge_property_list:\n",
      "      - bond-type-onehot\n",
      "      - stereo\n",
      "      - in-ring\n",
      "      add_self_loop: false\n",
      "      explicit_H: false\n",
      "      use_bonds_weights: false\n",
      "      pos_encoding_as_features:\n",
      "        pos_types:\n",
      "          lap_eigvec:\n",
      "            pos_level: node\n",
      "            pos_type: laplacian_eigvec\n",
      "            num_pos: 8\n",
      "            normalization: none\n",
      "            disconnected_comp: true\n",
      "          lap_eigval:\n",
      "            pos_level: node\n",
      "            pos_type: laplacian_eigval\n",
      "            num_pos: 8\n",
      "            normalization: none\n",
      "            disconnected_comp: true\n",
      "          rw_pos:\n",
      "            pos_level: node\n",
      "            pos_type: rw_return_probs\n",
      "            ksteps: 16\n",
      "    task_specific_args:\n",
      "      qm9:\n",
      "        df: null\n",
      "        df_path: ${constants.data_dir}/qm9.csv.gz\n",
      "        smiles_col: smiles\n",
      "        label_cols:\n",
      "        - A\n",
      "        - B\n",
      "        - C\n",
      "        - mu\n",
      "        - alpha\n",
      "        - homo\n",
      "        - lumo\n",
      "        - gap\n",
      "        - r2\n",
      "        - zpve\n",
      "        - u0\n",
      "        - u298\n",
      "        - h298\n",
      "        - g298\n",
      "        - cv\n",
      "        - u0_atom\n",
      "        - u298_atom\n",
      "        - h298_atom\n",
      "        - g298_atom\n",
      "        splits_path: ${constants.data_dir}/qm9_random_splits.pt\n",
      "        seed: ${constants.seed}\n",
      "        task_level: graph\n",
      "        label_normalization:\n",
      "          normalize_val_test: true\n",
      "          method: normal\n",
      "      tox21:\n",
      "        df: null\n",
      "        df_path: ${constants.data_dir}/Tox21-7k-12-labels.csv.gz\n",
      "        smiles_col: smiles\n",
      "        label_cols:\n",
      "        - NR-AR\n",
      "        - NR-AR-LBD\n",
      "        - NR-AhR\n",
      "        - NR-Aromatase\n",
      "        - NR-ER\n",
      "        - NR-ER-LBD\n",
      "        - NR-PPAR-gamma\n",
      "        - SR-ARE\n",
      "        - SR-ATAD5\n",
      "        - SR-HSE\n",
      "        - SR-MMP\n",
      "        - SR-p53\n",
      "        splits_path: ${constants.data_dir}/Tox21_random_splits.pt\n",
      "        seed: ${constants.seed}\n",
      "        task_level: graph\n",
      "      zinc:\n",
      "        df: null\n",
      "        df_path: ${constants.data_dir}/ZINC12k.csv.gz\n",
      "        smiles_col: smiles\n",
      "        label_cols:\n",
      "        - SA\n",
      "        - logp\n",
      "        - score\n",
      "        splits_path: ${constants.data_dir}/ZINC12k_random_splits.pt\n",
      "        seed: ${constants.seed}\n",
      "        task_level: graph\n",
      "        label_normalization:\n",
      "          normalize_val_test: true\n",
      "          method: normal\n",
      "    batch_size_training: 200\n",
      "    batch_size_inference: 200\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print_config_with_key(yaml_config, \"datamodule\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Architecture\n",
    "\n",
    "The architecture is based on [`FullGraphMultiTaskNetwork`](https://graphium-docs.datamol.io/stable/api/graphium.nn/architectures.html#graphium.nn.architectures.global_architectures.FullGraphMultiTaskNetwork).\n",
    "Here, we define all the layers for the model, including the layers for the pre-processing MLP (input layers `pre-nn` and `pre_nn_edges`), the positional encoder (`pe_encoders`), the post-processing MLP (output layers `post-nn`), and the main GNN (graph neural network `gnn`).\n",
    "\n",
    "You can find details in the following: \n",
    "- info about the positional encoder in [`graphium.nn.encoders`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html)\n",
    "- info about the gnn layers in [`graphium.nn.pyg_layers`](https://graphium-docs.datamol.io/stable/api/graphium.nn/pyg_layers.html)\n",
    "- info about the architecture [`FullGraphMultiTaskNetwork`](https://graphium-docs.datamol.io/stable/api/graphium.nn/architectures.html#graphium.nn.architectures.global_architectures.FullGraphMultiTaskNetwork)\n",
    "- Main class for the GNN layers in [`BaseGraphStructure`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure)\n",
    "\n",
    "The parameters allow to chose the feature size, the depth, the skip connections, the pooling and the virtual node. It also support different GNN layers such as [`GatedGCNPyg`](https://graphium-docs.datamol.io/stable/api/graphium.nn/pyg_layers.html#graphium.nn.pyg_layers.gated_gcn_pyg), [`GINConvPyg`](https://graphium-docs.datamol.io/stable/api/graphium.nn/pyg_layers.html#graphium.nn.pyg_layers.gin_pyg), [`GINEConvPyg`](https://graphium-docs.datamol.io/stable/api/graphium.nn/pyg_layers.html#graphium.nn.pyg_layers.gin_pyg.GINEConvPyg), [`GPSLayerPyg`](https://graphium-docs.datamol.io/stable/api/graphium.nn/pyg_layers.html#graphium.nn.pyg_layers.gps_pyg.GPSLayerPyg), [`MPNNPlusPyg`](https://graphium-docs.datamol.io/stable/api/graphium.nn/pyg_layers.html#graphium.nn.pyg_layers.mpnn_pyg.MPNNPlusPyg).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "architecture:\n",
      "  model_type: FullGraphMultiTaskNetwork\n",
      "  mup_base_path: null\n",
      "  pre_nn:\n",
      "    out_dim: 64\n",
      "    hidden_dims: 256\n",
      "    depth: 2\n",
      "    activation: relu\n",
      "    last_activation: none\n",
      "    dropout: 0.18\n",
      "    normalization: layer_norm\n",
      "    last_normalization: ${architecture.pre_nn.normalization}\n",
      "    residual_type: none\n",
      "  pre_nn_edges: null\n",
      "  pe_encoders:\n",
      "    out_dim: 32\n",
      "    pool: sum\n",
      "    last_norm: None\n",
      "    encoders:\n",
      "      la_pos:\n",
      "        encoder_type: laplacian_pe\n",
      "        input_keys:\n",
      "        - laplacian_eigvec\n",
      "        - laplacian_eigval\n",
      "        output_keys:\n",
      "        - feat\n",
      "        hidden_dim: 64\n",
      "        out_dim: 32\n",
      "        model_type: DeepSet\n",
      "        num_layers: 2\n",
      "        num_layers_post: 1\n",
      "        dropout: 0.1\n",
      "        first_normalization: none\n",
      "      rw_pos:\n",
      "        encoder_type: mlp\n",
      "        input_keys:\n",
      "        - rw_return_probs\n",
      "        output_keys:\n",
      "        - feat\n",
      "        hidden_dim: 64\n",
      "        out_dim: 32\n",
      "        num_layers: 2\n",
      "        dropout: 0.1\n",
      "        normalization: layer_norm\n",
      "        first_normalization: layer_norm\n",
      "  gnn:\n",
      "    in_dim: 64\n",
      "    out_dim: 96\n",
      "    hidden_dims: 96\n",
      "    depth: 4\n",
      "    activation: gelu\n",
      "    last_activation: none\n",
      "    dropout: 0.1\n",
      "    normalization: layer_norm\n",
      "    last_normalization: ${architecture.pre_nn.normalization}\n",
      "    residual_type: simple\n",
      "    virtual_node: none\n",
      "    layer_type: pyg:gcn\n",
      "    layer_kwargs: null\n",
      "  graph_output_nn:\n",
      "    graph:\n",
      "      pooling:\n",
      "      - sum\n",
      "      out_dim: 96\n",
      "      hidden_dims: 96\n",
      "      depth: 1\n",
      "      activation: relu\n",
      "      last_activation: none\n",
      "      dropout: ${architecture.pre_nn.dropout}\n",
      "      normalization: ${architecture.pre_nn.normalization}\n",
      "      last_normalization: none\n",
      "      residual_type: none\n",
      "  task_heads:\n",
      "    qm9:\n",
      "      task_level: graph\n",
      "      out_dim: 19\n",
      "      hidden_dims: 128\n",
      "      depth: 2\n",
      "      activation: relu\n",
      "      last_activation: none\n",
      "      dropout: ${architecture.pre_nn.dropout}\n",
      "      normalization: ${architecture.pre_nn.normalization}\n",
      "      last_normalization: none\n",
      "      residual_type: none\n",
      "    tox21:\n",
      "      task_level: graph\n",
      "      out_dim: 12\n",
      "      hidden_dims: 64\n",
      "      depth: 2\n",
      "      activation: relu\n",
      "      last_activation: none\n",
      "      dropout: ${architecture.pre_nn.dropout}\n",
      "      normalization: ${architecture.pre_nn.normalization}\n",
      "      last_normalization: none\n",
      "      residual_type: none\n",
      "    zinc:\n",
      "      task_level: graph\n",
      "      out_dim: 3\n",
      "      hidden_dims: 32\n",
      "      depth: 2\n",
      "      activation: relu\n",
      "      last_activation: none\n",
      "      dropout: ${architecture.pre_nn.dropout}\n",
      "      normalization: ${architecture.pre_nn.normalization}\n",
      "      last_normalization: none\n",
      "      residual_type: none\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print_config_with_key(yaml_config, \"architecture\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predictor\n",
    "\n",
    "In the predictor, we define the loss functions, the metrics to track on the progress bar, and all the parameters necessary for the optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predictor:\n",
      "  metrics_on_progress_bar:\n",
      
Download .txt
gitextract_455isr46/

├── .github/
│   ├── CODEOWNERS
│   ├── CODE_OF_CONDUCT.md
│   ├── PULL_REQUEST_TEMPLATE.md
│   └── workflows/
│       ├── code-check.yml
│       ├── doc.yml
│       ├── release.yml
│       ├── test.yml
│       └── test_ipu.yml
├── .gitignore
├── LICENSE
├── README.md
├── cleanup_files.sh
├── codecov.yml
├── docs/
│   ├── _assets/
│   │   ├── css/
│   │   │   ├── custom-graphium.css
│   │   │   └── custom.css
│   │   └── js/
│   │       └── google-analytics.js
│   ├── api/
│   │   ├── graphium.config.md
│   │   ├── graphium.data.md
│   │   ├── graphium.features.md
│   │   ├── graphium.finetuning.md
│   │   ├── graphium.ipu.md
│   │   ├── graphium.nn/
│   │   │   ├── architectures.md
│   │   │   ├── encoders.md
│   │   │   ├── graphium.nn.md
│   │   │   └── pyg_layers.md
│   │   ├── graphium.trainer.md
│   │   └── graphium.utils.md
│   ├── baseline.md
│   ├── cli/
│   │   ├── graphium-train.md
│   │   ├── graphium.md
│   │   └── reference.md
│   ├── contribute.md
│   ├── datasets.md
│   ├── design.md
│   ├── index.md
│   ├── license.md
│   ├── pretrained_models.md
│   └── tutorials/
│       ├── feature_processing/
│       │   ├── add_new_positional_encoding.ipynb
│       │   ├── choosing_parallelization.ipynb
│       │   ├── csv_to_parquet.ipynb
│       │   └── timing_parallel.ipynb
│       ├── gnn/
│       │   ├── add_new_gnn_layers.ipynb
│       │   ├── making_gnn_networks.ipynb
│       │   └── using_gnn_layers.ipynb
│       └── model_training/
│           └── simple-molecular-model.ipynb
├── enable_ipu.sh
├── env.yml
├── expts/
│   ├── __init__.py
│   ├── configs/
│   │   ├── config_gps_10M_pcqm4m.yaml
│   │   ├── config_gps_10M_pcqm4m_mod.yaml
│   │   ├── config_mpnn_10M_b3lyp.yaml
│   │   └── config_mpnn_pcqm4m.yaml
│   ├── data/
│   │   ├── micro_zinc_splits.csv
│   │   └── tiny_zinc_splits.csv
│   ├── dataset_benchmark.py
│   ├── debug_yaml.py
│   ├── hydra-configs/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── accelerator/
│   │   │   ├── cpu.yaml
│   │   │   ├── gpu.yaml
│   │   │   ├── ipu.yaml
│   │   │   └── ipu_pipeline.yaml
│   │   ├── architecture/
│   │   │   ├── largemix.yaml
│   │   │   ├── pcqm4m.yaml
│   │   │   └── toymix.yaml
│   │   ├── experiment/
│   │   │   └── toymix_mpnn.yaml
│   │   ├── finetuning/
│   │   │   ├── admet.yaml
│   │   │   └── admet_baseline.yaml
│   │   ├── hparam_search/
│   │   │   └── optuna.yaml
│   │   ├── main.yaml
│   │   ├── model/
│   │   │   ├── gated_gcn.yaml
│   │   │   ├── gcn.yaml
│   │   │   ├── gin.yaml
│   │   │   ├── gine.yaml
│   │   │   ├── gpspp.yaml
│   │   │   └── mpnn.yaml
│   │   ├── tasks/
│   │   │   ├── admet.yaml
│   │   │   ├── l1000_mcf7.yaml
│   │   │   ├── l1000_vcap.yaml
│   │   │   ├── largemix.yaml
│   │   │   ├── loss_metrics_datamodule/
│   │   │   │   ├── admet.yaml
│   │   │   │   ├── l1000_mcf7.yaml
│   │   │   │   ├── l1000_vcap.yaml
│   │   │   │   ├── largemix.yaml
│   │   │   │   ├── pcba_1328.yaml
│   │   │   │   ├── pcqm4m.yaml
│   │   │   │   ├── pcqm4m_g25.yaml
│   │   │   │   ├── pcqm4m_n4.yaml
│   │   │   │   └── toymix.yaml
│   │   │   ├── pcba_1328.yaml
│   │   │   ├── pcqm4m.yaml
│   │   │   ├── pcqm4m_g25.yaml
│   │   │   ├── pcqm4m_n4.yaml
│   │   │   ├── task_heads/
│   │   │   │   ├── admet.yaml
│   │   │   │   ├── l1000_mcf7.yaml
│   │   │   │   ├── l1000_vcap.yaml
│   │   │   │   ├── largemix.yaml
│   │   │   │   ├── pcba_1328.yaml
│   │   │   │   ├── pcqm4m.yaml
│   │   │   │   ├── pcqm4m_g25.yaml
│   │   │   │   ├── pcqm4m_n4.yaml
│   │   │   │   └── toymix.yaml
│   │   │   └── toymix.yaml
│   │   └── training/
│   │       ├── accelerator/
│   │       │   ├── largemix_cpu.yaml
│   │       │   ├── largemix_gpu.yaml
│   │       │   ├── largemix_ipu.yaml
│   │       │   ├── pcqm4m_ipu.yaml
│   │       │   ├── toymix_cpu.yaml
│   │       │   ├── toymix_gpu.yaml
│   │       │   └── toymix_ipu.yaml
│   │       ├── largemix.yaml
│   │       ├── model/
│   │       │   ├── largemix_gated_gcn.yaml
│   │       │   ├── largemix_gcn.yaml
│   │       │   ├── largemix_gin.yaml
│   │       │   ├── largemix_gine.yaml
│   │       │   ├── largemix_mpnn.yaml
│   │       │   ├── pcqm4m_gpspp.yaml
│   │       │   ├── pcqm4m_mpnn.yaml
│   │       │   ├── toymix_gcn.yaml
│   │       │   └── toymix_gin.yaml
│   │       ├── pcqm4m.yaml
│   │       └── toymix.yaml
│   ├── main_run_get_fingerprints.py
│   ├── main_run_multitask.py
│   ├── main_run_predict.py
│   ├── main_run_test.py
│   ├── neurips2023_configs/
│   │   ├── base_config/
│   │   │   ├── large.yaml
│   │   │   ├── large_pcba.yaml
│   │   │   ├── large_pcqm_g25.yaml
│   │   │   ├── large_pcqm_n4.yaml
│   │   │   └── small.yaml
│   │   ├── baseline/
│   │   │   ├── config_small_gcn_baseline.yaml
│   │   │   ├── config_small_gin_baseline.yaml
│   │   │   └── config_small_gine_baseline.yaml
│   │   ├── config_classifigression_l1000.yaml
│   │   ├── config_large_gcn.yaml
│   │   ├── config_large_gcn_g25.yaml
│   │   ├── config_large_gcn_gpu.yaml
│   │   ├── config_large_gcn_n4.yaml
│   │   ├── config_large_gcn_pcba.yaml
│   │   ├── config_large_gin.yaml
│   │   ├── config_large_gin_g25.yaml
│   │   ├── config_large_gin_n4.yaml
│   │   ├── config_large_gin_pcba.yaml
│   │   ├── config_large_gine.yaml
│   │   ├── config_large_gine_g25.yaml
│   │   ├── config_large_gine_n4.yaml
│   │   ├── config_large_gine_pcba.yaml
│   │   ├── config_large_mpnn.yaml
│   │   ├── config_luis_jama.yaml
│   │   ├── config_small_gated_gcn.yaml
│   │   ├── config_small_gcn.yaml
│   │   ├── config_small_gcn_gpu.yaml
│   │   ├── config_small_gin.yaml
│   │   ├── config_small_gine.yaml
│   │   ├── config_small_mpnn.yaml
│   │   ├── single_task_gcn/
│   │   │   ├── config_large_gcn_mcf7.yaml
│   │   │   ├── config_large_gcn_pcba.yaml
│   │   │   └── config_large_gcn_vcap.yaml
│   │   ├── single_task_gin/
│   │   │   ├── config_large_gin_g25.yaml
│   │   │   ├── config_large_gin_mcf7.yaml
│   │   │   ├── config_large_gin_n4.yaml
│   │   │   ├── config_large_gin_pcba.yaml
│   │   │   ├── config_large_gin_pcq.yaml
│   │   │   └── config_large_gin_vcap.yaml
│   │   └── single_task_gine/
│   │       ├── config_large_gine_g25.yaml
│   │       ├── config_large_gine_mcf7.yaml
│   │       ├── config_large_gine_n4.yaml
│   │       ├── config_large_gine_pcba.yaml
│   │       ├── config_large_gine_pcq.yaml
│   │       └── config_large_gine_vcap.yaml
│   └── run_validation_test.py
├── graphium/
│   ├── __init__.py
│   ├── _version.py
│   ├── cli/
│   │   ├── __init__.py
│   │   ├── __main__.py
│   │   ├── data.py
│   │   ├── finetune_utils.py
│   │   ├── fingerprints.py
│   │   ├── main.py
│   │   ├── parameters.py
│   │   └── train_finetune_test.py
│   ├── config/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── _load.py
│   │   ├── _loader.py
│   │   ├── config_convert.py
│   │   ├── dummy_finetuning_from_gnn.yaml
│   │   ├── dummy_finetuning_from_task_head.yaml
│   │   ├── fake_and_missing_multilevel_multitask_pyg.yaml
│   │   ├── fake_multilevel_multitask_pyg.yaml
│   │   └── zinc_default_multitask_pyg.yaml
│   ├── data/
│   │   ├── L1000/
│   │   │   └── datasets_goli-L1000_parse_gctx_to_csv.ipynb
│   │   ├── QM9/
│   │   │   ├── micro_qm9.csv
│   │   │   ├── micro_qm9.parquet
│   │   │   └── norm_micro_qm9.csv
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── collate.py
│   │   ├── datamodule.py
│   │   ├── dataset.py
│   │   ├── make_data_splits/
│   │   │   ├── make_train_val_test_splits_large.ipynb
│   │   │   └── make_train_val_test_splits_small.ipynb
│   │   ├── micro_ZINC/
│   │   │   ├── __init__.py
│   │   │   └── micro_ZINC.csv
│   │   ├── multilevel_utils.py
│   │   ├── multitask/
│   │   │   ├── __init__.py
│   │   │   ├── tiny_ZINC_SA.csv
│   │   │   ├── tiny_ZINC_logp.csv
│   │   │   └── tiny_ZINC_score.csv
│   │   ├── normalization.py
│   │   ├── sampler.py
│   │   ├── sdf2csv.py
│   │   ├── single_atom_dataset/
│   │   │   └── single_atom_dataset.csv
│   │   ├── smiles_transform.py
│   │   └── utils.py
│   ├── expts/
│   │   └── pyg_batching_sparse.ipynb
│   ├── features/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── commute.py
│   │   ├── electrostatic.py
│   │   ├── featurizer.py
│   │   ├── graphormer.py
│   │   ├── nmp.py
│   │   ├── periodic_table.csv
│   │   ├── positional_encoding.py
│   │   ├── properties.py
│   │   ├── rw.py
│   │   ├── spectral.py
│   │   └── transfer_pos_level.py
│   ├── finetuning/
│   │   ├── __init__.py
│   │   ├── finetuning.py
│   │   ├── finetuning_architecture.py
│   │   ├── fingerprinting.py
│   │   └── utils.py
│   ├── hyper_param_search/
│   │   ├── __init__.py
│   │   └── results.py
│   ├── ipu/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── ipu_dataloader.py
│   │   ├── ipu_losses.py
│   │   ├── ipu_metrics.py
│   │   ├── ipu_simple_lightning.py
│   │   ├── ipu_utils.py
│   │   ├── ipu_wrapper.py
│   │   └── to_dense_batch.py
│   ├── nn/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── architectures/
│   │   │   ├── README.md
│   │   │   ├── __init__.py
│   │   │   ├── encoder_manager.py
│   │   │   ├── global_architectures.py
│   │   │   └── pyg_architectures.py
│   │   ├── base_graph_layer.py
│   │   ├── base_layers.py
│   │   ├── encoders/
│   │   │   ├── README.md
│   │   │   ├── __init__.py
│   │   │   ├── base_encoder.py
│   │   │   ├── bessel_pos_encoder.py
│   │   │   ├── gaussian_kernel_pos_encoder.py
│   │   │   ├── laplace_pos_encoder.py
│   │   │   ├── mlp_encoder.py
│   │   │   └── signnet_pos_encoder.py
│   │   ├── ensemble_layers.py
│   │   ├── pyg_layers/
│   │   │   ├── README.md
│   │   │   ├── __init__.py
│   │   │   ├── dimenet_pyg.py
│   │   │   ├── gated_gcn_pyg.py
│   │   │   ├── gcn_pyg.py
│   │   │   ├── gin_pyg.py
│   │   │   ├── gps_pyg.py
│   │   │   ├── mpnn_pyg.py
│   │   │   ├── pna_pyg.py
│   │   │   ├── pooling_pyg.py
│   │   │   └── utils.py
│   │   ├── residual_connections.py
│   │   └── utils.py
│   ├── trainer/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── losses.py
│   │   ├── metrics.py
│   │   ├── predictor.py
│   │   ├── predictor_options.py
│   │   └── predictor_summaries.py
│   └── utils/
│       ├── README.md
│       ├── __init__.py
│       ├── arg_checker.py
│       ├── command_line_utils.py
│       ├── custom_lr.py
│       ├── decorators.py
│       ├── fs.py
│       ├── hashing.py
│       ├── moving_average_tracker.py
│       ├── mup.py
│       ├── packing.py
│       ├── safe_run.py
│       ├── spaces.py
│       └── tensor.py
├── install_ipu.sh
├── mkdocs.yml
├── notebooks/
│   ├── compare-pretraining-finetuning-performance.ipynb
│   ├── dev-datamodule-invalidate-cache.ipynb
│   ├── dev-datamodule-ogb.ipynb
│   ├── dev-datamodule.ipynb
│   ├── dev-pretrained.ipynb
│   ├── dev-training-loop.ipynb
│   ├── dev.ipynb
│   ├── finetuning-on-tdc-admet-benchmark.ipynb
│   ├── running-fingerprints-from-pretrained-model.ipynb
│   └── running-model-from-config.ipynb
├── profiling/
│   ├── configs_profiling.yaml
│   ├── profile_mol_to_graph.py
│   ├── profile_one_of_k_encoding.py
│   └── profile_predictor.py
├── pyproject.toml
├── scripts/
│   ├── balance_params_and_train.sh
│   ├── convert_yml.py
│   ├── ipu_start.sh
│   ├── ipu_venv.sh
│   └── scale_mpnn.sh
└── tests/
    ├── .gitignore
    ├── __init__.py
    ├── config_test_ipu_dataloader.yaml
    ├── config_test_ipu_dataloader_multitask.yaml
    ├── conftest.py
    ├── converted_fake_multilevel_data.parquet
    ├── data/
    │   ├── config_micro_ZINC.yaml
    │   ├── micro_ZINC.csv
    │   ├── micro_ZINC_corrupt.csv
    │   ├── micro_ZINC_shard_1.csv
    │   ├── micro_ZINC_shard_1.parquet
    │   ├── micro_ZINC_shard_2.csv
    │   ├── micro_ZINC_shard_2.parquet
    │   └── pcqm4mv2-2k.csv
    ├── fake_and_missing_multilevel_data.parquet
    ├── test_architectures.py
    ├── test_attention.py
    ├── test_base_layers.py
    ├── test_collate.py
    ├── test_data_utils.py
    ├── test_datamodule.py
    ├── test_dataset.py
    ├── test_ensemble_layers.py
    ├── test_featurizer.py
    ├── test_finetuning.py
    ├── test_ipu_dataloader.py
    ├── test_ipu_losses.py
    ├── test_ipu_metrics.py
    ├── test_ipu_options.py
    ├── test_ipu_poptorch.py
    ├── test_ipu_to_dense_batch.py
    ├── test_loaders.py
    ├── test_losses.py
    ├── test_metrics.py
    ├── test_mtl_architecture.py
    ├── test_multitask_datamodule.py
    ├── test_mup.py
    ├── test_packing.py
    ├── test_pe_nodepair.py
    ├── test_pe_rw.py
    ├── test_pe_spectral.py
    ├── test_pos_transfer_funcs.py
    ├── test_positional_encoders.py
    ├── test_positional_encodings.py
    ├── test_predictor.py
    ├── test_pyg_layers.py
    ├── test_residual_connections.py
    ├── test_training.py
    └── test_utils.py
Download .txt
SYMBOL INDEX (1124 symbols across 124 files)

FILE: docs/_assets/js/google-analytics.js
  function gtag (line 8) | function gtag() {

FILE: expts/dataset_benchmark.py
  function benchmark (line 27) | def benchmark(fn, *args, message="", log2wandb=False, **kwargs):
  function benchmark_dataloader (line 37) | def benchmark_dataloader(dataloader, name, n_epochs=5, log2wandb=False):
  function main (line 81) | def main(

FILE: expts/debug_yaml.py
  function get_anchors_and_aliases (line 41) | def get_anchors_and_aliases(filepath):

FILE: expts/main_run_get_fingerprints.py
  function main (line 27) | def main() -> None:

FILE: expts/main_run_multitask.py
  function main (line 6) | def main(cfg: DictConfig) -> None:

FILE: expts/main_run_predict.py
  function main (line 31) | def main(cfg: DictConfig) -> None:

FILE: expts/main_run_test.py
  function main (line 25) | def main(cfg: DictConfig) -> None:

FILE: expts/run_validation_test.py
  function main (line 37) | def main(cfg: DictConfig) -> None:

FILE: graphium/cli/data.py
  function download (line 19) | def download(name: str, output: str, progress: bool = True):
  function list (line 34) | def list():
  function prepare_data (line 40) | def prepare_data(overrides: List[str]) -> None:

FILE: graphium/cli/finetune_utils.py
  function benchmark_tdc_admet_cli (line 28) | def benchmark_tdc_admet_cli(
  function get_fingerprints_from_model (line 80) | def get_fingerprints_from_model(
  function get_tdc_task_specific (line 139) | def get_tdc_task_specific(task: str, output: Literal["name", "mode", "la...

FILE: graphium/cli/fingerprints.py
  function get_fingerprints_from_model (line 5) | def get_fingerprints_from_model(): ...

FILE: graphium/cli/parameters.py
  function infer_parameter_count (line 26) | def infer_parameter_count(overrides: List[str] = []) -> int:

FILE: graphium/cli/train_finetune_test.py
  function cli (line 48) | def cli(cfg: DictConfig) -> None:
  function get_replication_factor (line 55) | def get_replication_factor(cfg):
  function get_gradient_accumulation_factor (line 72) | def get_gradient_accumulation_factor(cfg):
  function get_training_batch_size (line 90) | def get_training_batch_size(cfg):
  function get_training_device_iterations (line 108) | def get_training_device_iterations(cfg):
  function run_training_finetuning_testing (line 125) | def run_training_finetuning_testing(cfg: DictConfig) -> None:

FILE: graphium/config/_load.py
  function load_config (line 18) | def load_config(name: str):

FILE: graphium/config/_loader.py
  function get_accelerator (line 49) | def get_accelerator(
  function _get_ipu_opts (line 82) | def _get_ipu_opts(config: Union[omegaconf.DictConfig, Dict[str, Any]]) -...
  function load_datamodule (line 98) | def load_datamodule(
  function load_metrics (line 167) | def load_metrics(config: Union[omegaconf.DictConfig, Dict[str, Any]]) ->...
  function load_architecture (line 192) | def load_architecture(
  function load_predictor (line 297) | def load_predictor(
  function load_mup (line 366) | def load_mup(mup_base_path: str, predictor: PredictorModule) -> Predicto...
  function load_trainer (line 388) | def load_trainer(
  function save_params_to_wandb (line 469) | def save_params_to_wandb(
  function load_accelerator (line 515) | def load_accelerator(config: Union[omegaconf.DictConfig, Dict[str, Any]]...
  function load_config_override (line 532) | def load_config_override(
  function load_yaml_config (line 546) | def load_yaml_config(
  function merge_dicts (line 579) | def merge_dicts(
  function get_checkpoint_path (line 626) | def get_checkpoint_path(config: Union[omegaconf.DictConfig, Dict[str, An...

FILE: graphium/config/config_convert.py
  function recursive_config_reformating (line 17) | def recursive_config_reformating(configs):

FILE: graphium/data/collate.py
  function graphium_collate_fn (line 30) | def graphium_collate_fn(
  function collage_pyg_graph (line 127) | def collage_pyg_graph(pyg_graphs: Iterable[Union[Data, Dict]], batch_siz...
  function pad_to_expected_label_size (line 184) | def pad_to_expected_label_size(labels: torch.Tensor, label_size: List[in...
  function collate_pyg_graph_labels (line 205) | def collate_pyg_graph_labels(pyg_labels: List[Data]):
  function get_expected_label_size (line 228) | def get_expected_label_size(label_data: Data, task: str, label_size: Lis...
  function collate_labels (line 243) | def collate_labels(
  function pad_nodepairs (line 305) | def pad_nodepairs(pe: torch.Tensor, num_nodes: int, max_num_nodes_per_gr...

FILE: graphium/data/datamodule.py
  class BaseDataModule (line 104) | class BaseDataModule(lightning.LightningDataModule):
    method __init__ (line 105) | def __init__(
    method prepare_data (line 157) | def prepare_data(self):
    method setup (line 160) | def setup(self):
    method train_dataloader (line 163) | def train_dataloader(self, **kwargs):
    method val_dataloader (line 174) | def val_dataloader(self, **kwargs):
    method test_dataloader (line 185) | def test_dataloader(self, **kwargs):
    method predict_dataloader (line 196) | def predict_dataloader(self, **kwargs):
    method get_collate_fn (line 207) | def get_collate_fn(self, collate_fn):
    method is_prepared (line 218) | def is_prepared(self):
    method is_setup (line 222) | def is_setup(self):
    method num_node_feats (line 226) | def num_node_feats(self):
    method num_edge_feats (line 230) | def num_edge_feats(self):
    method predict_ds (line 234) | def predict_ds(self):
    method get_num_workers (line 242) | def get_num_workers(self):
    method predict_ds (line 254) | def predict_ds(self, value):
    method get_fake_graph (line 258) | def get_fake_graph(self):
    method _read_csv (line 264) | def _read_csv(
    method _get_data_file_type (line 294) | def _get_data_file_type(path):
    method _get_table_columns (line 319) | def _get_table_columns(path: str) -> List[str]:
    method _read_parquet (line 348) | def _read_parquet(path, **kwargs):
    method _read_sdf (line 398) | def _read_sdf(path: str, mol_col_name: str = "_rdkit_molecule_obj", **...
    method _glob (line 445) | def _glob(path: str) -> List[str]:
    method _read_table (line 457) | def _read_table(self, path: str, **kwargs) -> pd.DataFrame:
    method get_dataloader_kwargs (line 487) | def get_dataloader_kwargs(self, stage: RunningStage, shuffle: bool, **...
    method get_dataloader (line 524) | def get_dataloader(self, dataset: Dataset, shuffle: bool, stage: Runni...
    method _dataloader (line 539) | def _dataloader(self, dataset: Dataset, **kwargs) -> DataLoader:
    method get_max_num_nodes_datamodule (line 556) | def get_max_num_nodes_datamodule(self, stages: Optional[List[str]] = N...
    method get_max_num_edges_datamodule (line 610) | def get_max_num_edges_datamodule(self, stages: Optional[List[str]] = N...
  class DatasetProcessingParams (line 667) | class DatasetProcessingParams:
    method __init__ (line 668) | def __init__(
  class IPUDataModuleModifier (line 731) | class IPUDataModuleModifier:
    method __init__ (line 732) | def __init__(
    method _dataloader (line 763) | def _dataloader(self, dataset: Dataset, **kwargs) -> "poptorch.DataLoa...
  class MultitaskFromSmilesDataModule (line 788) | class MultitaskFromSmilesDataModule(BaseDataModule, IPUDataModuleModifier):
    method __init__ (line 789) | def __init__(
    method _parse_caching_args (line 924) | def _parse_caching_args(self, processed_graph_data_path, dataloading_f...
    method _get_task_key (line 945) | def _get_task_key(self, task_level: str, task: str):
    method get_task_levels (line 951) | def get_task_levels(self):
    method prepare_data (line 961) | def prepare_data(self, save_smiles_and_ids: bool = False):
    method setup (line 1163) | def setup(
    method _make_multitask_dataset (line 1219) | def _make_multitask_dataset(
    method _ready_to_load_all_from_file (line 1276) | def _ready_to_load_all_from_file(self) -> bool:
    method _path_to_load_from_file (line 1286) | def _path_to_load_from_file(self, stage: Literal["train", "val", "test...
    method _data_ready_at_path (line 1294) | def _data_ready_at_path(self, path: str) -> bool:
    method _save_data_to_files (line 1302) | def _save_data_to_files(self, save_smiles_and_ids: bool = False) -> None:
    method get_folder_size (line 1329) | def get_folder_size(self, path):
    method calculate_statistics (line 1333) | def calculate_statistics(self, dataset: Datasets.MultitaskDataset, tra...
    method get_label_statistics (line 1360) | def get_label_statistics(
    method normalize_label (line 1391) | def normalize_label(self, dataset: Datasets.MultitaskDataset, stage) -...
    method save_featurized_data (line 1411) | def save_featurized_data(self, dataset: Datasets.MultitaskDataset, pro...
    method process_func (line 1425) | def process_func(self, param):
    method get_dataloader_kwargs (line 1435) | def get_dataloader_kwargs(self, stage: RunningStage, shuffle: bool, **...
    method get_dataloader (line 1471) | def get_dataloader(
    method get_collate_fn (line 1504) | def get_collate_fn(self, collate_fn):
    method _featurize_molecules (line 1517) | def _featurize_molecules(self, smiles: Iterable[str]) -> Tuple[List, L...
    method _filter_none_molecules (line 1571) | def _filter_none_molecules(
    method _parse_label_cols (line 1620) | def _parse_label_cols(
    method is_prepared (line 1670) | def is_prepared(self):
    method is_setup (line 1676) | def is_setup(self):
    method num_node_feats (line 1682) | def num_node_feats(self):
    method in_dims (line 1689) | def in_dims(self):
    method num_edge_feats (line 1711) | def num_edge_feats(self):
    method get_fake_graph (line 1720) | def get_fake_graph(self):
    method _save_to_cache (line 1737) | def _save_to_cache(self):
    method _load_from_cache (line 1740) | def _load_from_cache(self):
    method _extract_smiles_labels (line 1743) | def _extract_smiles_labels(
    method _get_split_indices (line 1841) | def _get_split_indices(
    method _sub_sample_df (line 1928) | def _sub_sample_df(
    method get_data_hash (line 1954) | def get_data_hash(self):
    method get_data_cache_fullname (line 1983) | def get_data_cache_fullname(self, compress: bool = False) -> str:
    method load_data_from_cache (line 2000) | def load_data_from_cache(self, verbose: bool = True, compress: bool = ...
    method get_subsets_of_datasets (line 2056) | def get_subsets_of_datasets(
    method __len__ (line 2085) | def __len__(self) -> int:
    method to_dict (line 2100) | def to_dict(self) -> Dict[str, Any]:
    method __repr__ (line 2125) | def __repr__(self) -> str:
  class GraphOGBDataModule (line 2135) | class GraphOGBDataModule(MultitaskFromSmilesDataModule):
    method __init__ (line 2136) | def __init__(
    method to_dict (line 2234) | def to_dict(self) -> Dict[str, Any]:
    method _load_dataset (line 2248) | def _load_dataset(
    method _get_dataset_metadata (line 2340) | def _get_dataset_metadata(self, dataset_name: str) -> Dict[str, Any]:
    method _get_ogb_metadata (line 2347) | def _get_ogb_metadata(self):
  class ADMETBenchmarkDataModule (line 2367) | class ADMETBenchmarkDataModule(MultitaskFromSmilesDataModule):
    method __init__ (line 2395) | def __init__(
    method _get_task_specific_arguments (line 2478) | def _get_task_specific_arguments(self, name: str, seed: int, cache_dir...
  class FakeDataModule (line 2536) | class FakeDataModule(MultitaskFromSmilesDataModule):
    method __init__ (line 2544) | def __init__(
    method generate_data (line 2576) | def generate_data(self, label_cols: List[str], smiles_col: str):
    method prepare_data (line 2598) | def prepare_data(self):
    method setup (line 2700) | def setup(self, stage=None):
    method get_fake_graph (line 2730) | def get_fake_graph(self):

FILE: graphium/data/dataset.py
  class SingleTaskDataset (line 32) | class SingleTaskDataset(Dataset):
    method __init__ (line 33) | def __init__(
    method __len__ (line 87) | def __len__(self):
    method __getitem__ (line 95) | def __getitem__(self, idx):
    method __getstate__ (line 128) | def __getstate__(self):
    method __setstate__ (line 140) | def __setstate__(self, state: dict):
  class MultitaskDataset (line 149) | class MultitaskDataset(Dataset):
    method __init__ (line 152) | def __init__(
    method transfer_from_disk_to_ram (line 231) | def transfer_from_disk_to_ram(self, parallel_with_batches: bool = False):
    method save_metadata (line 271) | def save_metadata(self, directory: str):
    method _load_metadata (line 290) | def _load_metadata(self):
    method __len__ (line 327) | def __len__(self):
    method num_nodes_list (line 334) | def num_nodes_list(self):
    method num_edges_list (line 346) | def num_edges_list(self):
    method num_graphs_total (line 358) | def num_graphs_total(self):
    method num_nodes_total (line 365) | def num_nodes_total(self):
    method max_num_nodes_per_graph (line 372) | def max_num_nodes_per_graph(self):
    method std_num_nodes_per_graph (line 379) | def std_num_nodes_per_graph(self):
    method min_num_nodes_per_graph (line 386) | def min_num_nodes_per_graph(self):
    method mean_num_nodes_per_graph (line 393) | def mean_num_nodes_per_graph(self):
    method num_edges_total (line 400) | def num_edges_total(self):
    method max_num_edges_per_graph (line 407) | def max_num_edges_per_graph(self):
    method min_num_edges_per_graph (line 414) | def min_num_edges_per_graph(self):
    method std_num_edges_per_graph (line 421) | def std_num_edges_per_graph(self):
    method mean_num_edges_per_graph (line 428) | def mean_num_edges_per_graph(self):
    method __getitem__ (line 434) | def __getitem__(self, idx):
    method load_graph_from_index (line 464) | def load_graph_from_index(self, data_idx):
    method merge (line 479) | def merge(
    method _get_all_lists_ids (line 531) | def _get_all_lists_ids(self, datasets: Dict[str, SingleTaskDataset]) -...
    method _get_inv_of_mol_ids (line 578) | def _get_inv_of_mol_ids(self, all_mol_ids):
    method _find_valid_label (line 582) | def _find_valid_label(self, task, ds):
    method set_label_size_dict (line 597) | def set_label_size_dict(self, datasets: Dict[str, SingleTaskDataset]):
    method set_label_dtype_dict (line 615) | def set_label_dtype_dict(self, datasets: Dict[str, SingleTaskDataset]):
    method __repr__ (line 630) | def __repr__(self) -> str:
  class FakeDataset (line 673) | class FakeDataset(MultitaskDataset):
    method __init__ (line 678) | def __init__(
    method _get_inv_of_mol_ids (line 715) | def _get_inv_of_mol_ids(self, all_mol_ids):
    method deepcopy_mol (line 722) | def deepcopy_mol(self, mol_ids, labels, smiles, features=None):
    method __len__ (line 747) | def __len__(self):
    method __getitem__ (line 753) | def __getitem__(self, idx):
  function get_num_nodes_per_graph (line 774) | def get_num_nodes_per_graph(graphs):
  function get_num_edges_per_graph (line 784) | def get_num_edges_per_graph(graphs):

FILE: graphium/data/multilevel_utils.py
  function extract_labels (line 22) | def extract_labels(df: pd.DataFrame, task_level: str, label_cols: List[s...

FILE: graphium/data/normalization.py
  class LabelNormalization (line 21) | class LabelNormalization:
    method __init__ (line 22) | def __init__(
    method calculate_statistics (line 59) | def calculate_statistics(self, array):
    method normalize (line 74) | def normalize(self, input):
    method denormalize (line 100) | def denormalize(self, input):

FILE: graphium/data/sampler.py
  class DatasetSubSampler (line 24) | class DatasetSubSampler(data_utils.Sampler):
    method __init__ (line 25) | def __init__(
    method __iter__ (line 60) | def __iter__(self):
    method __len__ (line 69) | def __len__(self):
    method check_sampling_required (line 73) | def check_sampling_required(cls, sampler_task_dict):

FILE: graphium/data/sdf2csv.py
  function extract_zip (line 24) | def extract_zip(fname):
  function extract_mols_from_sdf (line 32) | def extract_mols_from_sdf(fname):
  function mols2cxs (line 41) | def mols2cxs(mols):
  function write_csv (line 51) | def write_csv(cxs: any, homos: any, fname: str):
  function sdf2csv (line 66) | def sdf2csv(sdf_name: str = "pcqm4m-v2-train", outname: str = "pcqm4m-v2...

FILE: graphium/data/smiles_transform.py
  function smiles_to_unique_mol_id (line 20) | def smiles_to_unique_mol_id(smiles: str) -> Optional[str]:
  function did_featurization_fail (line 40) | def did_featurization_fail(features: Any) -> bool:
  class BatchingSmilesTransform (line 47) | class BatchingSmilesTransform:
    method __init__ (line 52) | def __init__(self, transform: Callable):
    method __call__ (line 59) | def __call__(self, smiles_list: Iterable[str]) -> Any:
    method parse_batch_size (line 69) | def parse_batch_size(numel: int, desired_batch_size: int, n_jobs: int)...
  function smiles_to_unique_mol_ids (line 91) | def smiles_to_unique_mol_ids(

FILE: graphium/data/utils.py
  function load_micro_zinc (line 36) | def load_micro_zinc() -> pd.DataFrame:
  function load_tiny_zinc (line 49) | def load_tiny_zinc() -> pd.DataFrame:
  function graphium_package_path (line 62) | def graphium_package_path(graphium_path: str) -> str:
  function list_graphium_datasets (line 76) | def list_graphium_datasets() -> set:
  function download_graphium_dataset (line 85) | def download_graphium_dataset(
  function get_keys (line 124) | def get_keys(pyg_data):
  function found_size_mismatch (line 131) | def found_size_mismatch(task: str, features: Union[Data, GraphDict], lab...

FILE: graphium/features/commute.py
  function compute_commute_distances (line 22) | def compute_commute_distances(

FILE: graphium/features/electrostatic.py
  function compute_electrostatic_interactions (line 22) | def compute_electrostatic_interactions(

FILE: graphium/features/featurizer.py
  function to_dense_array (line 33) | def to_dense_array(array: np.ndarray, dtype: str = None) -> np.ndarray:
  function to_dense_tensor (line 53) | def to_dense_tensor(tensor: Tensor, dtype: str = None) -> Tensor:
  function _mask_nans_inf (line 70) | def _mask_nans_inf(mask_nan: Optional[str], array: np.ndarray, array_nam...
  function get_mol_atomic_features_onehot (line 103) | def get_mol_atomic_features_onehot(mol: dm.Mol, property_list: List[str]...
  function get_mol_conformer_features (line 194) | def get_mol_conformer_features(
  function get_mol_atomic_features_float (line 247) | def get_mol_atomic_features_float(
  function get_simple_mol_conformer (line 444) | def get_simple_mol_conformer(mol: dm.Mol) -> Union[Chem.rdchem.Conformer...
  function get_estimated_bond_length (line 481) | def get_estimated_bond_length(bond: Chem.rdchem.Bond, mol: dm.Mol) -> fl...
  function get_mol_edge_features (line 544) | def get_mol_edge_features(
  function mol_to_adj_and_features (line 634) | def mol_to_adj_and_features(
  function mol_to_adjacency_matrix (line 791) | def mol_to_adjacency_matrix(
  class GraphDict (line 856) | class GraphDict(dict):
    method __init__ (line 857) | def __init__(
    method keys (line 903) | def keys(self):
    method values (line 907) | def values(self):
    method make_pyg_graph (line 910) | def make_pyg_graph(self, **kwargs) -> Data:
    method adj (line 947) | def adj(self):
    method dtype (line 951) | def dtype(self):
    method mask_nan (line 955) | def mask_nan(self):
    method num_nodes (line 959) | def num_nodes(self) -> int:
    method num_edges (line 963) | def num_edges(self) -> int:
  function mol_to_graph_dict (line 970) | def mol_to_graph_dict(
  function mol_to_pyggraph (line 1138) | def mol_to_pyggraph(
  function mol_to_graph_signature (line 1252) | def mol_to_graph_signature(featurizer_args: Dict[str, Any] = None) -> Di...

FILE: graphium/features/graphormer.py
  function compute_graphormer_distances (line 22) | def compute_graphormer_distances(

FILE: graphium/features/nmp.py
  function float_or_none (line 29) | def float_or_none(string: str) -> Union[float, None]:

FILE: graphium/features/positional_encoding.py
  function get_all_positional_encodings (line 29) | def get_all_positional_encodings(
  function graph_positional_encoder (line 76) | def graph_positional_encoder(

FILE: graphium/features/properties.py
  function get_prop_or_none (line 23) | def get_prop_or_none(
  function get_props_from_mol (line 43) | def get_props_from_mol(

FILE: graphium/features/rw.py
  function compute_rwse (line 25) | def compute_rwse(
  function get_Pks (line 122) | def get_Pks(

FILE: graphium/features/spectral.py
  function compute_laplacian_pe (line 24) | def compute_laplacian_pe(
  function _get_positional_eigvecs (line 122) | def _get_positional_eigvecs(
  function normalize_matrix (line 158) | def normalize_matrix(

FILE: graphium/features/transfer_pos_level.py
  function transfer_pos_level (line 23) | def transfer_pos_level(
  function node_to_edge (line 110) | def node_to_edge(
  function node_to_nodepair (line 150) | def node_to_nodepair(pe: np.ndarray, num_nodes: int) -> np.ndarray:
  function node_to_graph (line 174) | def node_to_graph(pe: np.ndarray, num_nodes: int) -> np.ndarray:
  function edge_to_node (line 190) | def edge_to_node(pe: np.ndarray, adj: Union[np.ndarray, spmatrix]) -> np...
  function edge_to_nodepair (line 206) | def edge_to_nodepair(
  function edge_to_graph (line 246) | def edge_to_graph(pe: np.ndarray) -> np.ndarray:
  function nodepair_to_node (line 260) | def nodepair_to_node(pe: np.ndarray, stats_list: List = [np.min, np.mean...
  function nodepair_to_edge (line 286) | def nodepair_to_edge(
  function nodepair_to_graph (line 325) | def nodepair_to_graph(pe: np.ndarray, num_nodes: int) -> np.ndarray:
  function graph_to_node (line 341) | def graph_to_node(

FILE: graphium/finetuning/finetuning.py
  class GraphFinetuning (line 25) | class GraphFinetuning(BaseFinetuning):
    method __init__ (line 26) | def __init__(
    method freeze_before_training (line 55) | def freeze_before_training(self, pl_module: pl.LightningModule):
    method freeze_module (line 73) | def freeze_module(self, pl_module, module_name: str, module_map: Dict[...
    method finetune_function (line 96) | def finetune_function(self, pl_module: pl.LightningModule, epoch: int,...

FILE: graphium/finetuning/finetuning_architecture.py
  class FullGraphFinetuningNetwork (line 26) | class FullGraphFinetuningNetwork(nn.Module, MupMixin):
    method __init__ (line 27) | def __init__(
    method forward (line 98) | def forward(self, g: Batch) -> Tensor:
    method make_mup_base_kwargs (line 139) | def make_mup_base_kwargs(self, divide_factor: float = 2.0) -> Dict[str...
    method set_max_num_nodes_edges_per_graph (line 173) | def set_max_num_nodes_edges_per_graph(self, max_nodes: Optional[int], ...
  class PretrainedModel (line 188) | class PretrainedModel(nn.Module, MupMixin):
    method __init__ (line 189) | def __init__(
    method forward (line 226) | def forward(self, g: Union[torch.Tensor, Batch]):
    method overwrite_with_pretrained (line 231) | def overwrite_with_pretrained(
    method make_mup_base_kwargs (line 281) | def make_mup_base_kwargs(self, divide_factor: float = 2.0) -> Dict[str...
  class FinetuningHead (line 299) | class FinetuningHead(nn.Module, MupMixin):
    method __init__ (line 300) | def __init__(self, finetuning_head_kwargs: Dict[str, Any]):
    method forward (line 320) | def forward(self, g: Union[Dict[str, Union[torch.Tensor, Batch]], torc...
    method make_mup_base_kwargs (line 332) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d...

FILE: graphium/finetuning/fingerprinting.py
  class Fingerprinter (line 25) | class Fingerprinter:
    method __init__ (line 89) | def __init__(
    method setup (line 133) | def setup(self):
    method get_fingerprints_for_batch (line 140) | def get_fingerprints_for_batch(self, batch):
    method get_fingerprints_for_dataset (line 165) | def get_fingerprints_for_dataset(self, dataloader):
    method teardown (line 181) | def teardown(self):
    method __enter__ (line 190) | def __enter__(self):
    method __exit__ (line 194) | def __exit__(self, exc_type, exc_val, exc_tb):
    method _convert_output_type (line 200) | def _convert_output_type(self, feats: torch.Tensor):

FILE: graphium/finetuning/utils.py
  function filter_cfg_based_on_admet_benchmark_name (line 24) | def filter_cfg_based_on_admet_benchmark_name(config: Dict[str, Any], nam...
  function modify_cfg_for_finetuning (line 59) | def modify_cfg_for_finetuning(cfg: Dict[str, Any]):
  function update_cfg_arch_for_module (line 179) | def update_cfg_arch_for_module(

FILE: graphium/hyper_param_search/results.py
  function extract_main_metric_for_hparam_search (line 17) | def extract_main_metric_for_hparam_search(results: dict, cfg: dict):

FILE: graphium/ipu/ipu_dataloader.py
  class IPUDataloaderOptions (line 37) | class IPUDataloaderOptions:
    method set_kwargs (line 56) | def set_kwargs(self):
  class CombinedBatchingCollator (line 96) | class CombinedBatchingCollator:
    method __init__ (line 106) | def __init__(
    method __call__ (line 132) | def __call__(
  function create_ipu_dataloader (line 195) | def create_ipu_dataloader(
  class Pad (line 296) | class Pad(BaseTransform):
    method __init__ (line 301) | def __init__(
    method validate (line 333) | def validate(self, data):
    method __call__ (line 357) | def __call__(self, batch: Batch) -> Batch:
    method forward (line 360) | def forward(self, batch: Batch) -> Batch:
    method _call (line 363) | def _call(self, batch: Batch) -> Batch:
    method __repr__ (line 427) | def __repr__(self) -> str:

FILE: graphium/ipu/ipu_losses.py
  class BCEWithLogitsLossIPU (line 22) | class BCEWithLogitsLossIPU(BCEWithLogitsLoss):
    method forward (line 29) | def forward(self, input: Tensor, target: Tensor) -> Tensor:
  class BCELossIPU (line 67) | class BCELossIPU(BCELoss):
    method forward (line 74) | def forward(self, input: Tensor, target: Tensor) -> Tensor:
  class MSELossIPU (line 112) | class MSELossIPU(MSELoss):
    method forward (line 120) | def forward(self, input: Tensor, target: Tensor) -> Tensor:
  class L1LossIPU (line 140) | class L1LossIPU(L1Loss):
    method forward (line 148) | def forward(self, input: Tensor, target: Tensor) -> Tensor:
  class HybridCELossIPU (line 167) | class HybridCELossIPU(HybridCELoss):
    method __init__ (line 168) | def __init__(
    method forward (line 180) | def forward(self, input: Tensor, target: Tensor) -> Tensor:

FILE: graphium/ipu/ipu_metrics.py
  function auroc_ipu (line 36) | def auroc_ipu(
  function average_precision_ipu (line 81) | def average_precision_ipu(
  function precision_ipu (line 128) | def precision_ipu(
  function recall_ipu (line 162) | def recall_ipu(
  function accuracy_ipu (line 195) | def accuracy_ipu(
  function get_confusion_matrix (line 321) | def get_confusion_matrix(
  class NaNTensor (line 451) | class NaNTensor(Tensor):
    method get_nans (line 461) | def get_nans(self) -> BoolTensor:
    method sum (line 474) | def sum(self, *args, **kwargs) -> Tensor:
    method mean (line 486) | def mean(self, *args, **kwargs) -> Tensor:
    method numel (line 494) | def numel(self) -> int:
    method min (line 500) | def min(self, *args, **kwargs) -> Tensor:
    method max (line 508) | def max(self, *args, **kwargs) -> Tensor:
    method argsort (line 516) | def argsort(self, dim=-1, descending=False) -> IntTensor:
    method size (line 527) | def size(self, dim) -> Tensor:
    method __lt__ (line 534) | def __lt__(self, other) -> Tensor:
    method __torch_function__ (line 546) | def __torch_function__(cls, func, types, args=(), kwargs=None):
  function pearson_ipu (line 569) | def pearson_ipu(preds, target):
  function spearman_ipu (line 585) | def spearman_ipu(preds, target):
  function _rank_data (line 605) | def _rank_data(data: Tensor) -> Tensor:
  function r2_score_ipu (line 626) | def r2_score_ipu(preds, target, *args, **kwargs) -> Tensor:
  function fbeta_score_ipu (line 659) | def fbeta_score_ipu(
  function f1_score_ipu (line 801) | def f1_score_ipu(
  function mean_squared_error_ipu (line 848) | def mean_squared_error_ipu(preds: Tensor, target: Tensor, squared: bool)...
  function mean_absolute_error_ipu (line 882) | def mean_absolute_error_ipu(preds: Tensor, target: Tensor) -> Tensor:

FILE: graphium/ipu/ipu_simple_lightning.py
  class SimpleTorchModel (line 35) | class SimpleTorchModel(torch.nn.Module):
    method __init__ (line 36) | def __init__(self, in_dim, hidden_dim, kernel_size, num_classes):
    method make_mup_base_kwargs (line 60) | def make_mup_base_kwargs(self, divide_factor: float = 2.0):
    method forward (line 68) | def forward(self, x):
  class SimpleLightning (line 75) | class SimpleLightning(lightning.LightningModule):
    method __init__ (line 76) | def __init__(self, in_dim, hidden_dim, kernel_size, num_classes, on_ipu):
    method training_step (line 83) | def training_step(self, batch, _):
    method validation_step (line 89) | def validation_step(self, batch, _):
    method on_train_batch_end (line 99) | def on_train_batch_end(self, outputs, batch, batch_idx):
    method validation_epoch_end (line 102) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 109) | def configure_optimizers(self):

FILE: graphium/ipu/ipu_utils.py
  function import_poptorch (line 23) | def import_poptorch(raise_error=True) -> Optional[ModuleType]:
  function is_running_on_ipu (line 47) | def is_running_on_ipu() -> bool:
  function load_ipu_options (line 57) | def load_ipu_options(
  function ipu_options_list_to_file (line 145) | def ipu_options_list_to_file(ipu_opts: Optional[List[str]]) -> tempfile....

FILE: graphium/ipu/ipu_wrapper.py
  class PyGArgsParser (line 36) | class PyGArgsParser(poptorch.ICustomArgParser):
    method sortedTensorKeys (line 45) | def sortedTensorKeys(struct: BaseData) -> Iterable[str]:
    method yieldTensors (line 57) | def yieldTensors(self, struct: BaseData):
    method reconstruct (line 64) | def reconstruct(self, original_structure: BaseData, tensor_iterator: I...
  class PredictorModuleIPU (line 92) | class PredictorModuleIPU(PredictorModule):
    method __init__ (line 97) | def __init__(self, *args, **kwargs):
    method compute_loss (line 103) | def compute_loss(
    method on_train_batch_end (line 115) | def on_train_batch_end(self, outputs, batch, batch_idx):
    method training_step (line 120) | def training_step(self, batch, batch_idx) -> Dict[str, Any]:
    method validation_step (line 130) | def validation_step(self, batch, batch_idx) -> Dict[str, Any]:
    method test_step (line 138) | def test_step(self, batch, batch_idx) -> Dict[str, Any]:
    method predict_step (line 147) | def predict_step(self, **inputs) -> Dict[str, Any]:
    method on_validation_batch_end (line 154) | def on_validation_batch_end(
    method evaluation_epoch_end (line 161) | def evaluation_epoch_end(self, outputs: Any):
    method on_test_batch_end (line 165) | def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int, ...
    method configure_optimizers (line 169) | def configure_optimizers(self, impl=None):
    method squeeze_input_dims (line 180) | def squeeze_input_dims(self, features, labels):
    method convert_from_fp16 (line 190) | def convert_from_fp16(self, data: Any) -> Any:
    method _convert_features_dtype (line 204) | def _convert_features_dtype(self, feats):
    method precision_to_dtype (line 222) | def precision_to_dtype(self, precision):
    method get_num_graphs (line 225) | def get_num_graphs(self, data: Batch):

FILE: graphium/ipu/to_dense_batch.py
  function to_sparse_batch (line 21) | def to_sparse_batch(x: Tensor, mask_idx: Tensor):
  function to_sparse_batch_from_packed (line 28) | def to_sparse_batch_from_packed(x: Tensor, pack_from_node_idx: Tensor):
  function to_dense_batch (line 35) | def to_dense_batch(
  function to_packed_dense_batch (line 143) | def to_packed_dense_batch(

FILE: graphium/nn/architectures/encoder_manager.py
  class EncoderManager (line 42) | class EncoderManager(nn.Module):
    method __init__ (line 43) | def __init__(
    method _initialize_positional_encoders (line 77) | def _initialize_positional_encoders(self, pe_encoders_kwargs: Dict[str...
    method forward (line 156) | def forward(self, g: Batch) -> Batch:
    method forward_positional_encoding (line 193) | def forward_positional_encoding(self, g: Batch) -> Dict[str, Tensor]:
    method forward_simple_pooling (line 231) | def forward_simple_pooling(self, h: Tensor, pooling: str, dim: int) ->...
    method make_mup_base_kwargs (line 253) | def make_mup_base_kwargs(self, divide_factor: float = 2.0) -> Dict[str...
    method input_keys (line 281) | def input_keys(self) -> Iterable[str]:
    method in_dims (line 294) | def in_dims(self) -> Iterable[int]:
    method out_dim (line 307) | def out_dim(self) -> int:

FILE: graphium/nn/architectures/global_architectures.py
  class FeedForwardNN (line 49) | class FeedForwardNN(nn.Module, MupMixin):
    method __init__ (line 50) | def __init__(
    method _parse_layers (line 187) | def _parse_layers(self, layer_type, residual_type):
    method _check_bad_arguments (line 194) | def _check_bad_arguments(self):
    method _parse_class_from_dict (line 204) | def _parse_class_from_dict(
    method _create_residual_connection (line 221) | def _create_residual_connection(self, out_dims: List[int]) -> Tuple[Re...
    method _create_layers (line 248) | def _create_layers(self):
    method cache_readouts (line 295) | def cache_readouts(self) -> bool:
    method _enable_readout_cache (line 299) | def _enable_readout_cache(self):
    method _disable_readout_cache (line 307) | def _disable_readout_cache(self):
    method drop_layers (line 311) | def drop_layers(self, depth: int) -> None:
    method add_layers (line 322) | def add_layers(self, layers: int) -> None:
    method forward (line 333) | def forward(self, h: torch.Tensor) -> torch.Tensor:
    method get_init_kwargs (line 367) | def get_init_kwargs(self) -> Dict[str, Any]:
    method make_mup_base_kwargs (line 393) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d...
    method __repr__ (line 411) | def __repr__(self):
  class EnsembleFeedForwardNN (line 421) | class EnsembleFeedForwardNN(FeedForwardNN):
    method __init__ (line 422) | def __init__(
    method _create_layers (line 583) | def _create_layers(self):
    method _parse_num_ensemble (line 587) | def _parse_num_ensemble(self, num_ensemble: int, layer_kwargs) -> int:
    method _parse_reduction (line 616) | def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) ...
    method _parse_subset_in_dim (line 652) | def _parse_subset_in_dim(
    method _parse_layers (line 703) | def _parse_layers(self, layer_type, residual_type):
    method forward (line 710) | def forward(self, h: torch.Tensor) -> torch.Tensor:
    method get_init_kwargs (line 748) | def get_init_kwargs(self) -> Dict[str, Any]:
    method __repr__ (line 757) | def __repr__(self):
  class FeedForwardGraph (line 767) | class FeedForwardGraph(FeedForwardNN):
    method __init__ (line 768) | def __init__(
    method _check_bad_arguments (line 960) | def _check_bad_arguments(self):
    method get_nested_key (line 970) | def get_nested_key(self, d, target_key):
    method _create_layers (line 990) | def _create_layers(self):
    method _graph_layer_forward (line 1089) | def _graph_layer_forward(
    method _parse_virtual_node_class (line 1177) | def _parse_virtual_node_class(self) -> type:
    method _virtual_node_forward (line 1180) | def _virtual_node_forward(
    method forward (line 1228) | def forward(self, g: Batch) -> torch.Tensor:
    method get_init_kwargs (line 1290) | def get_init_kwargs(self) -> Dict[str, Any]:
    method make_mup_base_kwargs (line 1305) | def make_mup_base_kwargs(
    method __repr__ (line 1345) | def __repr__(self):
  class FullGraphMultiTaskNetwork (line 1355) | class FullGraphMultiTaskNetwork(nn.Module, MupMixin):
    method __init__ (line 1356) | def __init__(
    method _parse_feed_forward_gnn (line 1482) | def _parse_feed_forward_gnn(
    method _check_bad_arguments (line 1496) | def _check_bad_arguments(self):
    method _apply_ipu_options (line 1529) | def _apply_ipu_options(self, ipu_kwargs):
    method _apply_ipu_pipeline_split (line 1533) | def _apply_ipu_pipeline_split(self, gnn_layers_per_ipu):
    method _enable_readout_cache (line 1569) | def _enable_readout_cache(self, module_filter: Optional[Union[str, Lis...
    method _disable_readout_cache (line 1596) | def _disable_readout_cache(self):
    method create_module_map (line 1604) | def create_module_map(self, level: Union[Literal["layers"], Literal["m...
    method forward (line 1650) | def forward(self, g: Batch) -> Tensor:
    method make_mup_base_kwargs (line 1713) | def make_mup_base_kwargs(self, divide_factor: float = 2.0) -> Dict[str...
    method set_max_num_nodes_edges_per_graph (line 1781) | def set_max_num_nodes_edges_per_graph(self, max_nodes: Optional[int], ...
    method __repr__ (line 1806) | def __repr__(self) -> str:
    method in_dim (line 1828) | def in_dim(self) -> int:
    method out_dim (line 1838) | def out_dim(self) -> int:
    method out_dim_edges (line 1845) | def out_dim_edges(self) -> int:
    method in_dim_edges (line 1855) | def in_dim_edges(self) -> int:
  class GraphOutputNN (line 1862) | class GraphOutputNN(nn.Module, MupMixin):
    method __init__ (line 1863) | def __init__(
    method forward (line 1921) | def forward(self, g: Batch):
    method _parse_pooling_layer (line 1953) | def _parse_pooling_layer(
    method _pool_layer_forward (line 1986) | def _pool_layer_forward(
    method compute_nodepairs (line 2017) | def compute_nodepairs(
    method make_mup_base_kwargs (line 2059) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d...
    method drop_graph_output_nn_layers (line 2082) | def drop_graph_output_nn_layers(self, num_layers_to_drop: int) -> None:
    method extend_graph_output_nn_layers (line 2095) | def extend_graph_output_nn_layers(self, layers: nn.ModuleList):
    method set_max_num_nodes_edges_per_graph (line 2108) | def set_max_num_nodes_edges_per_graph(self, max_nodes: Optional[int], ...
    method concat_last_layers (line 2123) | def concat_last_layers(self) -> Optional[Iterable[int]]:
    method concat_last_layers (line 2136) | def concat_last_layers(self, value: Union[Type[None], int, Iterable[in...
    method out_dim (line 2155) | def out_dim(self) -> int:
  class TaskHeads (line 2162) | class TaskHeads(nn.Module, MupMixin):
    method __init__ (line 2163) | def __init__(
    method forward (line 2215) | def forward(self, g: Batch) -> Dict[str, torch.Tensor]:
    method make_mup_base_kwargs (line 2234) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d...
    method set_max_num_nodes_edges_per_graph (line 2269) | def set_max_num_nodes_edges_per_graph(self, max_nodes: Optional[int], ...
    method out_dim (line 2292) | def out_dim(self) -> Dict[str, int]:
    method __repr__ (line 2298) | def __repr__(self):
    method _check_bad_arguments (line 2307) | def _check_bad_arguments(self):

FILE: graphium/nn/architectures/pyg_architectures.py
  class FeedForwardPyg (line 25) | class FeedForwardPyg(FeedForwardGraph):
    method _graph_layer_forward (line 26) | def _graph_layer_forward(
    method _parse_virtual_node_class (line 114) | def _parse_virtual_node_class(self) -> type:
    method _parse_pooling_layer (line 117) | def _parse_pooling_layer(

FILE: graphium/nn/base_graph_layer.py
  class BaseGraphStructure (line 28) | class BaseGraphStructure:
    method __init__ (line 29) | def __init__(
    method _initialize_activation_dropout_norm (line 92) | def _initialize_activation_dropout_norm(self):
    method _parse_dropout (line 105) | def _parse_dropout(self, dropout):
    method _parse_droppath (line 112) | def _parse_droppath(self, droppath_rate):
    method _parse_norm (line 119) | def _parse_norm(self, normalization, dim=None):
    method apply_norm_activation_dropout (line 136) | def apply_norm_activation_dropout(
    method layer_supports_edges (line 191) | def layer_supports_edges(cls) -> bool:
    method layer_inputs_edges (line 205) | def layer_inputs_edges(self) -> bool:
    method layer_outputs_edges (line 221) | def layer_outputs_edges(self) -> bool:
    method out_dim_factor (line 237) | def out_dim_factor(self) -> int:
    method max_num_nodes_per_graph (line 258) | def max_num_nodes_per_graph(self) -> Optional[int]:
    method max_num_nodes_per_graph (line 265) | def max_num_nodes_per_graph(self, value: Optional[int]):
    method max_num_edges_per_graph (line 276) | def max_num_edges_per_graph(self) -> Optional[int]:
    method max_num_edges_per_graph (line 283) | def max_num_edges_per_graph(self, value: Optional[int]):
    method __repr__ (line 293) | def __repr__(self):
  class BaseGraphModule (line 302) | class BaseGraphModule(BaseGraphStructure, nn.Module):
    method __init__ (line 303) | def __init__(
  function check_intpus_allow_int (line 366) | def check_intpus_allow_int(obj, edge_index, size):

FILE: graphium/nn/base_layers.py
  function get_activation (line 44) | def get_activation(activation: Union[type(None), str, Callable]) -> Opti...
  function get_activation_str (line 72) | def get_activation_str(activation: Union[type(None), str, Callable]) -> ...
  function get_norm (line 96) | def get_norm(normalization: Union[Type[None], str, Callable], dim: Optio...
  class MultiheadAttentionMup (line 124) | class MultiheadAttentionMup(nn.MultiheadAttention):
    method __init__ (line 132) | def __init__(self, biased_attention, **kwargs):
    method _reset_parameters (line 136) | def _reset_parameters(self):
    method forward (line 153) | def forward(
  class TransformerEncoderLayerMup (line 208) | class TransformerEncoderLayerMup(nn.TransformerEncoderLayer):
    method __init__ (line 215) | def __init__(self, biased_attention, *args, **kwargs) -> None:
  class MuReadoutGraphium (line 238) | class MuReadoutGraphium(MuReadout):
    method __init__ (line 250) | def __init__(self, in_features, *args, **kwargs):
    method absolute_width (line 255) | def absolute_width(self):
    method base_width (line 259) | def base_width(self):
    method base_width (line 263) | def base_width(self, val):
    method width_mult (line 271) | def width_mult(self):
  class FCLayer (line 275) | class FCLayer(nn.Module):
    method __init__ (line 276) | def __init__(
    method reset_parameters (line 380) | def reset_parameters(self, init_fn=None):
    method forward (line 391) | def forward(self, h: torch.Tensor) -> torch.Tensor:
    method in_channels (line 435) | def in_channels(self) -> int:
    method out_channels (line 442) | def out_channels(self) -> int:
    method __repr__ (line 448) | def __repr__(self):
  class MLP (line 452) | class MLP(nn.Module):
    method __init__ (line 453) | def __init__(
    method forward (line 585) | def forward(self, h: torch.Tensor) -> torch.Tensor:
    method in_features (line 609) | def in_features(self):
    method __getitem__ (line 612) | def __getitem__(self, idx: int) -> nn.Module:
    method __repr__ (line 615) | def __repr__(self):
  class GRU (line 622) | class GRU(nn.Module):
    method __init__ (line 623) | def __init__(self, in_dim: int, hidden_dim: int):
    method forward (line 639) | def forward(self, x, y):
  class DropPath (line 673) | class DropPath(nn.Module):
    method __init__ (line 674) | def __init__(self, drop_rate: float):
    method get_stochastic_drop_rate (line 690) | def get_stochastic_drop_rate(
    method forward (line 717) | def forward(

FILE: graphium/nn/encoders/base_encoder.py
  class BaseEncoder (line 11) | class BaseEncoder(torch.nn.Module, MupMixin):
    method __init__ (line 12) | def __init__(
    method parse_input_keys_with_prefix (line 52) | def parse_input_keys_with_prefix(self, key_prefix):
    method forward (line 66) | def forward(self, graph: Batch, key_prefix=None) -> Dict[str, torch.Te...
    method parse_input_keys (line 76) | def parse_input_keys(self, input_keys: List[str]) -> List[str]:
    method parse_output_keys (line 85) | def parse_output_keys(self, output_keys: List[str]) -> List[str]:
    method make_mup_base_kwargs (line 93) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d...

FILE: graphium/nn/encoders/bessel_pos_encoder.py
  class BesselSphericalPosEncoder (line 16) | class BesselSphericalPosEncoder(BaseEncoder):
    method __init__ (line 17) | def __init__(
    method forward (line 88) | def forward(self, batch: Batch, key_prefix: Optional[str] = None) -> D...
    method make_mup_base_kwargs (line 148) | def make_mup_base_kwargs(
    method parse_input_keys (line 176) | def parse_input_keys(
    method parse_output_keys (line 198) | def parse_output_keys(

FILE: graphium/nn/encoders/gaussian_kernel_pos_encoder.py
  class GaussianKernelPosEncoder (line 9) | class GaussianKernelPosEncoder(BaseEncoder):
    method __init__ (line 10) | def __init__(
    method parse_input_keys (line 66) | def parse_input_keys(
    method parse_output_keys (line 92) | def parse_output_keys(
    method forward (line 108) | def forward(self, batch: Batch, key_prefix: Optional[str] = None) -> D...
    method make_mup_base_kwargs (line 141) | def make_mup_base_kwargs(

FILE: graphium/nn/encoders/laplace_pos_encoder.py
  class LapPENodeEncoder (line 10) | class LapPENodeEncoder(BaseEncoder):
    method __init__ (line 11) | def __init__(
    method parse_input_keys (line 120) | def parse_input_keys(
    method parse_output_keys (line 145) | def parse_output_keys(
    method forward (line 168) | def forward(
    method make_mup_base_kwargs (line 228) | def make_mup_base_kwargs(

FILE: graphium/nn/encoders/mlp_encoder.py
  class MLPEncoder (line 10) | class MLPEncoder(BaseEncoder):
    method __init__ (line 11) | def __init__(
    method parse_input_keys (line 72) | def parse_input_keys(
    method parse_output_keys (line 85) | def parse_output_keys(
    method forward (line 114) | def forward(
    method make_mup_base_kwargs (line 141) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d...
  class CatMLPEncoder (line 164) | class CatMLPEncoder(BaseEncoder):
    method __init__ (line 165) | def __init__(
    method parse_input_keys (line 230) | def parse_input_keys(
    method parse_output_keys (line 243) | def parse_output_keys(
    method forward (line 257) | def forward(
    method make_mup_base_kwargs (line 284) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d...

FILE: graphium/nn/encoders/signnet_pos_encoder.py
  class SimpleGIN (line 18) | class SimpleGIN(nn.Module):
    method __init__ (line 19) | def __init__(
    method forward (line 71) | def forward(self, x, edge_index):
  class GINDeepSigns (line 77) | class GINDeepSigns(nn.Module):
    method __init__ (line 82) | def __init__(
    method forward (line 117) | def forward(self, x, edge_index, batch_index):
  class MaskedGINDeepSigns (line 126) | class MaskedGINDeepSigns(nn.Module):
    method __init__ (line 131) | def __init__(
    method batched_n_nodes (line 164) | def batched_n_nodes(self, batch_index):
    method forward (line 173) | def forward(self, x, edge_index, batch_index):
  class SignNetNodeEncoder (line 189) | class SignNetNodeEncoder(BaseEncoder):
    method __init__ (line 207) | def __init__(
    method parse_input_keys (line 283) | def parse_input_keys(self, input_keys):
    method parse_output_keys (line 295) | def parse_output_keys(self, output_keys):
    method forward (line 305) | def forward(self, batch: Batch, key_prefix: Optional[str] = None) -> D...
    method make_mup_base_kwargs (line 320) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d...

FILE: graphium/nn/ensemble_layers.py
  class EnsembleLinear (line 26) | class EnsembleLinear(nn.Module):
    method __init__ (line 27) | def __init__(
    method reset_parameters (line 61) | def reset_parameters(self):
    method forward (line 73) | def forward(self, h: torch.Tensor) -> torch.Tensor:
  class EnsembleFCLayer (line 98) | class EnsembleFCLayer(FCLayer):
    method __init__ (line 99) | def __init__(
    method reset_parameters (line 189) | def reset_parameters(self, init_fn=None):
    method __repr__ (line 196) | def __repr__(self):
  class EnsembleMuReadoutGraphium (line 202) | class EnsembleMuReadoutGraphium(EnsembleLinear):
    method __init__ (line 208) | def __init__(
    method reset_parameters (line 230) | def reset_parameters(self) -> None:
    method width_mult (line 238) | def width_mult(self):
    method _rescale_parameters (line 246) | def _rescale_parameters(self):
    method forward (line 266) | def forward(self, x):
    method absolute_width (line 270) | def absolute_width(self):
    method base_width (line 274) | def base_width(self):
    method base_width (line 278) | def base_width(self, val):
    method width_mult (line 286) | def width_mult(self):
  class EnsembleMLP (line 290) | class EnsembleMLP(MLP):
    method __init__ (line 291) | def __init__(
    method _parse_reduction (line 392) | def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) ...
    method forward (line 428) | def forward(self, h: torch.Tensor) -> torch.Tensor:
    method __repr__ (line 452) | def __repr__(self):

FILE: graphium/nn/pyg_layers/dimenet_pyg.py
  class ResidualLayer (line 30) | class ResidualLayer(torch.nn.Module):
    method __init__ (line 36) | def __init__(self, hidden_channels: int, activation: Union[Callable, s...
    method forward (line 41) | def forward(self, x: Tensor) -> Tensor:
  class OutputBlock (line 45) | class OutputBlock(torch.nn.Module):
    method __init__ (line 51) | def __init__(
    method forward (line 67) | def forward(self, x: Tensor, rbf: Tensor, i: Tensor, num_nodes: Option...
  class InteractionBlock (line 75) | class InteractionBlock(nn.Module):
    method __init__ (line 81) | def __init__(
    method forward (line 113) | def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor,...
  class DimeNetPyg (line 141) | class DimeNetPyg(BaseGraphModule):
    method __init__ (line 142) | def __init__(
    method forward (line 255) | def forward(self, batch: Union[Data, Batch]) -> Union[Data, Batch]:
    method layer_supports_edges (line 287) | def layer_supports_edges(cls) -> bool:
    method layer_inputs_edges (line 299) | def layer_inputs_edges(self) -> bool:
    method layer_outputs_edges (line 314) | def layer_outputs_edges(self) -> bool:
    method out_dim_factor (line 329) | def out_dim_factor(self) -> int:

FILE: graphium/nn/pyg_layers/gated_gcn_pyg.py
  class GatedGCNPyg (line 28) | class GatedGCNPyg(MessagePassing, BaseGraphStructure):
    method __init__ (line 29) | def __init__(
    method forward (line 106) | def forward(
    method message (line 143) | def message(self, Dx_i: torch.Tensor, Ex_j: torch.Tensor, Ce: torch.Te...
    method aggregate (line 159) | def aggregate(
    method update (line 196) | def update(self, aggr_out: torch.Tensor, Ax: torch.Tensor):
    method layer_supports_edges (line 212) | def layer_supports_edges(cls) -> bool:
    method layer_inputs_edges (line 223) | def layer_inputs_edges(self) -> bool:
    method layer_outputs_edges (line 237) | def layer_outputs_edges(self) -> bool:
    method out_dim_factor (line 251) | def out_dim_factor(self) -> int:

FILE: graphium/nn/pyg_layers/gcn_pyg.py
  class GCNConvPyg (line 23) | class GCNConvPyg(BaseGraphModule):
    method __init__ (line 24) | def __init__(
    method forward (line 47) | def forward(
    method layer_supports_edges (line 64) | def layer_supports_edges(cls) -> bool:
    method layer_inputs_edges (line 76) | def layer_inputs_edges(self) -> bool:
    method layer_outputs_edges (line 91) | def layer_outputs_edges(self) -> bool:
    method out_dim_factor (line 106) | def out_dim_factor(self) -> int:

FILE: graphium/nn/pyg_layers/gin_pyg.py
  class GINConvPyg (line 24) | class GINConvPyg(BaseGraphModule):
    method __init__ (line 25) | def __init__(
    method forward (line 94) | def forward(
    method layer_supports_edges (line 111) | def layer_supports_edges(cls) -> bool:
    method layer_inputs_edges (line 123) | def layer_inputs_edges(self) -> bool:
    method layer_outputs_edges (line 138) | def layer_outputs_edges(self) -> bool:
    method out_dim_factor (line 153) | def out_dim_factor(self) -> int:
  class GINEConvPyg (line 173) | class GINEConvPyg(BaseGraphModule):
    method __init__ (line 174) | def __init__(
    method forward (line 244) | def forward(
    method layer_supports_edges (line 261) | def layer_supports_edges(cls) -> bool:
    method layer_inputs_edges (line 273) | def layer_inputs_edges(self) -> bool:
    method layer_outputs_edges (line 288) | def layer_outputs_edges(self) -> bool:
    method out_dim_factor (line 303) | def out_dim_factor(self) -> int:

FILE: graphium/nn/pyg_layers/gps_pyg.py
  class GPSLayerPyg (line 53) | class GPSLayerPyg(BaseGraphModule):
    method __init__ (line 54) | def __init__(
    method residual_add (line 217) | def residual_add(self, feature: Tensor, input_feature: Tensor) -> Tensor:
    method scale_activations (line 230) | def scale_activations(self, feature: Tensor, scale_factor: Tensor) -> ...
    method forward (line 245) | def forward(self, batch: Batch) -> Batch:
    method _parse_mpnn_layer (line 299) | def _parse_mpnn_layer(self, mpnn_type, mpnn_kwargs: Dict[str, Any]) ->...
    method _parse_attn_layer (line 324) | def _parse_attn_layer(
    method _use_packing (line 365) | def _use_packing(self, batch: Batch) -> bool:
    method _to_dense_batch (line 372) | def _to_dense_batch(
    method _to_sparse_batch (line 406) | def _to_sparse_batch(self, batch: Batch, h_dense: Tensor, idx: Tensor)...
    method _self_attention_block (line 422) | def _self_attention_block(self, feat: Tensor, feat_in: Tensor, batch: ...
    method _sa_block (line 472) | def _sa_block(
    method layer_supports_edges (line 498) | def layer_supports_edges(cls) -> bool:
    method layer_inputs_edges (line 510) | def layer_inputs_edges(self) -> bool:
    method layer_outputs_edges (line 525) | def layer_outputs_edges(self) -> bool:
    method out_dim_factor (line 542) | def out_dim_factor(self) -> int:

FILE: graphium/nn/pyg_layers/mpnn_pyg.py
  class MPNNPlusPyg (line 25) | class MPNNPlusPyg(BaseGraphModule):
    method __init__ (line 26) | def __init__(
    method gather_features (line 188) | def gather_features(
    method aggregate_features (line 235) | def aggregate_features(
    method forward (line 297) | def forward(self, batch: Batch) -> Batch:
    method layer_supports_edges (line 343) | def layer_supports_edges(cls) -> bool:
    method layer_inputs_edges (line 355) | def layer_inputs_edges(self) -> bool:
    method layer_outputs_edges (line 370) | def layer_outputs_edges(self) -> bool:
    method out_dim_factor (line 385) | def out_dim_factor(self) -> int:

FILE: graphium/nn/pyg_layers/pna_pyg.py
  class PNAMessagePassingPyg (line 30) | class PNAMessagePassingPyg(MessagePassing, BaseGraphStructure):
    method __init__ (line 31) | def __init__(
    method forward (line 167) | def forward(self, batch: Union[Data, Batch]) -> Union[Data, Batch]:
    method message (line 182) | def message(self, x_i: Tensor, x_j: Tensor, edge_feat: OptTensor) -> T...
    method aggregate (line 202) | def aggregate(
    method layer_outputs_edges (line 263) | def layer_outputs_edges(self) -> bool:
    method out_dim_factor (line 278) | def out_dim_factor(self) -> int:
    method layer_inputs_edges (line 298) | def layer_inputs_edges(self) -> bool:
    method layer_supports_edges (line 313) | def layer_supports_edges(cls) -> bool:

FILE: graphium/nn/pyg_layers/pooling_pyg.py
  function scatter_logsum_pool (line 30) | def scatter_logsum_pool(x: Tensor, batch: LongTensor, dim: int = 0, dim_...
  function scatter_std_pool (line 62) | def scatter_std_pool(x: Tensor, batch: LongTensor, dim: int = 0, dim_siz...
  class PoolingWrapperPyg (line 88) | class PoolingWrapperPyg(ModuleWrap):
    method __init__ (line 89) | def __init__(self, func, feat_type, *args, **kwargs) -> None:
    method forward (line 93) | def forward(self, g: Batch, feature: Tensor, *args, **kwargs):
  function parse_pooling_layer_pyg (line 112) | def parse_pooling_layer_pyg(in_dim: int, pooling: Union[str, List[str]],...
  class VirtualNodePyg (line 164) | class VirtualNodePyg(nn.Module):
    method __init__ (line 165) | def __init__(
    method forward (line 286) | def forward(

FILE: graphium/nn/pyg_layers/utils.py
  class PreprocessPositions (line 27) | class PreprocessPositions(nn.Module):
    method __init__ (line 32) | def __init__(
    method forward (line 75) | def forward(
  class GaussianLayer (line 160) | class GaussianLayer(nn.Module):
    method __init__ (line 161) | def __init__(self, num_kernels=128, in_dim=3):
    method forward (line 175) | def forward(self, input: Tensor) -> Tensor:
  function triplets (line 190) | def triplets(

FILE: graphium/nn/residual_connections.py
  class ResidualConnectionBase (line 27) | class ResidualConnectionBase(nn.Module):
    method __init__ (line 28) | def __init__(self, skip_steps: int = 1):
    method _bool_apply_skip_step (line 50) | def _bool_apply_skip_step(self, step_idx: int):
    method __repr__ (line 63) | def __repr__(self):
    method h_dim_increase_type (line 71) | def h_dim_increase_type(cls):
    method get_true_out_dims (line 90) | def get_true_out_dims(self, out_dims: List) -> List:
    method has_weights (line 127) | def has_weights(cls):
  class ResidualConnectionNone (line 138) | class ResidualConnectionNone(ResidualConnectionBase):
    method __init__ (line 144) | def __init__(self, skip_steps: int = 1):
    method __repr__ (line 147) | def __repr__(self):
    method h_dim_increase_type (line 154) | def h_dim_increase_type(cls):
    method has_weights (line 165) | def has_weights(cls):
    method forward (line 175) | def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int):
  class ResidualConnectionSimple (line 191) | class ResidualConnectionSimple(ResidualConnectionBase):
    method __init__ (line 192) | def __init__(self, skip_steps: int = 1):
    method h_dim_increase_type (line 208) | def h_dim_increase_type(cls):
    method has_weights (line 219) | def has_weights(cls):
    method forward (line 229) | def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int):
  class ResidualConnectionWeighted (line 265) | class ResidualConnectionWeighted(ResidualConnectionBase):
    method __init__ (line 266) | def __init__(
    method h_dim_increase_type (line 333) | def h_dim_increase_type(cls):
    method has_weights (line 343) | def has_weights(cls):
    method forward (line 353) | def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int):
    method _bool_apply_skip_step (line 390) | def _bool_apply_skip_step(self, step_idx: int):
  class ResidualConnectionConcat (line 394) | class ResidualConnectionConcat(ResidualConnectionBase):
    method __init__ (line 395) | def __init__(self, skip_steps: int = 1):
    method h_dim_increase_type (line 412) | def h_dim_increase_type(cls):
    method has_weights (line 423) | def has_weights(cls):
    method forward (line 433) | def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int):
  class ResidualConnectionDenseNet (line 471) | class ResidualConnectionDenseNet(ResidualConnectionBase):
    method __init__ (line 472) | def __init__(self, skip_steps: int = 1):
    method h_dim_increase_type (line 489) | def h_dim_increase_type(cls):
    method has_weights (line 500) | def has_weights(cls):
    method forward (line 510) | def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int):
  class ResidualConnectionRandom (line 548) | class ResidualConnectionRandom(ResidualConnectionBase):
    method __init__ (line 549) | def __init__(self, skip_steps=1, out_dims: List[int] = None, num_layer...
    method h_dim_increase_type (line 580) | def h_dim_increase_type(cls):
    method has_weights (line 590) | def has_weights(cls):
    method forward (line 598) | def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int):

FILE: graphium/nn/utils.py
  class MupMixin (line 20) | class MupMixin(abc.ABC):
    method make_mup_base_kwargs (line 22) | def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_d...
    method scale_kwargs (line 38) | def scale_kwargs(self, scale_factor: Real, scale_in_dim: bool = False):

FILE: graphium/trainer/losses.py
  class HybridCELoss (line 22) | class HybridCELoss(_WeightedLoss):
    method __init__ (line 23) | def __init__(
    method forward (line 68) | def forward(self, input: Tensor, target: Tensor, nan_targets: Tensor =...

FILE: graphium/trainer/metrics.py
  class Thresholder (line 36) | class Thresholder:
    method __init__ (line 37) | def __init__(
    method compute (line 50) | def compute(self, preds: Tensor, target: Tensor):
    method __call__ (line 61) | def __call__(self, preds: Tensor, target: Tensor):
    method __repr__ (line 64) | def __repr__(self):
    method _get_operator (line 72) | def _get_operator(operator):
    method __getstate__ (line 97) | def __getstate__(self):
    method __setstate__ (line 115) | def __setstate__(self, state: dict):
    method __eq__ (line 120) | def __eq__(self, obj) -> bool:
  class MetricWrapper (line 131) | class MetricWrapper:
    method __init__ (line 137) | def __init__(
    method _parse_target_nan_mask (line 201) | def _parse_target_nan_mask(target_nan_mask):
    method _parse_multitask_handling (line 227) | def _parse_multitask_handling(multitask_handling, target_nan_mask):
    method _get_metric (line 256) | def _get_metric(metric):
    method compute (line 267) | def compute(self, preds: Tensor, target: Tensor) -> Tensor:
    method _filter_nans (line 343) | def _filter_nans(self, preds: Tensor, target: Tensor):
    method __call__ (line 359) | def __call__(self, preds: Tensor, target: Tensor) -> Tensor:
    method __repr__ (line 365) | def __repr__(self):
    method __eq__ (line 375) | def __eq__(self, obj) -> bool:
    method __getstate__ (line 388) | def __getstate__(self):
    method __setstate__ (line 405) | def __setstate__(self, state: dict):

FILE: graphium/trainer/predictor.py
  class PredictorModule (line 42) | class PredictorModule(lightning.LightningModule):
    method __init__ (line 43) | def __init__(
    method forward (line 197) | def forward(
    method _convert_features_dtype (line 221) | def _convert_features_dtype(self, feats):
    method _get_task_key (line 231) | def _get_task_key(self, task_level: str, task: str):
    method configure_optimizers (line 237) | def configure_optimizers(self, impl=None):
    method compute_loss (line 254) | def compute_loss(
    method _general_step (line 327) | def _general_step(self, batch: Dict[str, Tensor], step_name: str, to_c...
    method flag_step (line 402) | def flag_step(self, batch: Dict[str, Tensor], step_name: str, to_cpu: ...
    method on_train_batch_start (line 471) | def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional...
    method on_train_batch_end (line 478) | def on_train_batch_end(self, outputs, batch: Any, batch_idx: int) -> N...
    method training_step (line 537) | def training_step(self, batch: Dict[str, Tensor], to_cpu: bool = True)...
    method get_gradient_norm (line 554) | def get_gradient_norm(self):
    method validation_step (line 564) | def validation_step(self, batch: Dict[str, Tensor], to_cpu: bool = Tru...
    method test_step (line 567) | def test_step(self, batch: Dict[str, Tensor], to_cpu: bool = True) -> ...
    method _general_epoch_end (line 570) | def _general_epoch_end(self, outputs: Dict[str, Any], step_name: str, ...
    method on_train_epoch_start (line 607) | def on_train_epoch_start(self) -> None:
    method on_train_epoch_end (line 610) | def on_train_epoch_end(self) -> None:
    method on_validation_epoch_start (line 618) | def on_validation_epoch_start(self) -> None:
    method on_validation_batch_start (line 623) | def on_validation_batch_start(self, batch: Any, batch_idx: int, datalo...
    method on_validation_batch_end (line 627) | def on_validation_batch_end(
    method on_validation_epoch_end (line 637) | def on_validation_epoch_end(self) -> None:
    method on_test_batch_end (line 651) | def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int, ...
    method on_test_epoch_end (line 654) | def on_test_epoch_end(self) -> None:
    method on_train_start (line 665) | def on_train_start(self):
    method get_progress_bar_dict (line 671) | def get_progress_bar_dict(self) -> Dict[str, float]:
    method __repr__ (line 682) | def __repr__(self) -> str:
    method list_pretrained_models (line 692) | def list_pretrained_models():
    method load_pretrained_model (line 697) | def load_pretrained_model(name_or_path: str, device: str = None):
    method set_max_nodes_edges_per_graph (line 720) | def set_max_nodes_edges_per_graph(self, datamodule: BaseDataModule, st...
    method get_num_graphs (line 728) | def get_num_graphs(self, data: Batch):

FILE: graphium/trainer/predictor_options.py
  class ModelOptions (line 37) | class ModelOptions:
  class OptimOptions (line 54) | class OptimOptions:
    method set_kwargs (line 93) | def set_kwargs(self):
  class EvalOptions (line 137) | class EvalOptions:
    method check_metrics_validity (line 167) | def check_metrics_validity(self):
    method parse_loss_fun (line 185) | def parse_loss_fun(loss_fun: Union[str, Dict, Callable]) -> Callable:
  class FlagOptions (line 226) | class FlagOptions:
    method set_kwargs (line 245) | def set_kwargs(self):

FILE: graphium/trainer/predictor_summaries.py
  class SummaryInterface (line 26) | class SummaryInterface(object):
    method set_results (line 31) | def set_results(self, **kwargs):
    method get_dict_summary (line 34) | def get_dict_summary(self):
    method update_predictor_state (line 37) | def update_predictor_state(self, **kwargs):
    method get_metrics_logs (line 40) | def get_metrics_logs(self, **kwargs):
  class Summary (line 44) | class Summary(SummaryInterface):
    method __init__ (line 46) | def __init__(
    method update_predictor_state (line 103) | def update_predictor_state(
    method set_results (line 121) | def set_results(
    method is_best_epoch (line 145) | def is_best_epoch(self, step_name: str, loss: Tensor, metrics: Dict[st...
    method get_results (line 173) | def get_results(
    method get_best_results (line 186) | def get_best_results(
    method get_results_on_progress_bar (line 199) | def get_results_on_progress_bar(
    method get_dict_summary (line 220) | def get_dict_summary(self) -> Dict[str, Any]:
    method get_metrics_logs (line 241) | def get_metrics_logs(self) -> Dict[str, Any]:
    method metric_log_name (line 295) | def metric_log_name(self, task_name, metric_name, step_name):
    class Results (line 301) | class Results:
      method __init__ (line 302) | def __init__(
  class TaskSummaries (line 334) | class TaskSummaries(SummaryInterface):
    method __init__ (line 335) | def __init__(
    method update_predictor_state (line 380) | def update_predictor_state(
    method set_results (line 410) | def set_results(self, task_metrics: Dict[str, Dict[str, Tensor]]):
    method get_results (line 425) | def get_results(
    method get_best_results (line 441) | def get_best_results(
    method get_results_on_progress_bar (line 457) | def get_results_on_progress_bar(
    method get_dict_summary (line 475) | def get_dict_summary(
    method get_metrics_logs (line 488) | def get_metrics_logs(
    method concatenate_metrics_logs (line 513) | def concatenate_metrics_logs(
    method metric_log_name (line 530) | def metric_log_name(

FILE: graphium/utils/arg_checker.py
  function _parse_type (line 32) | def _parse_type(type_to_validate, accepted_types):
  function _enforce_iter_type (line 50) | def _enforce_iter_type(arg, enforce_type):
  function check_arg_iterator (line 62) | def check_arg_iterator(arg, enforce_type=None, enforce_subtype=None, cas...
  function check_list1_in_list2 (line 149) | def check_list1_in_list2(list1, list2, throw_error=True):
  function check_columns_choice (line 184) | def check_columns_choice(dataframe, columns_choice, extra_accepted_cols=...

FILE: graphium/utils/command_line_utils.py
  function get_anchors_and_aliases (line 19) | def get_anchors_and_aliases(filepath):
  function update_config (line 67) | def update_config(cfg: Dict, unknown: List, anchors: List):

FILE: graphium/utils/custom_lr.py
  class WarmUpLinearLR (line 18) | class WarmUpLinearLR(_LRScheduler):
    method __init__ (line 31) | def __init__(self, optimizer, max_num_epochs, warmup_epochs=0, min_lr=...
    method get_lr (line 38) | def get_lr(self):
    method _get_closed_form_lr (line 55) | def _get_closed_form_lr(self):

FILE: graphium/utils/decorators.py
  class classproperty (line 14) | class classproperty(property):
    method __get__ (line 28) | def __get__(self, cls, owner):

FILE: graphium/utils/fs.py
  function get_cache_dir (line 26) | def get_cache_dir(suffix: str = None, create: bool = True) -> pathlib.Path:
  function get_mapper (line 42) | def get_mapper(path: Union[str, os.PathLike]):
  function get_basename (line 50) | def get_basename(path: Union[str, os.PathLike]):
  function get_extension (line 61) | def get_extension(path: Union[str, os.PathLike]):
  function exists (line 70) | def exists(path: Union[str, os.PathLike, fsspec.core.OpenFile, io.IOBase]):
  function exists_and_not_empty (line 88) | def exists_and_not_empty(path: Union[str, os.PathLike]):
  function mkdir (line 99) | def mkdir(path: Union[str, os.PathLike], exist_ok: bool = True):
  function rm (line 105) | def rm(path: Union[str, os.PathLike], recursive=False, maxdepth=None):
  function join (line 111) | def join(*paths):
  function get_size (line 124) | def get_size(file: Union[str, os.PathLike, io.IOBase, fsspec.core.OpenFi...
  function copy (line 146) | def copy(

FILE: graphium/utils/hashing.py
  function get_md5_hash (line 19) | def get_md5_hash(object: Any) -> str:

FILE: graphium/utils/moving_average_tracker.py
  class MovingAverageTracker (line 18) | class MovingAverageTracker:
    method update (line 22) | def update(self, value: float):
    method reset (line 28) | def reset(self):

FILE: graphium/utils/mup.py
  function apply_infshapes (line 24) | def apply_infshapes(model, infshapes):
  function set_base_shapes (line 36) | def set_base_shapes(model, base, rescale_params=True, delta=None, savefi...

FILE: graphium/utils/packing.py
  class MolPack (line 19) | class MolPack:
    method __init__ (line 26) | def __init__(self):
    method add_mol (line 32) | def add_mol(self, num_nodes: int, idx: int) -> "MolPack":
    method expected_atoms (line 47) | def expected_atoms(self, remaining_mean_num_nodes: float, batch_size: ...
    method __repr__ (line 64) | def __repr__(self) -> str:
  function smart_packing (line 71) | def smart_packing(num_nodes: List[int], batch_size: int) -> List[List[in...
  function fast_packing (line 128) | def fast_packing(num_nodes: List[int], batch_size: int) -> List[List[int]]:
  function hybrid_packing (line 161) | def hybrid_packing(num_nodes: List[int], batch_size: int) -> List[List[i...
  function get_pack_sizes (line 227) | def get_pack_sizes(packed_indices, num_nodes):
  function estimate_max_pack_node_size (line 240) | def estimate_max_pack_node_size(num_nodes: Iterable[int], batch_size: in...
  function node_to_pack_indices_mask (line 268) | def node_to_pack_indices_mask(

FILE: graphium/utils/safe_run.py
  class SafeRun (line 18) | class SafeRun:
    method __init__ (line 19) | def __init__(self, name: str, raise_error: bool = True, verbose: int =...
    method __enter__ (line 42) | def __enter__(self):
    method __exit__ (line 49) | def __exit__(self, type, value, traceback):

FILE: graphium/utils/tensor.py
  function save_im (line 29) | def save_im(im_dir, im_name: str, ext: List[str] = ["svg", "png"], dpi: ...
  function is_dtype_torch_tensor (line 43) | def is_dtype_torch_tensor(dtype: Union[np.dtype, torch.dtype]) -> bool:
  function is_dtype_numpy_array (line 57) | def is_dtype_numpy_array(dtype: Union[np.dtype, torch.dtype]) -> bool:
  function one_of_k_encoding (line 78) | def one_of_k_encoding(val: Any, classes: Iterable[Any]) -> List[int]:
  function is_device_cuda (line 103) | def is_device_cuda(device: torch.device, ignore_errors: bool = False) ->...
  function nan_mean (line 127) | def nan_mean(input: Tensor, *args, **kwargs) -> Tensor:
  function nan_median (line 154) | def nan_median(input: Tensor, **kwargs) -> Tensor:
  function nan_var (line 202) | def nan_var(input: Tensor, unbiased: bool = True, **kwargs) -> Tensor:
  function nan_std (line 242) | def nan_std(input: Tensor, unbiased: bool = True, **kwargs) -> Tensor:
  function nan_mad (line 270) | def nan_mad(input: Tensor, normal: bool = True, **kwargs) -> Tensor:
  class ModuleWrap (line 304) | class ModuleWrap(torch.nn.Module):
    method __init__ (line 312) | def __init__(self, func, *args, **kwargs) -> None:
    method forward (line 319) | def forward(self, *args, **kwargs):
    method __repr__ (line 325) | def __repr__(self):
  class ModuleListConcat (line 329) | class ModuleListConcat(torch.nn.ModuleList):
    method __init__ (line 339) | def __init__(self, dim: int = -1):
    method forward (line 343) | def forward(self, *args, **kwargs) -> Tensor:
  function parse_valid_args (line 355) | def parse_valid_args(param_dict, fn):
  function arg_in_func (line 382) | def arg_in_func(fn, arg):
  function tensor_fp16_to_fp32 (line 402) | def tensor_fp16_to_fp32(tensor: Tensor) -> Tensor:
  function dict_tensor_fp16_to_fp32 (line 416) | def dict_tensor_fp16_to_fp32(

FILE: profiling/profile_mol_to_graph.py
  function main (line 23) | def main():

FILE: profiling/profile_one_of_k_encoding.py
  function main (line 18) | def main():

FILE: profiling/profile_predictor.py
  function main (line 29) | def main():

FILE: tests/conftest.py
  function datadir (line 22) | def datadir(request):

FILE: tests/test_architectures.py
  class test_FeedForwardNN (line 40) | class test_FeedForwardNN(ut.TestCase):
    method test_forward_no_residual (line 52) | def test_forward_no_residual(self):
    method test_forward_simple_residual_1 (line 78) | def test_forward_simple_residual_1(self):
    method test_forward_norms (line 104) | def test_forward_norms(self):
    method test_forward_simple_residual_2 (line 134) | def test_forward_simple_residual_2(self):
    method test_forward_concat_residual_1 (line 162) | def test_forward_concat_residual_1(self):
    method test_forward_concat_residual_2 (line 190) | def test_forward_concat_residual_2(self):
    method test_forward_densenet_residual_1 (line 218) | def test_forward_densenet_residual_1(self):
    method test_forward_densenet_residual_2 (line 246) | def test_forward_densenet_residual_2(self):
    method test_forward_weighted_residual_1 (line 274) | def test_forward_weighted_residual_1(self):
    method test_forward_weighted_residual_2 (line 304) | def test_forward_weighted_residual_2(self):
  class test_FeedForwardGraph (line 335) | class test_FeedForwardGraph(ut.TestCase):
    method test_forward_no_residual (line 379) | def test_forward_no_residual(self):
    method test_forward_simple_residual (line 437) | def test_forward_simple_residual(self):
    method test_forward_weighted_residual (line 494) | def test_forward_weighted_residual(self):
    method test_forward_concat_residual (line 551) | def test_forward_concat_residual(self):
    method test_forward_densenet_residual (line 609) | def test_forward_densenet_residual(self):

FILE: tests/test_attention.py
  function seed_everything (line 29) | def seed_everything(seed: int):
  class test_MultiHeadAttention (line 40) | class test_MultiHeadAttention(ut.TestCase):
    method test_attention_class (line 58) | def test_attention_class(self):

FILE: tests/test_base_layers.py
  class test_Base_Layers (line 27) | class test_Base_Layers(ut.TestCase):
    method test_droppath_layer_0p5 (line 48) | def test_droppath_layer_0p5(self):
    method test_droppath_layer_1p0 (line 56) | def test_droppath_layer_1p0(self):
    method test_droppath_layer_0p0 (line 65) | def test_droppath_layer_0p0(self):
    method test_transformer_encoder_layer_mup (line 73) | def test_transformer_encoder_layer_mup(self):

FILE: tests/test_collate.py
  class test_Collate (line 27) | class test_Collate(ut.TestCase):
    method test_collate_labels (line 28) | def test_collate_labels(self):

FILE: tests/test_data_utils.py
  class TestDataUtils (line 20) | class TestDataUtils(ut.TestCase):
    method test_list_datasets (line 21) | def test_list_datasets(
    method test_download_datasets (line 28) | def test_download_datasets(self):

FILE: tests/test_datamodule.py
  class Test_DataModule (line 27) | class Test_DataModule(ut.TestCase):
    method test_ogb_datamodule (line 28) | def test_ogb_datamodule(self):
    method test_none_filtering (line 90) | def test_none_filtering(self):
    method test_caching (line 181) | def test_caching(self):
    method test_datamodule_with_none_molecules (line 336) | def test_datamodule_with_none_molecules(self):
    method test_datamodule_multiple_data_files (line 444) | def test_datamodule_multiple_data_files(self):
    method test_splits_file (line 497) | def test_splits_file(self):

FILE: tests/test_dataset.py
  class Test_Multitask_Dataset (line 22) | class Test_Multitask_Dataset(ut.TestCase):
    method test_multitask_dataset_case_1 (line 29) | def test_multitask_dataset_case_1(self):
    method test_multitask_dataset_case_2 (line 89) | def test_multitask_dataset_case_2(self):
    method test_multitask_dataset_case_3 (line 168) | def test_multitask_dataset_case_3(self):

FILE: tests/test_ensemble_layers.py
  class test_Ensemble_Layers (line 33) | class test_Ensemble_Layers(ut.TestCase):
    method check_ensemble_linear (line 34) | def check_ensemble_linear(
    method test_ensemble_linear (line 99) | def test_ensemble_linear(self):
    method test_ensemble_mureadout_graphium (line 115) | def test_ensemble_mureadout_graphium(self):
    method check_ensemble_fclayer (line 150) | def check_ensemble_fclayer(
    method test_ensemble_fclayer (line 209) | def test_ensemble_fclayer(self):
    method check_ensemble_mlp (line 236) | def check_ensemble_mlp(
    method test_ensemble_mlp (line 302) | def test_ensemble_mlp(self):
    method check_ensemble_feedforwardnn (line 329) | def check_ensemble_feedforwardnn(
    method check_ensemble_feedforwardnn_mean (line 404) | def check_ensemble_feedforwardnn_mean(
    method check_ensemble_feedforwardnn_simple (line 476) | def check_ensemble_feedforwardnn_simple(
    method test_ensemble_feedforwardnn (line 507) | def test_ensemble_feedforwardnn(self):

FILE: tests/test_featurizer.py
  class test_featurizer (line 33) | class test_featurizer(ut.TestCase):
    method test_get_mol_atomic_features_onehot (line 99) | def test_get_mol_atomic_features_onehot(self):
    method test_get_mol_atomic_features_float (line 123) | def test_get_mol_atomic_features_float(self):
    method test_get_mol_atomic_features_float_nan_mask (line 145) | def test_get_mol_atomic_features_float_nan_mask(self):
    method test_get_mol_edge_features (line 178) | def test_get_mol_edge_features(self):
    method test_mol_to_adj_and_features (line 199) | def test_mol_to_adj_and_features(self):
    method test_mol_to_pyggraph (line 236) | def test_mol_to_pyggraph(self):

FILE: tests/test_finetuning.py
  class Test_Finetuning (line 42) | class Test_Finetuning(ut.TestCase):
    method test_finetuning_from_task_head (line 43) | def test_finetuning_from_task_head(self):
    method test_finetuning_from_gnn (line 232) | def test_finetuning_from_gnn(self):

FILE: tests/test_ipu_dataloader.py
  function random_packing (line 31) | def random_packing(num_nodes, batch_size):
  function global_batch_collator (line 39) | def global_batch_collator(batch_size, batches):
  class test_DataLoading (line 49) | class test_DataLoading(ut.TestCase):
    class TestSimpleLightning (line 50) | class TestSimpleLightning(LightningModule):
      method __init__ (line 52) | def __init__(self, batch_size, node_feat_size, edge_feat_size, num_b...
      method validation_step (line 61) | def validation_step(self, batch, batch_idx):
      method training_step (line 66) | def training_step(self, batch, batch_idx):
      method forward (line 71) | def forward(self, batch):
      method assert_shapes (line 76) | def assert_shapes(self, batch, batch_idx, step):
      method configure_optimizers (line 98) | def configure_optimizers(self):
    class TestDataset (line 101) | class TestDataset(torch.utils.data.Dataset):
      method __init__ (line 103) | def __init__(self, labels, node_features, edge_features):
      method __len__ (line 108) | def __len__(self):
      method __getitem__ (line 111) | def __getitem__(self, idx):
    method test_poptorch_simple_deviceiterations_gradient_accumulation (line 116) | def test_poptorch_simple_deviceiterations_gradient_accumulation(self):
    method test_poptorch_graphium_deviceiterations_gradient_accumulation_full (line 193) | def test_poptorch_graphium_deviceiterations_gradient_accumulation_full...

FILE: tests/test_ipu_losses.py
  class test_Losses (line 25) | class test_Losses(ut.TestCase):
    method test_bce (line 40) | def test_bce(self):
    method test_mse (line 81) | def test_mse(self):
    method test_l1 (line 104) | def test_l1(self):
    method test_bce_logits (line 127) | def test_bce_logits(self):

FILE: tests/test_ipu_metrics.py
  class test_Metrics (line 50) | class test_Metrics(ut.TestCase):
    method test_auroc (line 65) | def test_auroc(self):
    method test_average_precision (line 116) | def test_average_precision(self):  # TODO: Make work with multi-class
    method test_precision (line 148) | def test_precision(self):
    method test_accuracy (line 254) | def test_accuracy(self):
    method test_recall (line 341) | def test_recall(self):
    method test_pearsonr (line 429) | def test_pearsonr(self):
    method test_spearmanr (line 451) | def test_spearmanr(self):
    method test_r2_score (line 473) | def test_r2_score(self):
    method test_fbeta_score (line 498) | def test_fbeta_score(self):
    method test_f1_score (line 602) | def test_f1_score(self):
    method test_mse (line 695) | def test_mse(self):
    method test_mae (line 745) | def test_mae(self):

FILE: tests/test_ipu_options.py
  function test_ipu_options (line 65) | def test_ipu_options():
  function test_ipu_options_list_to_file (line 122) | def test_ipu_options_list_to_file():

FILE: tests/test_ipu_poptorch.py
  function test_poptorch (line 18) | def test_poptorch():

FILE: tests/test_ipu_to_dense_batch.py
  class TestIPUBatch (line 39) | class TestIPUBatch:
    method setup_class (line 41) | def setup_class(self):
    method test_ipu_to_dense_batch (line 63) | def test_ipu_to_dense_batch(self, max_num_nodes_per_graph, batch_size):
    method test_ipu_to_dense_batch_no_batch_no_max_nodes (line 106) | def test_ipu_to_dense_batch_no_batch_no_max_nodes(self):
    method test_ipu_to_dense_batch_no_batch (line 119) | def test_ipu_to_dense_batch_no_batch(self):
    method test_ipu_to_dense_batch_drop_last (line 133) | def test_ipu_to_dense_batch_drop_last(self):

FILE: tests/test_loaders.py
  class TestLoader (line 19) | class TestLoader(ut.TestCase):
    method test_merge_dicts (line 20) | def test_merge_dicts(self):

FILE: tests/test_losses.py
  function _parse (line 26) | def _parse(loss_fun):
  class test_HybridCELoss (line 31) | class test_HybridCELoss(ut.TestCase):
    method test_pure_ce_loss (line 38) | def test_pure_ce_loss(self):
    method test_pure_mae_loss (line 46) | def test_pure_mae_loss(self):
    method test_pure_mse_loss (line 61) | def test_pure_mse_loss(self):
    method test_hybrid_loss (line 77) | def test_hybrid_loss(self):
    method test_loss_parser (line 88) | def test_loss_parser(self):
  class test_BCELoss (line 112) | class test_BCELoss(ut.TestCase):
    method test_loss_parser (line 113) | def test_loss_parser(self):

FILE: tests/test_metrics.py
  class test_Metrics (line 32) | class test_Metrics(ut.TestCase):
    method test_thresholder (line 33) | def test_thresholder(self):
  class test_MetricWrapper (line 75) | class test_MetricWrapper(ut.TestCase):
    method test_target_nan_mask (line 76) | def test_target_nan_mask(self):
    method test_pickling (line 142) | def test_pickling(self):
    method test_classifigression_target_squeezing (line 206) | def test_classifigression_target_squeezing(self):

FILE: tests/test_mtl_architecture.py
  function toy_test_data (line 133) | def toy_test_data(in_dim=7, in_dim_edges=3):
  class test_GraphOutputNN (line 153) | class test_GraphOutputNN(ut.TestCase):
    method generate_test_data (line 154) | def generate_test_data(self):
    method test_nodepair_max_num_nodes_not_set (line 207) | def test_nodepair_max_num_nodes_not_set(self):
    method test_nodepair_with_max_num_nodes (line 231) | def test_nodepair_with_max_num_nodes(self):
  class test_TaskHeads (line 257) | class test_TaskHeads(ut.TestCase):
    method test_task_heads_forward (line 258) | def test_task_heads_forward(self):
    method test_task_heads_non_supported_level (line 342) | def test_task_heads_non_supported_level(self):
  class test_Multitask_NN (line 370) | class test_Multitask_NN(ut.TestCase):
    method test_full_graph_multitask_forward (line 396) | def test_full_graph_multitask_forward(self):
    method test_full_graph_multi_task_from_config (line 559) | def test_full_graph_multi_task_from_config(self):
    method test_full_graph_multi_task_set_max_num_nodes (line 591) | def test_full_graph_multi_task_set_max_num_nodes(self):

FILE: tests/test_multitask_datamodule.py
  class Test_Multitask_DataModule (line 25) | class Test_Multitask_DataModule(ut.TestCase):
    method setUp (line 26) | def setUp(self):
    method tearDown (line 30) | def tearDown(self):
    method test_multitask_fromsmiles_dm (line 34) | def test_multitask_fromsmiles_dm(
    method test_multitask_fromsmiles_from_config (line 148) | def test_multitask_fromsmiles_from_config(self):
    method test_multitask_fromsmiles_from_config_csv (line 203) | def test_multitask_fromsmiles_from_config_csv(self):
    method test_multitask_fromsmiles_from_config_parquet (line 230) | def test_multitask_fromsmiles_from_config_parquet(self):
    method test_multitask_with_missing_fromsmiles_from_config_parquet (line 258) | def test_multitask_with_missing_fromsmiles_from_config_parquet(self):
    method test_extract_graph_level_singletask (line 286) | def test_extract_graph_level_singletask(self):
    method test_extract_graph_level_multitask (line 297) | def test_extract_graph_level_multitask(self):
    method test_extract_graph_level_multitask_missing_cols (line 308) | def test_extract_graph_level_multitask_missing_cols(self):
    method test_non_graph_level_extract_labels (line 325) | def test_non_graph_level_extract_labels(self):
    method test_non_graph_level_extract_labels_missing_cols (line 336) | def test_non_graph_level_extract_labels_missing_cols(self):
    method test_tdc_admet_benchmark_data_module (line 357) | def test_tdc_admet_benchmark_data_module(self):

FILE: tests/test_mup.py
  function get_pyg_graphs (line 29) | def get_pyg_graphs(in_dim, in_dim_edges):
  class test_mup (line 43) | class test_mup(ut.TestCase):
    method test_feedforwardnn_mup (line 63) | def test_feedforwardnn_mup(self):
    method test_feedforwardgraph_mup (line 106) | def test_feedforwardgraph_mup(self):
    method test_fullgraphmultitasknetwork (line 151) | def test_fullgraphmultitasknetwork(self):

FILE: tests/test_packing.py
  function random_packing (line 31) | def random_packing(num_nodes, batch_size):
  class test_Packing (line 39) | class test_Packing(ut.TestCase):
    method test_smart_packing (line 40) | def test_smart_packing(self):
    method test_fast_packing (line 79) | def test_fast_packing(self):
    method test_hybrid_packing (line 119) | def test_hybrid_packing(self):
    method test_node_to_pack_indices_mask (line 158) | def test_node_to_pack_indices_mask(self):

FILE: tests/test_pe_nodepair.py
  class test_positional_encodings (line 27) | class test_positional_encodings(ut.TestCase):
    method test_dimensions (line 63) | def test_dimensions(self):
    method test_symmetry (line 74) | def test_symmetry(self):
    method test_max_dist (line 82) | def test_max_dist(self):

FILE: tests/test_pe_rw.py
  class test_pe_spectral (line 25) | class test_pe_spectral(ut.TestCase):
    method test_caching_and_outputs (line 26) | def test_caching_and_outputs(self):

FILE: tests/test_pe_spectral.py
  class test_pe_spectral (line 25) | class test_pe_spectral(ut.TestCase):
    method test_for_connected_vs_disconnected_graph (line 35) | def test_for_connected_vs_disconnected_graph(self):

FILE: tests/test_pos_transfer_funcs.py
  class test_pos_transfer_funcs (line 33) | class test_pos_transfer_funcs(ut.TestCase):
    method test_different_pathways_from_node_to_edge (line 40) | def test_different_pathways_from_node_to_edge(self):

FILE: tests/test_positional_encoders.py
  class test_positional_encoder (line 33) | class test_positional_encoder(ut.TestCase):
    method test_laplacian_eigvec_eigval (line 45) | def test_laplacian_eigvec_eigval(self):
    method test_rwse (line 89) | def test_rwse(self):
    method test_laplacian_eigvec_with_encoder (line 106) | def test_laplacian_eigvec_with_encoder(self):

FILE: tests/test_positional_encodings.py
  class test_positional_encodings (line 29) | class test_positional_encodings(ut.TestCase):
    method test_dimensions (line 65) | def test_dimensions(self):
    method test_symmetry (line 76) | def test_symmetry(self):
    method test_max_dist (line 84) | def test_max_dist(self):

FILE: tests/test_predictor.py
  class test_Predictor (line 25) | class test_Predictor(ut.TestCase):
    method test_parse_loss_fun (line 26) | def test_parse_loss_fun(self):

FILE: tests/test_pyg_layers.py
  class test_Pyg_Layers (line 44) | class test_Pyg_Layers(ut.TestCase):
    method test_gpslayer (line 71) | def test_gpslayer(self):
    method test_ginlayer (line 98) | def test_ginlayer(self):
    method test_ginelayer (line 114) | def test_ginelayer(self):
    method test_mpnnlayer (line 135) | def test_mpnnlayer(self):
    method test_gatedgcnlayer (line 161) | def test_gatedgcnlayer(self):
    method test_pnamessagepassinglayer (line 187) | def test_pnamessagepassinglayer(self):
    method test_dimenetlayer (line 234) | def test_dimenetlayer(self):
    method test_preprocess3Dfeaturelayer (line 297) | def test_preprocess3Dfeaturelayer(self):
    method test_gaussianlayer (line 319) | def test_gaussianlayer(self):
    method test_pooling_virtual_node (line 327) | def test_pooling_virtual_node(self):

FILE: tests/test_residual_connections.py
  class test_ResidualConnectionNone (line 31) | class test_ResidualConnectionNone(ut.TestCase):
    method test_get_true_out_dims_none (line 32) | def test_get_true_out_dims_none(self):
    method test_forward_none (line 41) | def test_forward_none(self):
  class test_ResidualConnectionSimple (line 54) | class test_ResidualConnectionSimple(ut.TestCase):
    method test_get_true_out_dims_simple (line 55) | def test_get_true_out_dims_simple(self):
    method test_forward_simple (line 64) | def test_forward_simple(self):
  class test_ResidualConnectionRandom (line 93) | class test_ResidualConnectionRandom(ut.TestCase):
    method test_get_true_out_dims_random (line 94) | def test_get_true_out_dims_random(self):
    method test_forward_random (line 112) | def test_forward_random(self):
  class test_ResidualConnectionWeighted (line 132) | class test_ResidualConnectionWeighted(ut.TestCase):
    method test_get_true_out_dims_weighted (line 133) | def test_get_true_out_dims_weighted(self):
    method test_forward_weighted (line 142) | def test_forward_weighted(self):
  class test_ResidualConnectionConcat (line 184) | class test_ResidualConnectionConcat(ut.TestCase):
    method test_get_true_out_dims_concat (line 185) | def test_get_true_out_dims_concat(self):
    method test_forward_concat (line 207) | def test_forward_concat(self):
  class test_ResidualConnectionDenseNet (line 236) | class test_ResidualConnectionDenseNet(ut.TestCase):
    method test_get_true_out_dims_densenet (line 237) | def test_get_true_out_dims_densenet(self):
    method test_forward_densenet (line 259) | def test_forward_densenet(self):

FILE: tests/test_training.py
  class TestCLITraining (line 22) | class TestCLITraining:
    method setup_class (line 24) | def setup_class(cls):
    method call_cli_with_overrides (line 51) | def call_cli_with_overrides(self, acc_type: str, acc_prec: str, load_t...
    method test_cpu_cli_training (line 96) | def test_cpu_cli_training(self, load_type):
    method test_ipu_cli_training (line 102) | def test_ipu_cli_training(self, load_type):

FILE: tests/test_utils.py
  class test_nan_statistics (line 37) | class test_nan_statistics(ut.TestCase):
    method test_nan_mean (line 57) | def test_nan_mean(self):
    method test_nan_std_var (line 77) | def test_nan_std_var(self):
    method test_nan_median (line 103) | def test_nan_median(self):
    method test_nan_mad (line 128) | def test_nan_mad(self):
  class test_SafeRun (line 152) | class test_SafeRun(ut.TestCase):
    method test_safe_run (line 153) | def test_safe_run(self):
  class TestTensorFp16ToFp32 (line 181) | class TestTensorFp16ToFp32(ut.TestCase):
    method test_tensor_fp16_to_fp32 (line 182) | def test_tensor_fp16_to_fp32(self):
    method test_dict_tensor_fp16_to_fp32 (line 200) | def test_dict_tensor_fp16_to_fp32(self):
Condensed preview — 369 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (5,205K chars).
[
  {
    "path": ".github/CODEOWNERS",
    "chars": 13,
    "preview": "* @DomInvivo\n"
  },
  {
    "path": ".github/CODE_OF_CONDUCT.md",
    "chars": 5202,
    "preview": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participa"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.md",
    "chars": 625,
    "preview": "## Changelogs\n\n- _enumerate the changes of that PR._\n\n---\n\n_Checklist:_\n\n- [ ] _Was this PR discussed in an issue? It is"
  },
  {
    "path": ".github/workflows/code-check.yml",
    "chars": 534,
    "preview": "name: code-check\n\non:\n  push:\n    branches: [\"main\"]\n    tags: [\"*\"]\n  pull_request:\n    branches:\n      - \"*\"\n      - \""
  },
  {
    "path": ".github/workflows/doc.yml",
    "chars": 1267,
    "preview": "name: doc\n\non:\n  push:\n    branches: [\"main\"]\n\n# Prevent doc action on `main` to conflict with each others.\nconcurrency:"
  },
  {
    "path": ".github/workflows/release.yml",
    "chars": 3766,
    "preview": "name: release\n\non:\n  workflow_dispatch:\n    inputs:\n      release-version:\n        description: \"A valid Semver version "
  },
  {
    "path": ".github/workflows/test.yml",
    "chars": 1580,
    "preview": "name: test\n\non:\n  push:\n    branches: [\"main\"]\n    tags: [\"*\"]\n  pull_request:\n    branches:\n      - \"*\"\n      - \"!gh-pa"
  },
  {
    "path": ".github/workflows/test_ipu.yml",
    "chars": 2325,
    "preview": "name: test-ipu\n\non:\n  push:\n    branches: [\"main\"]\n    tags: [\"*\"]\n  pull_request:\n    branches:\n      - \"*\"\n      - \"!g"
  },
  {
    "path": ".gitignore",
    "chars": 2885,
    "preview": "############ graphium Custom GitIgnore ##############\n\n# Workspace\n*.code-workspace\n.vscode/\n.idea/\n\n# Training logs and"
  },
  {
    "path": "LICENSE",
    "chars": 11976,
    "preview": " Apache License\r\n                           Version 2.0, January 2004\r\n                        http://www.apache.org/lic"
  },
  {
    "path": "README.md",
    "chars": 7071,
    "preview": "<div align=\"center\">\n    <img src=\"docs/images/banner-tight.png\" height=\"200px\">\n    <h3>Scaling molecular GNNs to infin"
  },
  {
    "path": "cleanup_files.sh",
    "chars": 846,
    "preview": "\"\"\"\n--------------------------------------------------------------------------------\nCopyright (c) 2023 Valence Labs, Re"
  },
  {
    "path": "codecov.yml",
    "chars": 694,
    "preview": "coverage:\n  range: \"50...80\"\n  status:\n    project:\n      default:\n        threshold: 1%\n    patch: false\n\ncomment:\n  la"
  },
  {
    "path": "docs/_assets/css/custom-graphium.css",
    "chars": 2324,
    "preview": ":root {\n  --graphium-primary: #548235;\n  --graphium-secondary: #343a40;\n\n  /* Primary color shades */\n  --md-primary-fg-"
  },
  {
    "path": "docs/_assets/css/custom.css",
    "chars": 677,
    "preview": "/* Indentation. */\ndiv.doc-contents:not(.first) {\n  padding-left: 25px;\n  border-left: 4px solid #e6e6e6;\n  margin-botto"
  },
  {
    "path": "docs/_assets/js/google-analytics.js",
    "chars": 326,
    "preview": "var gtag_id = \"G-K76VNHBRM4\";\n\nvar script = document.createElement(\"script\");\nscript.src = \"https://www.googletagmanager"
  },
  {
    "path": "docs/api/graphium.config.md",
    "chars": 177,
    "preview": "graphium.config\n====================\nhelper functions to load and convert config files\n\n::: graphium.config._loader\n::: "
  },
  {
    "path": "docs/api/graphium.data.md",
    "chars": 390,
    "preview": "graphium.data\n====================\nmodule for loading datasets and collating batches\n\n=== \"Contents\"\n\n    * [Data Module"
  },
  {
    "path": "docs/api/graphium.features.md",
    "chars": 668,
    "preview": "graphium.features\n====================\nFeature extraction and manipulation\n\n=== \"Contents\"\n\n    * [Featurizer](#featuriz"
  },
  {
    "path": "docs/api/graphium.finetuning.md",
    "chars": 424,
    "preview": "graphium.finetuning\n====================\nModule for finetuning models and doing linear probing (fingerprinting).\n\n::: gr"
  },
  {
    "path": "docs/api/graphium.ipu.md",
    "chars": 782,
    "preview": "graphium.ipu\n====================\nCode for adapting to run on IPU\n\n=== \"Contents\"\n\n    * [IPU Dataloader](#ipu-dataloade"
  },
  {
    "path": "docs/api/graphium.nn/architectures.md",
    "chars": 499,
    "preview": "graphium.nn.architectures\n====================\n\nHigh level architectures in the library\n\n=== \"Contents\"\n\n    * [Global A"
  },
  {
    "path": "docs/api/graphium.nn/encoders.md",
    "chars": 820,
    "preview": "graphium.nn.encoders\n====================\n\nImplementations of positional encoders in the library\n\n=== \"Contents\"\n\n    * "
  },
  {
    "path": "docs/api/graphium.nn/graphium.nn.md",
    "chars": 469,
    "preview": "graphium.nn\n====================\n\nBase structures for neural networks in the library (to be inherented by other modules)"
  },
  {
    "path": "docs/api/graphium.nn/pyg_layers.md",
    "chars": 809,
    "preview": "graphium.nn.pyg_layers\n====================\nImplementations of GNN layers based on PyG in the library\n\n=== \"Contents\"\n\n "
  },
  {
    "path": "docs/api/graphium.trainer.md",
    "chars": 502,
    "preview": "graphium.trainer\n====================\n\nCode for training models\n\n=== \"Contents\"\n    * [Predictor](#predictor)\n    * [Met"
  },
  {
    "path": "docs/api/graphium.utils.md",
    "chars": 1062,
    "preview": "graphium.utils\n====================\nmodule for utility functions\n\n=== \"Contents\"\n    * [Argument Checker](#argument-chec"
  },
  {
    "path": "docs/baseline.md",
    "chars": 14551,
    "preview": "# ToyMix Baseline - Test set metrics\n\nFrom the paper to be released soon. Below, you can see the baselines for the `ToyM"
  },
  {
    "path": "docs/cli/graphium-train.md",
    "chars": 3337,
    "preview": "# `graphium-train`\n\nTo support advanced configuration, Graphium uses [`hydra`](https://hydra.cc/) to manage and write co"
  },
  {
    "path": "docs/cli/graphium.md",
    "chars": 2972,
    "preview": "# `graphium`\n\n**Usage**:\n\n```console\n$ graphium [OPTIONS] COMMAND [ARGS]...\n```\n\n**Options**:\n\n* `--help`: Show this mes"
  },
  {
    "path": "docs/cli/reference.md",
    "chars": 942,
    "preview": "# CLI Reference\n\nInstalling the Graphium library, makes two CLI tools available. \n\n- [`graphium-train`](./graphium-train"
  },
  {
    "path": "docs/contribute.md",
    "chars": 1074,
    "preview": "# Contribute\n\nWe are happy to see that you want to contribute 🤗.\nFeel free to open an issue or pull request at any time."
  },
  {
    "path": "docs/datasets.md",
    "chars": 7990,
    "preview": "# Graphium Datasets\n\nGraphium datasets are hosted at on Zenodo \n- ***ToyMix*** and  ***LargeMix*** dataseets are hosted "
  },
  {
    "path": "docs/design.md",
    "chars": 3818,
    "preview": "# Graphium Library Design\n\n---\n\n\nThe library is designed with 3 things in mind:\n\n- High modularity and configurability w"
  },
  {
    "path": "docs/index.md",
    "chars": 1232,
    "preview": "# Overview\n\nA deep learning library focused on graph representation learning for real-world chemical tasks.\n\n- ✅ State-o"
  },
  {
    "path": "docs/license.md",
    "chars": 20,
    "preview": "```\n{!LICENSE!}\n```\n"
  },
  {
    "path": "docs/pretrained_models.md",
    "chars": 780,
    "preview": "# Graphium pretrained models\n\nGraphium aims to provide a set of pretrained models that you can use for inference or tran"
  },
  {
    "path": "docs/tutorials/feature_processing/add_new_positional_encoding.ipynb",
    "chars": 12699,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Add new positional encoding\\n\",\n "
  },
  {
    "path": "docs/tutorials/feature_processing/choosing_parallelization.ipynb",
    "chars": 14067,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0693c5c9\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Choosing the"
  },
  {
    "path": "docs/tutorials/feature_processing/csv_to_parquet.ipynb",
    "chars": 2372,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n "
  },
  {
    "path": "docs/tutorials/feature_processing/timing_parallel.ipynb",
    "chars": 14123,
    "preview": "{\n \"cells\": [\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"0693c5c9\",\n   \"metadata\": {},\n   \"source\":"
  },
  {
    "path": "docs/tutorials/gnn/add_new_gnn_layers.ipynb",
    "chars": 23274,
    "preview": "{\n \"cells\": [\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Creating GN"
  },
  {
    "path": "docs/tutorials/gnn/making_gnn_networks.ipynb",
    "chars": 14482,
    "preview": "{\n \"cells\": [\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Making GNN "
  },
  {
    "path": "docs/tutorials/gnn/using_gnn_layers.ipynb",
    "chars": 12991,
    "preview": "{\n \"cells\": [\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Using GNN l"
  },
  {
    "path": "docs/tutorials/model_training/simple-molecular-model.ipynb",
    "chars": 20534,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Building and training a simple mo"
  },
  {
    "path": "enable_ipu.sh",
    "chars": 964,
    "preview": "\"\"\"\n--------------------------------------------------------------------------------\nCopyright (c) 2023 Graphcore Limite"
  },
  {
    "path": "env.yml",
    "chars": 1370,
    "preview": "channels:\n  - conda-forge\n  # - pyg # Add for Windows\n\ndependencies:\n  - python >=3.8\n  - pip\n  - typer\n  - loguru\n  - o"
  },
  {
    "path": "expts/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "expts/configs/config_gps_10M_pcqm4m.yaml",
    "chars": 9051,
    "preview": "# Testing the mpnn only model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name pcqm4mv2_mpnn_4layer\n  seed: &see"
  },
  {
    "path": "expts/configs/config_gps_10M_pcqm4m_mod.yaml",
    "chars": 9271,
    "preview": "# Testing the mpnn only model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name pcqm4mv2_mpnn_4layer\n  seed: &see"
  },
  {
    "path": "expts/configs/config_mpnn_10M_b3lyp.yaml",
    "chars": 10169,
    "preview": "# Testing the mpnn only model with the b3lyp dataset on IPU.\nconstants:\n  name: &name b3lyp_mpnn_4layer\n  seed: &seed 42"
  },
  {
    "path": "expts/configs/config_mpnn_pcqm4m.yaml",
    "chars": 7623,
    "preview": "# Testing the mpnn only model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name pcqm4mv2_mpnn_4layer\n  seed: &see"
  },
  {
    "path": "expts/data/micro_zinc_splits.csv",
    "chars": 5505,
    "preview": "train,val,test\n957,392.0,654.0\n223,588.0,430.0\n916,538.0,317.0\n410,903.0,629.0\n287,923.0,927.0\n496,677.0,641.0\n456,144.0"
  },
  {
    "path": "expts/data/tiny_zinc_splits.csv",
    "chars": 465,
    "preview": "train,val,test\n3,61.0,65.0\n21,64.0,24.0\n26,89.0,90.0\n42,28.0,37.0\n23,95.0,83.0\n87,46.0,63.0\n68,30.0,58.0\n99,91.0,85.0\n77"
  },
  {
    "path": "expts/dataset_benchmark.py",
    "chars": 4735,
    "preview": "import os\nfrom os.path import dirname, abspath\nimport yaml\nfrom omegaconf import DictConfig\nfrom datetime import datetim"
  },
  {
    "path": "expts/debug_yaml.py",
    "chars": 2996,
    "preview": "# test_yaml.py\n\n\nimport yaml\n\nCONFIG_FILE = \"neurips2023_configs/config_small_mpnn.yaml\"\n\nimport re\nfrom collections imp"
  },
  {
    "path": "expts/hydra-configs/README.md",
    "chars": 7445,
    "preview": "# Configuring Graphium with Hydra\nThis document provides users with a point of entry to composing configs in Graphium. A"
  },
  {
    "path": "expts/hydra-configs/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "expts/hydra-configs/accelerator/cpu.yaml",
    "chars": 9,
    "preview": "type: cpu"
  },
  {
    "path": "expts/hydra-configs/accelerator/gpu.yaml",
    "chars": 9,
    "preview": "type: gpu"
  },
  {
    "path": "expts/hydra-configs/accelerator/ipu.yaml",
    "chars": 711,
    "preview": "type: ipu\nipu_config:\n    - deviceIterations(60) # IPU would require large batches to be ready for the model.\n    # 60 f"
  },
  {
    "path": "expts/hydra-configs/accelerator/ipu_pipeline.yaml",
    "chars": 792,
    "preview": "type: ipu\nipu_config:\n    - deviceIterations(60) # IPU would require large batches to be ready for the model.\n    # 60 f"
  },
  {
    "path": "expts/hydra-configs/architecture/largemix.yaml",
    "chars": 3760,
    "preview": "# @package _global_\n\narchitecture:\n  model_type: FullGraphMultiTaskNetwork\n  mup_base_path: null\n  pre_nn:   # Set as nu"
  },
  {
    "path": "expts/hydra-configs/architecture/pcqm4m.yaml",
    "chars": 4232,
    "preview": "# @package _global_\n\narchitecture:\n  model_type: FullGraphMultiTaskNetwork\n  mup_base_path: null\n  pre_nn:   # Set as nu"
  },
  {
    "path": "expts/hydra-configs/architecture/toymix.yaml",
    "chars": 3657,
    "preview": "# @package _global_\n\narchitecture:\n  model_type: FullGraphMultiTaskNetwork\n  mup_base_path: null\n  pre_nn:\n    out_dim: "
  },
  {
    "path": "expts/hydra-configs/experiment/toymix_mpnn.yaml",
    "chars": 308,
    "preview": "# @package _global_\n\nconstants:\n  name: neurips2023_small_data_mpnn\n  entity: \"multitask-gnn\"\n  seed: 42\n  max_epochs: 1"
  },
  {
    "path": "expts/hydra-configs/finetuning/admet.yaml",
    "chars": 2801,
    "preview": "# @package _global_\n\n# == Fine-tuning configs in Graphium ==\n# \n# A fine-tuning config is a appendum to a (pre-)training"
  },
  {
    "path": "expts/hydra-configs/finetuning/admet_baseline.yaml",
    "chars": 1769,
    "preview": "# @package _global_\n\ndefaults:\n  - override /tasks/loss_metrics_datamodule: admet\n\nconstants:\n  task: tbd\n  name: finetu"
  },
  {
    "path": "expts/hydra-configs/hparam_search/optuna.yaml",
    "chars": 1833,
    "preview": "# @package _global_\n#\n# For running a hyper-parameter search, we use the Optuna plugin for hydra.\n# This makes optuna av"
  },
  {
    "path": "expts/hydra-configs/main.yaml",
    "chars": 291,
    "preview": "defaults:\n\n  # Accelerators\n  - accelerator: cpu\n\n  # Pre-training/fine-tuning\n  - architecture: toymix\n  - tasks: toymi"
  },
  {
    "path": "expts/hydra-configs/model/gated_gcn.yaml",
    "chars": 704,
    "preview": "# @package _global_\n\narchitecture:\n  pre_nn_edges:   # Set as null to avoid a pre-nn network\n    out_dim: 32\n    hidden_"
  },
  {
    "path": "expts/hydra-configs/model/gcn.yaml",
    "chars": 67,
    "preview": "# @package _global_\n\narchitecture:\n  gnn:\n    layer_type: 'pyg:gcn'"
  },
  {
    "path": "expts/hydra-configs/model/gin.yaml",
    "chars": 67,
    "preview": "# @package _global_\n\narchitecture:\n  gnn:\n    layer_type: 'pyg:gin'"
  },
  {
    "path": "expts/hydra-configs/model/gine.yaml",
    "chars": 647,
    "preview": "# @package _global_\n\narchitecture:\n  pre_nn_edges:   # Set as null to avoid a pre-nn network\n    out_dim: ${constants.gn"
  },
  {
    "path": "expts/hydra-configs/model/gpspp.yaml",
    "chars": 1196,
    "preview": "# @package _global_\n\ndatamodule:\n  args:\n    batch_size_training: 32\n    featurization:\n      conformer_property_list: ["
  },
  {
    "path": "expts/hydra-configs/model/mpnn.yaml",
    "chars": 677,
    "preview": "# @package _global_\n\narchitecture:\n  pre_nn_edges:   # Set as null to avoid a pre-nn network\n    out_dim: 32\n    hidden_"
  },
  {
    "path": "expts/hydra-configs/tasks/admet.yaml",
    "chars": 247,
    "preview": "# NOTE: We cannot have a single config, since for fine-tuning we will\n#  only want to override the loss_metrics_datamodu"
  },
  {
    "path": "expts/hydra-configs/tasks/l1000_mcf7.yaml",
    "chars": 257,
    "preview": "# NOTE: We cannot have a single config, since for fine-tuning we will\n#  only want to override the loss_metrics_datamodu"
  },
  {
    "path": "expts/hydra-configs/tasks/l1000_vcap.yaml",
    "chars": 257,
    "preview": "# NOTE: We cannot have a single config, since for fine-tuning we will\n#  only want to override the loss_metrics_datamodu"
  },
  {
    "path": "expts/hydra-configs/tasks/largemix.yaml",
    "chars": 253,
    "preview": "# NOTE: We cannot have a single config, since for fine-tuning we will\n#  only want to override the loss_metrics_datamodu"
  },
  {
    "path": "expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml",
    "chars": 4230,
    "preview": "# @package _global_\n\n#Task-specific\npredictor:\n  metrics_on_progress_bar:\n    # All below metrics are directly copied fr"
  },
  {
    "path": "expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_mcf7.yaml",
    "chars": 1629,
    "preview": "# @package _global_\n\npredictor:\n  metrics_on_progress_bar:\n    l1000_mcf7: []\n  metrics_on_training_set:\n    l1000_mcf7:"
  },
  {
    "path": "expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_vcap.yaml",
    "chars": 1629,
    "preview": "# @package _global_\n\npredictor:\n  metrics_on_progress_bar:\n    l1000_vcap: []\n  metrics_on_training_set:\n    l1000_vcap:"
  },
  {
    "path": "expts/hydra-configs/tasks/loss_metrics_datamodule/largemix.yaml",
    "chars": 6144,
    "preview": "# @package _global_\n\npredictor:\n  metrics_on_progress_bar:\n    l1000_vcap: []\n    l1000_mcf7: []\n    pcba_1328: []\n    p"
  },
  {
    "path": "expts/hydra-configs/tasks/loss_metrics_datamodule/pcba_1328.yaml",
    "chars": 1500,
    "preview": "# @package _global_\n\npredictor:\n  metrics_on_progress_bar:\n    pcba_1328: []\n  metrics_on_training_set:\n    pcba_1328: ["
  },
  {
    "path": "expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m.yaml",
    "chars": 1684,
    "preview": "# @package _global_\n\n#Task-specific\npredictor:\n  metrics_on_progress_bar:\n    homolumo: []\n  metrics_on_training_set:\n  "
  },
  {
    "path": "expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_g25.yaml",
    "chars": 1606,
    "preview": "# @package _global_\n\npredictor:\n  metrics_on_progress_bar:\n    pcqm4m_g25: []\n  metrics_on_training_set:\n    pcqm4m_g25:"
  },
  {
    "path": "expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_n4.yaml",
    "chars": 1474,
    "preview": "# @package _global_\n\npredictor:\n  metrics_on_progress_bar:\n    pcqm4m_n4: []\n  metrics_on_training_set:\n    pcqm4m_n4: ["
  },
  {
    "path": "expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml",
    "chars": 3700,
    "preview": "# @package _global_\n\npredictor:\n  metrics_on_progress_bar:\n    qm9: [\"mae\"]\n    tox21: [\"auroc\"]\n    zinc: [\"mae\"]\n  los"
  },
  {
    "path": "expts/hydra-configs/tasks/pcba_1328.yaml",
    "chars": 255,
    "preview": "# NOTE: We cannot have a single config, since for fine-tuning we will\n#  only want to override the loss_metrics_datamodu"
  },
  {
    "path": "expts/hydra-configs/tasks/pcqm4m.yaml",
    "chars": 249,
    "preview": "# NOTE: We cannot have a single config, since for fine-tuning we will\n#  only want to override the loss_metrics_datamodu"
  },
  {
    "path": "expts/hydra-configs/tasks/pcqm4m_g25.yaml",
    "chars": 257,
    "preview": "# NOTE: We cannot have a single config, since for fine-tuning we will\n#  only want to override the loss_metrics_datamodu"
  },
  {
    "path": "expts/hydra-configs/tasks/pcqm4m_n4.yaml",
    "chars": 255,
    "preview": "# NOTE: We cannot have a single config, since for fine-tuning we will\n#  only want to override the loss_metrics_datamodu"
  },
  {
    "path": "expts/hydra-configs/tasks/task_heads/admet.yaml",
    "chars": 1457,
    "preview": "# @package _global_\n\narchitecture:\n  task_heads:\n    caco2_wang: &regression_head\n      task_level: graph\n      out_dim:"
  },
  {
    "path": "expts/hydra-configs/tasks/task_heads/l1000_mcf7.yaml",
    "chars": 360,
    "preview": "# @package _global_\n\narchitecture:\n  task_heads:\n    l1000_mcf7:\n      task_level: graph\n      out_dim: 2934\n      hidde"
  },
  {
    "path": "expts/hydra-configs/tasks/task_heads/l1000_vcap.yaml",
    "chars": 360,
    "preview": "# @package _global_\n\narchitecture:\n  task_heads:\n    l1000_vcap:\n      task_level: graph\n      out_dim: 2934\n      hidde"
  },
  {
    "path": "expts/hydra-configs/tasks/task_heads/largemix.yaml",
    "chars": 1598,
    "preview": "# @package _global_\n\narchitecture:\n  task_heads:\n    l1000_vcap:\n      task_level: graph\n      out_dim: 2934\n      hidde"
  },
  {
    "path": "expts/hydra-configs/tasks/task_heads/pcba_1328.yaml",
    "chars": 358,
    "preview": "# @package _global_\n\narchitecture:\n  task_heads:\n    pcba_1328:\n      task_level: graph\n      out_dim: 1328\n      hidden"
  },
  {
    "path": "expts/hydra-configs/tasks/task_heads/pcqm4m.yaml",
    "chars": 365,
    "preview": "# @package _global_\n\narchitecture:\n  task_heads:\n    homolumo:\n      task_level: graph\n      out_dim: 1\n      hidden_dim"
  },
  {
    "path": "expts/hydra-configs/tasks/task_heads/pcqm4m_g25.yaml",
    "chars": 357,
    "preview": "# @package _global_\n\narchitecture:\n  task_heads:\n    pcqm4m_g25:\n      task_level: graph\n      out_dim: 25\n      hidden_"
  },
  {
    "path": "expts/hydra-configs/tasks/task_heads/pcqm4m_n4.yaml",
    "chars": 354,
    "preview": "# @package _global_\n\narchitecture:\n  task_heads:\n    pcqm4m_n4:\n      task_level: node\n      out_dim: 4\n      hidden_dim"
  },
  {
    "path": "expts/hydra-configs/tasks/task_heads/toymix.yaml",
    "chars": 958,
    "preview": "# @package _global_\n\narchitecture:\n  task_heads:\n    qm9:\n      task_level: graph\n      out_dim: 19\n      hidden_dims: 1"
  },
  {
    "path": "expts/hydra-configs/tasks/toymix.yaml",
    "chars": 249,
    "preview": "# NOTE: We cannot have a single config, since for fine-tuning we will\n#  only want to override the loss_metrics_datamodu"
  },
  {
    "path": "expts/hydra-configs/training/accelerator/largemix_cpu.yaml",
    "chars": 376,
    "preview": "# @package _global_\n\ndatamodule:\n  args:\n    batch_size_training: 200\n    batch_size_inference: 200\n    featurization_n_"
  },
  {
    "path": "expts/hydra-configs/training/accelerator/largemix_gpu.yaml",
    "chars": 400,
    "preview": "# @package _global_\n\naccelerator:\n  float32_matmul_precision: medium\n\ndatamodule:\n  args:\n    batch_size_training: 2048\n"
  },
  {
    "path": "expts/hydra-configs/training/accelerator/largemix_ipu.yaml",
    "chars": 597,
    "preview": "# @package _global_\n\ndatamodule:\n    args:\n      ipu_dataloader_training_opts:\n        mode: async\n        max_num_nodes"
  },
  {
    "path": "expts/hydra-configs/training/accelerator/pcqm4m_ipu.yaml",
    "chars": 545,
    "preview": "# @package _global_\n\ndatamodule:\n  args:\n    ipu_dataloader_training_opts:\n      mode: async\n      max_num_nodes_per_gra"
  },
  {
    "path": "expts/hydra-configs/training/accelerator/toymix_cpu.yaml",
    "chars": 392,
    "preview": "# @package _global_\n\ndatamodule:\n  args:\n    batch_size_training: 200\n    batch_size_inference: 200\n    featurization_n_"
  },
  {
    "path": "expts/hydra-configs/training/accelerator/toymix_gpu.yaml",
    "chars": 423,
    "preview": "# @package _global_\n\naccelerator:\n  float32_matmul_precision: medium\n\ndatamodule:\n  args:\n    batch_size_training: 200\n "
  },
  {
    "path": "expts/hydra-configs/training/accelerator/toymix_ipu.yaml",
    "chars": 544,
    "preview": "# @package _global_\n\ndatamodule:\n  args:\n    ipu_dataloader_training_opts:\n      mode: async\n      max_num_nodes_per_gra"
  },
  {
    "path": "expts/hydra-configs/training/largemix.yaml",
    "chars": 1040,
    "preview": "# @package _global_\n\npredictor:\n  random_seed: ${constants.seed}\n  optim_kwargs:\n    lr: 1.e-4 # warmup can be scheduled"
  },
  {
    "path": "expts/hydra-configs/training/model/largemix_gated_gcn.yaml",
    "chars": 490,
    "preview": "# @package _global_\n\nconstants:\n  name: large_data_gated_gcn\n  wandb:\n    name: ${constants.name}\n    project: neurips20"
  },
  {
    "path": "expts/hydra-configs/training/model/largemix_gcn.yaml",
    "chars": 443,
    "preview": "# @package _global_\n\nconstants:\n  name: large_data_gcn\n  wandb:\n    name: ${constants.name}\n    project: neurips2023-exp"
  },
  {
    "path": "expts/hydra-configs/training/model/largemix_gin.yaml",
    "chars": 443,
    "preview": "# @package _global_\n\nconstants:\n  name: large_data_gin\n  wandb:\n    name: ${constants.name}\n    project: neurips2023-exp"
  },
  {
    "path": "expts/hydra-configs/training/model/largemix_gine.yaml",
    "chars": 479,
    "preview": "# @package _global_\n\nconstants:\n  name: large_data_gine\n  wandb:\n    name: ${constants.name}\n    project: neurips2023-ex"
  },
  {
    "path": "expts/hydra-configs/training/model/largemix_mpnn.yaml",
    "chars": 481,
    "preview": "# @package _global_\n\nconstants:\n  name: large_data_mpnn\n  wandb:\n    name: ${constants.name}\n    project: neurips2023-ex"
  },
  {
    "path": "expts/hydra-configs/training/model/pcqm4m_gpspp.yaml",
    "chars": 373,
    "preview": "# @package _global_\n\n# GPS++ model with the PCQMv2 dataset.\nconstants:\n  name: pcqm4mv2_gpspp_4layer\n  seed: 42\n  max_ep"
  },
  {
    "path": "expts/hydra-configs/training/model/pcqm4m_mpnn.yaml",
    "chars": 370,
    "preview": "# @package _global_\n\n# MPNN model with the PCQMv2 dataset.\nconstants:\n  name: pcqm4mv2_mpnn_4layer\n  seed: 42\n  max_epoc"
  },
  {
    "path": "expts/hydra-configs/training/model/toymix_gcn.yaml",
    "chars": 326,
    "preview": "# @package _global_\n\nconstants:\n  name: neurips2023_small_data_gcn\n  seed: 42\n  max_epochs: 100\n  data_dir: expts/data/n"
  },
  {
    "path": "expts/hydra-configs/training/model/toymix_gin.yaml",
    "chars": 330,
    "preview": "# @package _global_\n\nconstants:\n  name: neurips2023_small_data_gin\n  seed: 42\n  max_epochs: 100\n  data_dir: expts/data/n"
  },
  {
    "path": "expts/hydra-configs/training/pcqm4m.yaml",
    "chars": 996,
    "preview": "# @package _global_\n\npredictor:\n  random_seed: ${constants.seed}\n  optim_kwargs:\n    lr: 4.e-4 # warmup can be scheduled"
  },
  {
    "path": "expts/hydra-configs/training/toymix.yaml",
    "chars": 650,
    "preview": "# @package _global_\n\npredictor:\n  random_seed: ${constants.seed}\n  optim_kwargs:\n    lr: 4.e-5 # warmup can be scheduled"
  },
  {
    "path": "expts/main_run_get_fingerprints.py",
    "chars": 2902,
    "preview": "# General imports\nimport os\nfrom os.path import dirname, abspath\nimport yaml\nimport numpy as np\nimport pandas as pd\nimpo"
  },
  {
    "path": "expts/main_run_multitask.py",
    "chars": 351,
    "preview": "import hydra\nfrom omegaconf import DictConfig\n\n\n@hydra.main(version_base=None, config_path=\"hydra-configs\", config_name="
  },
  {
    "path": "expts/main_run_predict.py",
    "chars": 2033,
    "preview": "# General imports\nimport os\nfrom os.path import dirname, abspath\nimport yaml\nfrom copy import deepcopy\nfrom omegaconf im"
  },
  {
    "path": "expts/main_run_test.py",
    "chars": 1283,
    "preview": "# General imports\nimport os\nfrom os.path import dirname, abspath\nimport yaml\nfrom copy import deepcopy\nfrom omegaconf im"
  },
  {
    "path": "expts/neurips2023_configs/base_config/large.yaml",
    "chars": 14502,
    "preview": "# @package _global_\n\nconstants:\n  seed: 42\n  raise_train_error: true   # Whether the code should raise an error if it cr"
  },
  {
    "path": "expts/neurips2023_configs/base_config/large_pcba.yaml",
    "chars": 14737,
    "preview": "# @package _global_\n\nconstants:\n  seed: 42\n  raise_train_error: true   # Whether the code should raise an error if it cr"
  },
  {
    "path": "expts/neurips2023_configs/base_config/large_pcqm_g25.yaml",
    "chars": 14754,
    "preview": "# @package _global_\n\nconstants:\n  seed: 42\n  raise_train_error: true   # Whether the code should raise an error if it cr"
  },
  {
    "path": "expts/neurips2023_configs/base_config/large_pcqm_n4.yaml",
    "chars": 14757,
    "preview": "# @package _global_\n\nconstants:\n  seed: 42\n  raise_train_error: true   # Whether the code should raise an error if it cr"
  },
  {
    "path": "expts/neurips2023_configs/base_config/small.yaml",
    "chars": 11236,
    "preview": "# @package _global_\n\nconstants:\n  seed: &seed 42\n  raise_train_error: true   # Whether the code should raise an error if"
  },
  {
    "path": "expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml",
    "chars": 11359,
    "preview": "# Testing the gcn model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name neurips2023_small_data_gcn\n  seed: &see"
  },
  {
    "path": "expts/neurips2023_configs/baseline/config_small_gin_baseline.yaml",
    "chars": 622,
    "preview": "# Testing the gin model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name neurips2023_small_data_gin\n  config_ove"
  },
  {
    "path": "expts/neurips2023_configs/baseline/config_small_gine_baseline.yaml",
    "chars": 859,
    "preview": "# Testing the gine model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name neurips2023_small_data_gine\n  config_o"
  },
  {
    "path": "expts/neurips2023_configs/config_classifigression_l1000.yaml",
    "chars": 10186,
    "preview": "# Testing the mpnn only model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name neurips2023_small_data_mpnn\n  see"
  },
  {
    "path": "expts/neurips2023_configs/config_large_gcn.yaml",
    "chars": 399,
    "preview": "# Running the gcn model with the largemix dataset on IPU.\n\ndefaults:\n  - base_config: large\n  - _self_\n\nconstants:\n  nam"
  },
  {
    "path": "expts/neurips2023_configs/config_large_gcn_g25.yaml",
    "chars": 466,
    "preview": "# Running the gcn model with the largemix dataset on IPU.\n\ndefaults:\n  # - base_config: large\n  - base_config: large_pcq"
  },
  {
    "path": "expts/neurips2023_configs/config_large_gcn_gpu.yaml",
    "chars": 3162,
    "preview": "# Testing GCN on LargeMix with FP16/32 on GPU\n\ndefaults:\n  - base_config: large\n  - _self_\n\nconstants:\n  name: neurips20"
  },
  {
    "path": "expts/neurips2023_configs/config_large_gcn_n4.yaml",
    "chars": 466,
    "preview": "# Running the gcn model with the largemix dataset on IPU.\n\ndefaults:\n  # - base_config: large\n  # - base_config: large_p"
  },
  {
    "path": "expts/neurips2023_configs/config_large_gcn_pcba.yaml",
    "chars": 496,
    "preview": "# Running the gcn model with the largemix dataset on IPU.\n\ndefaults:\n  # - base_config: large\n  # - base_config: large_p"
  },
  {
    "path": "expts/neurips2023_configs/config_large_gin.yaml",
    "chars": 342,
    "preview": "# Running the gin model with the largemix dataset on IPU.\ndefaults:\n  - base_config: large\n  - _self_\n\nconstants:\n  name"
  },
  {
    "path": "expts/neurips2023_configs/config_large_gin_g25.yaml",
    "chars": 410,
    "preview": "# Running the gin model with the largemix dataset on IPU.\ndefaults:\n  # - base_config: large\n  - base_config: large_pcqm"
  },
  {
    "path": "expts/neurips2023_configs/config_large_gin_n4.yaml",
    "chars": 410,
    "preview": "# Running the gin model with the largemix dataset on IPU.\ndefaults:\n  # - base_config: large\n  # - base_config: large_pc"
  },
  {
    "path": "expts/neurips2023_configs/config_large_gin_pcba.yaml",
    "chars": 440,
    "preview": "# Running the gin model with the largemix dataset on IPU.\ndefaults:\n  # - base_config: large\n  # - base_config: large_pc"
  },
  {
    "path": "expts/neurips2023_configs/config_large_gine.yaml",
    "chars": 910,
    "preview": "# Running the gine model with the largemix dataset on IPU.\n\ndefaults:\n  - base_config: large\n  # - base_config: large_pc"
  },
  {
    "path": "expts/neurips2023_configs/config_large_gine_g25.yaml",
    "chars": 910,
    "preview": "# Running the gine model with the largemix dataset on IPU.\n\ndefaults:\n  # - base_config: large\n  - base_config: large_pc"
  },
  {
    "path": "expts/neurips2023_configs/config_large_gine_n4.yaml",
    "chars": 910,
    "preview": "# Running the gine model with the largemix dataset on IPU.\n\ndefaults:\n  # - base_config: large\n  # - base_config: large_"
  },
  {
    "path": "expts/neurips2023_configs/config_large_gine_pcba.yaml",
    "chars": 940,
    "preview": "# Running the gine model with the largemix dataset on IPU.\n\ndefaults:\n  # - base_config: large\n  # - base_config: large_"
  },
  {
    "path": "expts/neurips2023_configs/config_large_mpnn.yaml",
    "chars": 1797,
    "preview": "# Running the mpnn model with the largemix dataset on IPU.\n\ndefaults:\n - base_config: large\n\nconstants:\n  name: neurips2"
  },
  {
    "path": "expts/neurips2023_configs/config_luis_jama.yaml",
    "chars": 9849,
    "preview": "# Testing the mpnn only model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name neurips2023_small_data_mpnn\n  see"
  },
  {
    "path": "expts/neurips2023_configs/config_small_gated_gcn.yaml",
    "chars": 652,
    "preview": "# Testing the gated_gcn model with the PCQMv2 dataset on IPU.\n\ndefaults:\n  - base_config: small\n  - _self_\n\nconstants:\n "
  },
  {
    "path": "expts/neurips2023_configs/config_small_gcn.yaml",
    "chars": 293,
    "preview": "# Testing the gcn model with the toymix dataset on IPU.\n\ndefaults:\n  - base_config: small\n  - _self_\n\nconstants:\n  name:"
  },
  {
    "path": "expts/neurips2023_configs/config_small_gcn_gpu.yaml",
    "chars": 1978,
    "preview": "# Testing GCN on ToyMix with FP16/32 on GPU\n\ndefaults:\n  - base_config: small\n  - _self_\n\nconstants:\n  name: neurips2023"
  },
  {
    "path": "expts/neurips2023_configs/config_small_gin.yaml",
    "chars": 238,
    "preview": "# Testing the gin model with the PCQMv2 dataset on IPU.\n\ndefaults:\n  - base_config: small\n  - _self_\n\nconstants:\n  name:"
  },
  {
    "path": "expts/neurips2023_configs/config_small_gine.yaml",
    "chars": 636,
    "preview": "# Testing the gine model with the PCQMv2 dataset on IPU.\n\ndefaults:\n  - base_config: small\n  - _self_\n\nconstants:\n  name"
  },
  {
    "path": "expts/neurips2023_configs/config_small_mpnn.yaml",
    "chars": 825,
    "preview": "# Testing the mpnn only model with the PCQMv2 dataset on IPU.\n\ndefaults:\n - base_config: small\n\nconstants:\n  name: neuri"
  },
  {
    "path": "expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml",
    "chars": 8386,
    "preview": "# Testing the gcn model\nconstants:\n  name: &name neurips2023_large_data_gcn_mcf7\n  seed: &seed 42\n  raise_train_error: t"
  },
  {
    "path": "expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml",
    "chars": 8189,
    "preview": "# Testing the gcn model\nconstants:\n  name: &name neurips2023_large_data_gcn_pcba\n  seed: &seed 42\n  raise_train_error: t"
  },
  {
    "path": "expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml",
    "chars": 8384,
    "preview": "# Testing the gcn model\nconstants:\n  name: &name neurips2023_large_data_gcn_vcap\n  seed: &seed 42\n  raise_train_error: t"
  },
  {
    "path": "expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml",
    "chars": 8440,
    "preview": "# Testing the gin model\nconstants:\n  name: &name neurips2023_large_data_gin_g25\n  seed: &seed 42\n  raise_train_error: tr"
  },
  {
    "path": "expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml",
    "chars": 8421,
    "preview": "# Testing the gine model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name neurips2023_large_data_gine_mcf7\n  see"
  },
  {
    "path": "expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml",
    "chars": 8443,
    "preview": "# Testing the gin model\nconstants:\n  name: &name neurips2023_large_data_gin_n4\n  seed: &seed 42\n  raise_train_error: tru"
  },
  {
    "path": "expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml",
    "chars": 8189,
    "preview": "# Testing the gin model\nconstants:\n  name: &name neurips2023_large_data_gin_pcba\n  seed: &seed 42\n  raise_train_error: t"
  },
  {
    "path": "expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml",
    "chars": 9854,
    "preview": "# Testing the gin model\nconstants:\n  name: &name neurips2023_large_data_gin_pcq\n  seed: &seed 42\n  raise_train_error: tr"
  },
  {
    "path": "expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml",
    "chars": 8419,
    "preview": "# Testing the gine model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name neurips2023_large_data_gine_vcap\n  see"
  },
  {
    "path": "expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml",
    "chars": 8681,
    "preview": "# Testing the gine model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name neurips2023_large_data_gine_g25\n  seed"
  },
  {
    "path": "expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml",
    "chars": 8627,
    "preview": "# Testing the gine model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name neurips2023_large_data_gine_mcf7\n  see"
  },
  {
    "path": "expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml",
    "chars": 8684,
    "preview": "# Testing the gine model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name neurips2023_large_data_gine_n4\n  seed:"
  },
  {
    "path": "expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml",
    "chars": 8430,
    "preview": "# Testing the gine model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name neurips2023_large_data_gine_pcba\n  see"
  },
  {
    "path": "expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml",
    "chars": 10095,
    "preview": "# Testing the gine model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name neurips2023_large_data_gine_pcq\n  seed"
  },
  {
    "path": "expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml",
    "chars": 8625,
    "preview": "# Testing the gine model with the PCQMv2 dataset on IPU.\nconstants:\n  name: &name neurips2023_large_data_gine_vcap\n  see"
  },
  {
    "path": "expts/run_validation_test.py",
    "chars": 3166,
    "preview": "# General imports\nimport argparse\nimport os\nfrom os.path import dirname, abspath\nimport yaml\nfrom copy import deepcopy\nf"
  },
  {
    "path": "graphium/__init__.py",
    "chars": 169,
    "preview": "from ._version import __version__\n\nfrom .config import load_config\n\nfrom . import utils\nfrom . import features\nfrom . im"
  },
  {
    "path": "graphium/_version.py",
    "chars": 219,
    "preview": "from importlib.metadata import version\nfrom importlib.metadata import PackageNotFoundError\n\ntry:\n    __version__ = versi"
  },
  {
    "path": "graphium/cli/__init__.py",
    "chars": 124,
    "preview": "from .data import data_app\nfrom .parameters import param_app\nfrom .finetune_utils import finetune_app\nfrom .main import "
  },
  {
    "path": "graphium/cli/__main__.py",
    "chars": 72,
    "preview": "from graphium.cli.main import app\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "graphium/cli/data.py",
    "chars": 1876,
    "preview": "import timeit\nfrom typing import List\nfrom omegaconf import OmegaConf\nimport typer\nimport graphium\n\nfrom loguru import l"
  },
  {
    "path": "graphium/cli/finetune_utils.py",
    "chars": 6249,
    "preview": "from typing import List, Literal, Optional\n\nimport fsspec\nimport numpy as np\nimport torch\nimport tqdm\nimport typer\nimpor"
  },
  {
    "path": "graphium/cli/fingerprints.py",
    "chars": 87,
    "preview": "from .main import app\n\n\n@app.command(name=\"fp\")\ndef get_fingerprints_from_model(): ...\n"
  },
  {
    "path": "graphium/cli/main.py",
    "chars": 94,
    "preview": "import typer\n\n\napp = typer.Typer(add_completion=False)\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "graphium/cli/parameters.py",
    "chars": 1429,
    "preview": "import timeit\nfrom typing import List\nfrom omegaconf import DictConfig, OmegaConf\nimport typer\n\nimport numpy as np\n\nfrom"
  },
  {
    "path": "graphium/cli/train_finetune_test.py",
    "chars": 11712,
    "preview": "from typing import List, Literal, Union\nimport os\nimport time\nimport timeit\nfrom datetime import datetime\n\nimport fsspec"
  },
  {
    "path": "graphium/config/README.md",
    "chars": 349,
    "preview": "<div align=\"center\">\n    <img src=\"../../docs/images/logo-title.png\" height=\"80px\">\n    <h3>The Graph Of LIfe Library.</"
  },
  {
    "path": "graphium/config/__init__.py",
    "chars": 292,
    "preview": "from ._load import load_config\n\nfrom ._loader import load_architecture\nfrom ._loader import load_datamodule\nfrom ._loade"
  },
  {
    "path": "graphium/config/_load.py",
    "chars": 916,
    "preview": "\"\"\"\n--------------------------------------------------------------------------------\nCopyright (c) 2023 Valence Labs, Re"
  },
  {
    "path": "graphium/config/_loader.py",
    "chars": 24263,
    "preview": "\"\"\"\n--------------------------------------------------------------------------------\nCopyright (c) 2023 Valence Labs, Re"
  },
  {
    "path": "graphium/config/config_convert.py",
    "chars": 1535,
    "preview": "\"\"\"\n--------------------------------------------------------------------------------\nCopyright (c) 2023 Valence Labs, Re"
  },
  {
    "path": "graphium/config/dummy_finetuning_from_gnn.yaml",
    "chars": 3184,
    "preview": "# Here, we are finetuning a FullGraphMultitaskNetwork\n# trained on ToyMix. We finetune from the gnn on the \n# TDC datase"
  },
  {
    "path": "graphium/config/dummy_finetuning_from_task_head.yaml",
    "chars": 3377,
    "preview": "# Here, we are finetuning a FullGraphMultitaskNetwork\n# trained on ToyMix. We finetune from the zinc task-head\n# (graph-"
  },
  {
    "path": "graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml",
    "chars": 5192,
    "preview": "datamodule:\n  module_type: \"MultitaskFromSmilesDataModule\"\n  args: # Matches that in the test_multitask_datamodule.py ca"
  },
  {
    "path": "graphium/config/fake_multilevel_multitask_pyg.yaml",
    "chars": 5186,
    "preview": "datamodule:\n  module_type: \"MultitaskFromSmilesDataModule\"\n  args: # Matches that in the test_multitask_datamodule.py ca"
  },
  {
    "path": "graphium/config/zinc_default_multitask_pyg.yaml",
    "chars": 5112,
    "preview": "datamodule:\n  module_type: \"MultitaskFromSmilesDataModule\"\n  args: # Matches that in the test_multitask_datamodule.py ca"
  },
  {
    "path": "graphium/data/L1000/datasets_goli-L1000_parse_gctx_to_csv.ipynb",
    "chars": 450291,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n "
  },
  {
    "path": "graphium/data/QM9/micro_qm9.csv",
    "chars": 225339,
    "preview": "mol_id,smiles,A,B,C,mu,alpha,homo,lumo,gap,r2,zpve,u0,u298,h298,g298,cv,u0_atom,u298_atom,h298_atom,g298_atom\r\ngdb_1,C,1"
  },
  {
    "path": "graphium/data/QM9/norm_micro_qm9.csv",
    "chars": 396078,
    "preview": "mol_id,smiles,A,B,C,mu,alpha,homo,lumo,gap,r2,zpve,u0,u298,h298,g298,cv,u0_atom,u298_atom,h298_atom,g298_atom\ngdb_1,C,0."
  },
  {
    "path": "graphium/data/README.md",
    "chars": 587,
    "preview": "<div align=\"center\">\n    <img src=\"../../docs/images/logo-title.png\" height=\"80px\">\n    <h3>The Graph Of LIfe Library.</"
  },
  {
    "path": "graphium/data/__init__.py",
    "chars": 408,
    "preview": "from .utils import load_micro_zinc\nfrom .utils import load_tiny_zinc\n\nfrom .collate import graphium_collate_fn\n\nfrom .da"
  },
  {
    "path": "graphium/data/collate.py",
    "chars": 14294,
    "preview": "\"\"\"\n--------------------------------------------------------------------------------\nCopyright (c) 2023 Valence Labs, Re"
  },
  {
    "path": "graphium/data/datamodule.py",
    "chars": 116339,
    "preview": "\"\"\"\n--------------------------------------------------------------------------------\nCopyright (c) 2023 Valence Labs, Re"
  },
  {
    "path": "graphium/data/dataset.py",
    "chars": 30394,
    "preview": "\"\"\"\n--------------------------------------------------------------------------------\nCopyright (c) 2023 Valence Labs, Re"
  }
]

// ... and 169 more files (download for full content)

About this extraction

This page contains the full source code of the datamol-io/graphium GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 369 files (4.8 MB), approximately 1.3M tokens, and a symbol index with 1124 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!