Repository: datamol-io/graphium
Branch: main
Commit: f1c180938718
Files: 369
Total size: 4.8 MB
Directory structure:
gitextract_455isr46/
├── .github/
│ ├── CODEOWNERS
│ ├── CODE_OF_CONDUCT.md
│ ├── PULL_REQUEST_TEMPLATE.md
│ └── workflows/
│ ├── code-check.yml
│ ├── doc.yml
│ ├── release.yml
│ ├── test.yml
│ └── test_ipu.yml
├── .gitignore
├── LICENSE
├── README.md
├── cleanup_files.sh
├── codecov.yml
├── docs/
│ ├── _assets/
│ │ ├── css/
│ │ │ ├── custom-graphium.css
│ │ │ └── custom.css
│ │ └── js/
│ │ └── google-analytics.js
│ ├── api/
│ │ ├── graphium.config.md
│ │ ├── graphium.data.md
│ │ ├── graphium.features.md
│ │ ├── graphium.finetuning.md
│ │ ├── graphium.ipu.md
│ │ ├── graphium.nn/
│ │ │ ├── architectures.md
│ │ │ ├── encoders.md
│ │ │ ├── graphium.nn.md
│ │ │ └── pyg_layers.md
│ │ ├── graphium.trainer.md
│ │ └── graphium.utils.md
│ ├── baseline.md
│ ├── cli/
│ │ ├── graphium-train.md
│ │ ├── graphium.md
│ │ └── reference.md
│ ├── contribute.md
│ ├── datasets.md
│ ├── design.md
│ ├── index.md
│ ├── license.md
│ ├── pretrained_models.md
│ └── tutorials/
│ ├── feature_processing/
│ │ ├── add_new_positional_encoding.ipynb
│ │ ├── choosing_parallelization.ipynb
│ │ ├── csv_to_parquet.ipynb
│ │ └── timing_parallel.ipynb
│ ├── gnn/
│ │ ├── add_new_gnn_layers.ipynb
│ │ ├── making_gnn_networks.ipynb
│ │ └── using_gnn_layers.ipynb
│ └── model_training/
│ └── simple-molecular-model.ipynb
├── enable_ipu.sh
├── env.yml
├── expts/
│ ├── __init__.py
│ ├── configs/
│ │ ├── config_gps_10M_pcqm4m.yaml
│ │ ├── config_gps_10M_pcqm4m_mod.yaml
│ │ ├── config_mpnn_10M_b3lyp.yaml
│ │ └── config_mpnn_pcqm4m.yaml
│ ├── data/
│ │ ├── micro_zinc_splits.csv
│ │ └── tiny_zinc_splits.csv
│ ├── dataset_benchmark.py
│ ├── debug_yaml.py
│ ├── hydra-configs/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── accelerator/
│ │ │ ├── cpu.yaml
│ │ │ ├── gpu.yaml
│ │ │ ├── ipu.yaml
│ │ │ └── ipu_pipeline.yaml
│ │ ├── architecture/
│ │ │ ├── largemix.yaml
│ │ │ ├── pcqm4m.yaml
│ │ │ └── toymix.yaml
│ │ ├── experiment/
│ │ │ └── toymix_mpnn.yaml
│ │ ├── finetuning/
│ │ │ ├── admet.yaml
│ │ │ └── admet_baseline.yaml
│ │ ├── hparam_search/
│ │ │ └── optuna.yaml
│ │ ├── main.yaml
│ │ ├── model/
│ │ │ ├── gated_gcn.yaml
│ │ │ ├── gcn.yaml
│ │ │ ├── gin.yaml
│ │ │ ├── gine.yaml
│ │ │ ├── gpspp.yaml
│ │ │ └── mpnn.yaml
│ │ ├── tasks/
│ │ │ ├── admet.yaml
│ │ │ ├── l1000_mcf7.yaml
│ │ │ ├── l1000_vcap.yaml
│ │ │ ├── largemix.yaml
│ │ │ ├── loss_metrics_datamodule/
│ │ │ │ ├── admet.yaml
│ │ │ │ ├── l1000_mcf7.yaml
│ │ │ │ ├── l1000_vcap.yaml
│ │ │ │ ├── largemix.yaml
│ │ │ │ ├── pcba_1328.yaml
│ │ │ │ ├── pcqm4m.yaml
│ │ │ │ ├── pcqm4m_g25.yaml
│ │ │ │ ├── pcqm4m_n4.yaml
│ │ │ │ └── toymix.yaml
│ │ │ ├── pcba_1328.yaml
│ │ │ ├── pcqm4m.yaml
│ │ │ ├── pcqm4m_g25.yaml
│ │ │ ├── pcqm4m_n4.yaml
│ │ │ ├── task_heads/
│ │ │ │ ├── admet.yaml
│ │ │ │ ├── l1000_mcf7.yaml
│ │ │ │ ├── l1000_vcap.yaml
│ │ │ │ ├── largemix.yaml
│ │ │ │ ├── pcba_1328.yaml
│ │ │ │ ├── pcqm4m.yaml
│ │ │ │ ├── pcqm4m_g25.yaml
│ │ │ │ ├── pcqm4m_n4.yaml
│ │ │ │ └── toymix.yaml
│ │ │ └── toymix.yaml
│ │ └── training/
│ │ ├── accelerator/
│ │ │ ├── largemix_cpu.yaml
│ │ │ ├── largemix_gpu.yaml
│ │ │ ├── largemix_ipu.yaml
│ │ │ ├── pcqm4m_ipu.yaml
│ │ │ ├── toymix_cpu.yaml
│ │ │ ├── toymix_gpu.yaml
│ │ │ └── toymix_ipu.yaml
│ │ ├── largemix.yaml
│ │ ├── model/
│ │ │ ├── largemix_gated_gcn.yaml
│ │ │ ├── largemix_gcn.yaml
│ │ │ ├── largemix_gin.yaml
│ │ │ ├── largemix_gine.yaml
│ │ │ ├── largemix_mpnn.yaml
│ │ │ ├── pcqm4m_gpspp.yaml
│ │ │ ├── pcqm4m_mpnn.yaml
│ │ │ ├── toymix_gcn.yaml
│ │ │ └── toymix_gin.yaml
│ │ ├── pcqm4m.yaml
│ │ └── toymix.yaml
│ ├── main_run_get_fingerprints.py
│ ├── main_run_multitask.py
│ ├── main_run_predict.py
│ ├── main_run_test.py
│ ├── neurips2023_configs/
│ │ ├── base_config/
│ │ │ ├── large.yaml
│ │ │ ├── large_pcba.yaml
│ │ │ ├── large_pcqm_g25.yaml
│ │ │ ├── large_pcqm_n4.yaml
│ │ │ └── small.yaml
│ │ ├── baseline/
│ │ │ ├── config_small_gcn_baseline.yaml
│ │ │ ├── config_small_gin_baseline.yaml
│ │ │ └── config_small_gine_baseline.yaml
│ │ ├── config_classifigression_l1000.yaml
│ │ ├── config_large_gcn.yaml
│ │ ├── config_large_gcn_g25.yaml
│ │ ├── config_large_gcn_gpu.yaml
│ │ ├── config_large_gcn_n4.yaml
│ │ ├── config_large_gcn_pcba.yaml
│ │ ├── config_large_gin.yaml
│ │ ├── config_large_gin_g25.yaml
│ │ ├── config_large_gin_n4.yaml
│ │ ├── config_large_gin_pcba.yaml
│ │ ├── config_large_gine.yaml
│ │ ├── config_large_gine_g25.yaml
│ │ ├── config_large_gine_n4.yaml
│ │ ├── config_large_gine_pcba.yaml
│ │ ├── config_large_mpnn.yaml
│ │ ├── config_luis_jama.yaml
│ │ ├── config_small_gated_gcn.yaml
│ │ ├── config_small_gcn.yaml
│ │ ├── config_small_gcn_gpu.yaml
│ │ ├── config_small_gin.yaml
│ │ ├── config_small_gine.yaml
│ │ ├── config_small_mpnn.yaml
│ │ ├── single_task_gcn/
│ │ │ ├── config_large_gcn_mcf7.yaml
│ │ │ ├── config_large_gcn_pcba.yaml
│ │ │ └── config_large_gcn_vcap.yaml
│ │ ├── single_task_gin/
│ │ │ ├── config_large_gin_g25.yaml
│ │ │ ├── config_large_gin_mcf7.yaml
│ │ │ ├── config_large_gin_n4.yaml
│ │ │ ├── config_large_gin_pcba.yaml
│ │ │ ├── config_large_gin_pcq.yaml
│ │ │ └── config_large_gin_vcap.yaml
│ │ └── single_task_gine/
│ │ ├── config_large_gine_g25.yaml
│ │ ├── config_large_gine_mcf7.yaml
│ │ ├── config_large_gine_n4.yaml
│ │ ├── config_large_gine_pcba.yaml
│ │ ├── config_large_gine_pcq.yaml
│ │ └── config_large_gine_vcap.yaml
│ └── run_validation_test.py
├── graphium/
│ ├── __init__.py
│ ├── _version.py
│ ├── cli/
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── data.py
│ │ ├── finetune_utils.py
│ │ ├── fingerprints.py
│ │ ├── main.py
│ │ ├── parameters.py
│ │ └── train_finetune_test.py
│ ├── config/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── _load.py
│ │ ├── _loader.py
│ │ ├── config_convert.py
│ │ ├── dummy_finetuning_from_gnn.yaml
│ │ ├── dummy_finetuning_from_task_head.yaml
│ │ ├── fake_and_missing_multilevel_multitask_pyg.yaml
│ │ ├── fake_multilevel_multitask_pyg.yaml
│ │ └── zinc_default_multitask_pyg.yaml
│ ├── data/
│ │ ├── L1000/
│ │ │ └── datasets_goli-L1000_parse_gctx_to_csv.ipynb
│ │ ├── QM9/
│ │ │ ├── micro_qm9.csv
│ │ │ ├── micro_qm9.parquet
│ │ │ └── norm_micro_qm9.csv
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── collate.py
│ │ ├── datamodule.py
│ │ ├── dataset.py
│ │ ├── make_data_splits/
│ │ │ ├── make_train_val_test_splits_large.ipynb
│ │ │ └── make_train_val_test_splits_small.ipynb
│ │ ├── micro_ZINC/
│ │ │ ├── __init__.py
│ │ │ └── micro_ZINC.csv
│ │ ├── multilevel_utils.py
│ │ ├── multitask/
│ │ │ ├── __init__.py
│ │ │ ├── tiny_ZINC_SA.csv
│ │ │ ├── tiny_ZINC_logp.csv
│ │ │ └── tiny_ZINC_score.csv
│ │ ├── normalization.py
│ │ ├── sampler.py
│ │ ├── sdf2csv.py
│ │ ├── single_atom_dataset/
│ │ │ └── single_atom_dataset.csv
│ │ ├── smiles_transform.py
│ │ └── utils.py
│ ├── expts/
│ │ └── pyg_batching_sparse.ipynb
│ ├── features/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── commute.py
│ │ ├── electrostatic.py
│ │ ├── featurizer.py
│ │ ├── graphormer.py
│ │ ├── nmp.py
│ │ ├── periodic_table.csv
│ │ ├── positional_encoding.py
│ │ ├── properties.py
│ │ ├── rw.py
│ │ ├── spectral.py
│ │ └── transfer_pos_level.py
│ ├── finetuning/
│ │ ├── __init__.py
│ │ ├── finetuning.py
│ │ ├── finetuning_architecture.py
│ │ ├── fingerprinting.py
│ │ └── utils.py
│ ├── hyper_param_search/
│ │ ├── __init__.py
│ │ └── results.py
│ ├── ipu/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── ipu_dataloader.py
│ │ ├── ipu_losses.py
│ │ ├── ipu_metrics.py
│ │ ├── ipu_simple_lightning.py
│ │ ├── ipu_utils.py
│ │ ├── ipu_wrapper.py
│ │ └── to_dense_batch.py
│ ├── nn/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── architectures/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── encoder_manager.py
│ │ │ ├── global_architectures.py
│ │ │ └── pyg_architectures.py
│ │ ├── base_graph_layer.py
│ │ ├── base_layers.py
│ │ ├── encoders/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── base_encoder.py
│ │ │ ├── bessel_pos_encoder.py
│ │ │ ├── gaussian_kernel_pos_encoder.py
│ │ │ ├── laplace_pos_encoder.py
│ │ │ ├── mlp_encoder.py
│ │ │ └── signnet_pos_encoder.py
│ │ ├── ensemble_layers.py
│ │ ├── pyg_layers/
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── dimenet_pyg.py
│ │ │ ├── gated_gcn_pyg.py
│ │ │ ├── gcn_pyg.py
│ │ │ ├── gin_pyg.py
│ │ │ ├── gps_pyg.py
│ │ │ ├── mpnn_pyg.py
│ │ │ ├── pna_pyg.py
│ │ │ ├── pooling_pyg.py
│ │ │ └── utils.py
│ │ ├── residual_connections.py
│ │ └── utils.py
│ ├── trainer/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── losses.py
│ │ ├── metrics.py
│ │ ├── predictor.py
│ │ ├── predictor_options.py
│ │ └── predictor_summaries.py
│ └── utils/
│ ├── README.md
│ ├── __init__.py
│ ├── arg_checker.py
│ ├── command_line_utils.py
│ ├── custom_lr.py
│ ├── decorators.py
│ ├── fs.py
│ ├── hashing.py
│ ├── moving_average_tracker.py
│ ├── mup.py
│ ├── packing.py
│ ├── safe_run.py
│ ├── spaces.py
│ └── tensor.py
├── install_ipu.sh
├── mkdocs.yml
├── notebooks/
│ ├── compare-pretraining-finetuning-performance.ipynb
│ ├── dev-datamodule-invalidate-cache.ipynb
│ ├── dev-datamodule-ogb.ipynb
│ ├── dev-datamodule.ipynb
│ ├── dev-pretrained.ipynb
│ ├── dev-training-loop.ipynb
│ ├── dev.ipynb
│ ├── finetuning-on-tdc-admet-benchmark.ipynb
│ ├── running-fingerprints-from-pretrained-model.ipynb
│ └── running-model-from-config.ipynb
├── profiling/
│ ├── configs_profiling.yaml
│ ├── profile_mol_to_graph.py
│ ├── profile_one_of_k_encoding.py
│ └── profile_predictor.py
├── pyproject.toml
├── scripts/
│ ├── balance_params_and_train.sh
│ ├── convert_yml.py
│ ├── ipu_start.sh
│ ├── ipu_venv.sh
│ └── scale_mpnn.sh
└── tests/
├── .gitignore
├── __init__.py
├── config_test_ipu_dataloader.yaml
├── config_test_ipu_dataloader_multitask.yaml
├── conftest.py
├── converted_fake_multilevel_data.parquet
├── data/
│ ├── config_micro_ZINC.yaml
│ ├── micro_ZINC.csv
│ ├── micro_ZINC_corrupt.csv
│ ├── micro_ZINC_shard_1.csv
│ ├── micro_ZINC_shard_1.parquet
│ ├── micro_ZINC_shard_2.csv
│ ├── micro_ZINC_shard_2.parquet
│ └── pcqm4mv2-2k.csv
├── fake_and_missing_multilevel_data.parquet
├── test_architectures.py
├── test_attention.py
├── test_base_layers.py
├── test_collate.py
├── test_data_utils.py
├── test_datamodule.py
├── test_dataset.py
├── test_ensemble_layers.py
├── test_featurizer.py
├── test_finetuning.py
├── test_ipu_dataloader.py
├── test_ipu_losses.py
├── test_ipu_metrics.py
├── test_ipu_options.py
├── test_ipu_poptorch.py
├── test_ipu_to_dense_batch.py
├── test_loaders.py
├── test_losses.py
├── test_metrics.py
├── test_mtl_architecture.py
├── test_multitask_datamodule.py
├── test_mup.py
├── test_packing.py
├── test_pe_nodepair.py
├── test_pe_rw.py
├── test_pe_spectral.py
├── test_pos_transfer_funcs.py
├── test_positional_encoders.py
├── test_positional_encodings.py
├── test_predictor.py
├── test_pyg_layers.py
├── test_residual_connections.py
├── test_training.py
└── test_utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/CODEOWNERS
================================================
* @DomInvivo
================================================
FILE: .github/CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.
================================================
FILE: .github/PULL_REQUEST_TEMPLATE.md
================================================
## Changelogs
- _enumerate the changes of that PR._
---
_Checklist:_
- [ ] _Was this PR discussed in an issue? It is recommended to first discuss a new feature into a GitHub issue before opening a PR._
- [ ] _Add tests to cover the fixed bug(s) or the new introduced feature(s) (if appropriate)._
- [ ] _Update the API documentation is a new function is added, or an existing one is deleted._
- [ ] _Write concise and explanatory changelogs above._
- [ ] _If possible, assign one of the following labels to the PR: `feature`, `fix` or `test` (or ask a maintainer to do it for you)._
---
_discussion related to that PR_
================================================
FILE: .github/workflows/code-check.yml
================================================
name: code-check
on:
push:
branches: ["main"]
tags: ["*"]
pull_request:
branches:
- "*"
- "!gh-pages"
jobs:
python-format-black:
name: Python lint [black]
runs-on: ubuntu-latest
steps:
- name: Checkout the code
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install black
run: |
pip install black>=23
- name: Lint
run: black --check .
================================================
FILE: .github/workflows/doc.yml
================================================
name: doc
on:
push:
branches: ["main"]
# Prevent doc action on `main` to conflict with each others.
concurrency:
group: doc-${{ github.ref }}
cancel-in-progress: true
jobs:
doc:
runs-on: "ubuntu-latest"
timeout-minutes: 30
defaults:
run:
shell: bash -l {0}
steps:
- name: Checkout the code
uses: actions/checkout@v3
- name: Setup mamba
uses: mamba-org/setup-micromamba@v1
with:
environment-file: env.yml
environment-name: graphium
cache-environment: true
cache-downloads: true
- name: Install library
run: |
python -m pip install --no-deps .
pip install typer-cli
- name: Configure git
run: |
git config --global user.name "${GITHUB_ACTOR}"
git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com"
- name: Deploy the doc
run: |
echo "Auto-generating typer docs"
typer graphium.cli.__main__ utils docs --name graphium --output docs/cli/graphium.md
echo "Get the gh-pages branch"
git fetch origin gh-pages
echo "Build and deploy the doc on main"
mike deploy --push main
================================================
FILE: .github/workflows/release.yml
================================================
name: release
on:
workflow_dispatch:
inputs:
release-version:
description: "A valid Semver version string"
required: true
permissions:
contents: write
pull-requests: write
jobs:
release:
# Do not release if not triggered from the default branch
if: github.ref == format('refs/heads/{0}', github.event.repository.default_branch)
runs-on: ubuntu-latest
timeout-minutes: 30
defaults:
run:
shell: bash -l {0}
steps:
- name: Checkout the code
uses: actions/checkout@v3
- name: Setup mamba
uses: mamba-org/setup-micromamba@v1
with:
environment-file: env.yml
environment-name: graphium
cache-environment: true
cache-downloads: true
create-args: >-
pip
semver
python-build
setuptools_scm
- name: Check the version is valid semver
run: |
RELEASE_VERSION="${{ inputs.release-version }}"
{
pysemver check $RELEASE_VERSION
} || {
echo "The version '$RELEASE_VERSION' is not a valid Semver version string."
echo "Please use a valid semver version string. More details at https://semver.org/"
echo "The release process is aborted."
exit 1
}
- name: Check the version is higher than the latest one
run: |
# Retrieve the git tags first
git fetch --prune --unshallow --tags &> /dev/null
RELEASE_VERSION="${{ inputs.release-version }}"
LATEST_VERSION=$(git describe --abbrev=0 --tags)
IS_HIGHER_VERSION=$(pysemver compare $RELEASE_VERSION $LATEST_VERSION)
if [ "$IS_HIGHER_VERSION" != "1" ]; then
echo "The version '$RELEASE_VERSION' is not higher than the latest version '$LATEST_VERSION'."
echo "The release process is aborted."
exit 1
fi
- name: Build Changelog
id: github_release
uses: mikepenz/release-changelog-builder-action@v4
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
toTag: "main"
- name: Configure git
run: |
git config --global user.name "${GITHUB_ACTOR}"
git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com"
- name: Create and push git tag
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Tag the release
git tag -a "${{ inputs.release-version }}" -m "Release version ${{ inputs.release-version }}"
# Checkout the git tag
git checkout "${{ inputs.release-version }}"
# Push the modified changelogs
git push origin main
# Push the tags
git push origin "${{ inputs.release-version }}"
- name: Install library
run: python -m pip install --no-deps .
- name: Build the wheel and sdist
run: python -m build --no-isolation
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}
packages-dir: dist/
- name: Deploy the doc
run: |
echo "Get the gh-pages branch"
git fetch origin gh-pages
echo "Build and deploy the doc on ${{ inputs.release-version }}"
mike deploy --push stable
mike deploy --push ${{ inputs.release-version }}
- name: Create GitHub Release
uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844
with:
tag_name: ${{ inputs.release-version }}
body: ${{steps.github_release.outputs.changelog}}
================================================
FILE: .github/workflows/test.yml
================================================
name: test
on:
push:
branches: ["main"]
tags: ["*"]
pull_request:
branches:
- "*"
- "!gh-pages"
schedule:
- cron: "0 4 * * *"
jobs:
test:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
pytorch-version: ["2.0"]
runs-on: "ubuntu-latest"
timeout-minutes: 30
defaults:
run:
shell: bash -l {0}
name: |
regular_env -
python=${{ matrix.python-version }} -
pytorch=${{ matrix.pytorch-version }}
steps:
- name: Checkout the code
uses: actions/checkout@v3
- name: Setup mamba
uses: mamba-org/setup-micromamba@v1
with:
environment-file: env.yml
environment-name: graphium
cache-environment: true
cache-downloads: true
create-args: >-
python=${{ matrix.python-version }}
pytorch=${{ matrix.pytorch-version }}
- name: Install library
run: python -m pip install --no-deps -e . # `-e` required for correct `coverage` run.
- name: Run tests
run: pytest -m 'not ipu'
- name: Test CLI
run: graphium --help
- name: Test building the doc
run: mkdocs build
- name: Codecov Upload
uses: codecov/codecov-action@v3
with:
files: ./coverage.xml
flags: unittests
name: codecov-umbrella
fail_ci_if_error: false
verbose: false
env_vars: ${{ matrix.python-version }},${{ matrix.pytorch-version }}
================================================
FILE: .github/workflows/test_ipu.yml
================================================
name: test-ipu
on:
push:
branches: ["main"]
tags: ["*"]
pull_request:
branches:
- "*"
- "!gh-pages"
schedule:
- cron: "0 4 * * *"
jobs:
test-ipu:
strategy:
fail-fast: false
matrix:
python-version: ["3.8"]
pytorch-version: ["2.0"]
runs-on: "ubuntu-20.04"
timeout-minutes: 30
defaults:
run:
shell: bash -l {0}
name: |
poptorch_env -
python=${{ matrix.python-version }} -
pytorch=${{ matrix.pytorch-version }}
steps:
- name: Checkout the code
uses: actions/checkout@v3
- name: Activate SDK + Install Requirements
run: |
python3 -m pip install --upgrade pip
wget -q -O 'poplar_sdk-ubuntu_20_04-3.3.0-208993bbb7.tar.gz' 'https://downloads.graphcore.ai/direct?package=poplar-poplar_sdk_ubuntu_20_04_3.3.0_208993bbb7-3.3.0&file=poplar_sdk-ubuntu_20_04-3.3.0-208993bbb7.tar.gz'
tar -xzf poplar_sdk-ubuntu_20_04-3.3.0-208993bbb7.tar.gz
python3 -m pip install poplar_sdk-ubuntu_20_04-3.3.0+1403-208993bbb7/poptorch-3.3.0+113432_960e9c294b_ubuntu_20_04-cp38-cp38-linux_x86_64.whl
# Enable Poplar SDK (including Poplar and PopART)
source poplar_sdk-ubuntu_20_04-3.3.0+1403-208993bbb7/enable
python -c "import poptorch"
# Download the datafiles (Total ~ 10Mb - nothing compared to the libraries)
wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz
wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz
wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/qm9.csv.gz
# Install the IPU specific and graphium requirements
pip install -r requirements_ipu.txt
# Install Graphium in dev mode
python -m pip install --no-deps -e .
python3 -m pytest -m 'not skip_ipu'
- name: Codecov Upload
uses: codecov/codecov-action@v3
with:
files: ./coverage.xml
flags: unittests
name: codecov-umbrella
fail_ci_if_error: false
verbose: false
env_vars: ${{ matrix.python-version }},${{ matrix.pytorch-version }}
================================================
FILE: .gitignore
================================================
############ graphium Custom GitIgnore ##############
# Workspace
*.code-workspace
.vscode/
.idea/
# Training logs and models
*.out
*.cache
*.ckpt
*.pickle
*.datacache
*.datacache.gz
lightning_logs/
logs/
multirun/
hparam-search-results/
models_checkpoints/
outputs/
out/
saved_models/
wandb/
out/
datacache/
tests/temp_cache*
predictions/
draft/
scripts-expts/
sweeps/
mup/
# Data and predictions
graphium/data/ZINC_bench_gnn/
graphium/data/BindingDB/
graphium/data/mega-pubchem/
graphium/data/cache/
graphium/data/b3lyp/
graphium/data/PCQM4Mv2/
graphium/data/PCQM4M/
graphium/data/neurips2023/small-dataset/
graphium/data/neurips2023/large-dataset/
graphium/data/neurips2023/dummy-dataset/
graphium/data/make_data_splits/*.csv*
graphium/data/make_data_splits/*.pt*
graphium/data/make_data_splits/*.parquet*
*.csv.gz
*.pt
# Others
expts_untracked/
debug/
change_commits.sh
graphium/features/test_new_pes.ipynb
# IPU related ignores and profiler outputs
*.a
*.cbor
*.capnp
*.pop
*.popart
*.pop_cache
*.popef
*.pvti*
############ END graphium Custom GitIgnore ##############
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
*.txt
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2023 Valence Labs
Copyright 2023 Recursion Pharmaceuticals
Copyright 2023 Graphcore Limited
Various Academic groups have also contributed to this software under
the given license. These include, but are not limited, to the following
1. University of Montreal
2. McGill University
3. Mila - Institut Quebecois d'intelligence artificielle
4. HEC Montreal
5. CIFAR AI Chair
6. New Jersey Institute of Technology
7. RWTH Aachen University
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
Scaling molecular GNNs to infinity
---
[](https://pypi.org/project/graphium/)
[](https://anaconda.org/conda-forge/graphium)
[](https://pypi.org/project/graphium/)
[](https://anaconda.org/conda-forge/graphium)
[](https://github.com/datamol-io/graphium/blob/main/LICENSE)
[](https://github.com/datamol-io/graphium/stargazers)
[](https://github.com/datamol-io/graphium/network/members)
[](https://github.com/datamol-io/graphium/actions/workflows/test.yml)
[](https://github.com/datamol-io/graphium/actions/workflows/test_ipu.yml)
[](https://github.com/datamol-io/graphium/actions/workflows/release.yml)
[](https://github.com/datamol-io/graphium/actions/workflows/code-check.yml)
[](https://github.com/datamol-io/graphium/actions/workflows/doc.yml)
[](https://codecov.io/gh/datamol-io/graphium)
[](https://hydra.cc/)
A deep learning library focused on graph representation learning for real-world chemical tasks.
- ✅ State-of-the-art GNN architectures.
- 🐍 Extensible API: build your own GNN model and train it with ease.
- ⚗️ Rich featurization: powerful and flexible built-in molecular featurization.
- 🧠 Pretrained models: for fast and easy inference or transfer learning.
- ⮔ Read-to-use training loop based on [Pytorch Lightning](https://www.pytorchlightning.ai/).
- 🔌 Have a new dataset? Graphium provides a simple plug-and-play interface. Change the path, the name of the columns to predict, the atomic featurization, and you’re ready to play!
## Documentation
Visit https://graphium-docs.datamol.io/.
## Installation for developers
### For CPU and GPU developers
Use [`mamba`](https://github.com/mamba-org/mamba), a faster and better alternative to `conda`.
If you are using a GPU, we recommend enforcing the CUDA version that you need with `CONDA_OVERRIDE_CUDA=XX.X`.
```bash
# Install Graphium's dependencies in a new environment named `graphium`
mamba env create -f env.yml -n graphium
# To force the CUDA version to 11.2, or any other version you prefer, use the following command:
# CONDA_OVERRIDE_CUDA=11.2 mamba env create -f env.yml -n graphium
# Install Graphium in dev mode
mamba activate graphium
pip install --no-deps -e .
```
### For IPU developers
```bash
# Install Graphcore's SDK and Graphium dependencies in a new environment called `.graphium_ipu`
./install_ipu.sh .graphium_ipu
```
The above step needs to be done once. After that, enable the SDK and the environment as follows:
```bash
source enable_ipu.sh .graphium_ipu
```
## Training a model
To learn how to train a model, we invite you to look at the documentation, or the jupyter notebooks available [here](https://github.com/datamol-io/graphium/tree/master/docs/tutorials/model_training).
If you are not familiar with [PyTorch](https://pytorch.org/docs) or [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/), we highly recommend going through their tutorial first.
## Running an experiment
We have setup Graphium with `hydra` for managing config files. To run an experiment go to the `expts/` folder. For example, to benchmark a GCN on the ToyMix dataset run
```bash
graphium-train architecture=toymix tasks=toymix training=toymix model=gcn
```
To change parameters specific to this experiment like switching from `fp16` to `fp32` precision, you can either override them directly in the CLI via
```bash
graphium-train architecture=toymix tasks=toymix training=toymix model=gcn trainer.trainer.precision=32
```
or change them permanently in the dedicated experiment config under `expts/hydra-configs/toymix_gcn.yaml`.
Integrating `hydra` also allows you to quickly switch between accelerators. E.g., running
```bash
graphium-train architecture=toymix tasks=toymix training=toymix model=gcn accelerator=gpu
```
automatically selects the correct configs to run the experiment on GPU.
Finally, you can also run a fine-tuning loop:
```bash
graphium-train +finetuning=admet
```
To use a config file you built from scratch you can run
```bash
graphium-train --config-path [PATH] --config-name [CONFIG]
```
Thanks to the modular nature of `hydra` you can reuse many of our config settings for your own experiments with Graphium.
## Preparing the data in advance
The data preparation including the featurization (e.g., of molecules from smiles to pyg-compatible format) is embedded in the pipeline and will be performed when executing `graphium-train [...]`.
However, when working with larger datasets, it is recommended to perform data preparation in advance using a machine with sufficient allocated memory (e.g., ~400GB in the case of `LargeMix`). Preparing data in advance is also beneficial when running lots of concurrent jobs with identical molecular featurization, so that resources aren't wasted and processes don't conflict reading/writing in the same directory.
The following command-line will prepare the data and cache it, then use it to train a model.
```bash
# First prepare the data and cache it in `path_to_cached_data`
graphium data prepare ++datamodule.args.processed_graph_data_path=[path_to_cached_data]
# Then train the model on the prepared data
graphium-train [...] datamodule.args.processed_graph_data_path=[path_to_cached_data]
```
**Note** that `datamodule.args.processed_graph_data_path` can also be specified at `expts/hydra_configs/`.
**Note** that, every time the configs of `datamodule.args.featurization` changes, you will need to run a new data preparation, which will automatically be saved in a separate directory that uses a hash unique to the configs.
## License
Under the Apache-2.0 license. See [LICENSE](LICENSE).
## Documentation
- Diagram for data processing in Graphium.
- Diagram for Muti-task network in Graphium
================================================
FILE: cleanup_files.sh
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
# Delete saved files like wandb, checkpoints, profiling etc.
# Usage: bash cleanup_files.sh
rm -rf wandb/*
rm -rf checkpoints/*
rm -rf pro_vision/*
rm -rf outputs/*
rm -rf models_checkpoints/*
rm -rf datacache/*
================================================
FILE: codecov.yml
================================================
coverage:
range: "50...80"
status:
project:
default:
threshold: 1%
patch: false
comment:
layout: "header, diff, flags, components" # show component info in the PR comment
component_management:
default_rules: # default rules that will be inherited by all components
statuses:
- type: project # in this case every component that doens't have a status defined will have a project type one
target: auto
branches:
- "!main"
individual_components:
- component_id: ipu # this is an identifier that should not be changed
name: ipu # this is a display name, and can be changed freely
paths:
- graphium/ipu/**
================================================
FILE: docs/_assets/css/custom-graphium.css
================================================
:root {
--graphium-primary: #548235;
--graphium-secondary: #343a40;
/* Primary color shades */
--md-primary-fg-color: var(--graphium-primary);
--md-primary-fg-color--light: var(--graphium-primary);
--md-primary-fg-color--dark: var(--graphium-primary);
--md-primary-bg-color: var(--graphium-secondary);
--md-primary-bg-color--light: var(--graphium-secondary);
--md-text-link-color: var(--graphium-secondary);
/* Accent color shades */
--md-accent-fg-color: var(--graphium-secondary);
--md-accent-fg-color--transparent: var(--graphium-secondary);
--md-accent-bg-color: var(--graphium-secondary);
--md-accent-bg-color--light: var(--graphium-secondary);
}
:root > * {
/* Code block color shades */
--md-code-bg-color: hsla(0, 0%, 96%, 1);
--md-code-fg-color: hsla(200, 18%, 26%, 1);
/* Footer */
--md-footer-bg-color: var(--graphium-primary);
/* --md-footer-bg-color--dark: hsla(0, 0%, 0%, 0.32); */
--md-footer-fg-color: var(--graphium-secondary);
--md-footer-fg-color--light: var(--graphium-secondary);
--md-footer-fg-color--lighter: var(--graphium-secondary);
}
.md-header {
background-image: linear-gradient(to right, #548235, #8DB771);
}
.md-footer {
background-image: linear-gradient(to right, #548235, #8DB771);
}
.md-tabs {
background-image: linear-gradient(to right, #f4f6f9, #e2cec3);
}
.md-header__topic {
color: #fff;
}
.md-source__repository,
.md-source__icon,
.md-search__input,
.md-search__input::placeholder,
.md-search__input ~ .md-search__icon,
.md-footer__inner.md-grid,
.md-copyright__highlight,
.md-copyright,
.md-footer-meta.md-typeset a,
.md-version {
color: #fff !important;
}
.md-search__form {
background-color: rgba(255, 255, 255, 0.2);
}
.md-search__input {
color: #222 !important;
}
.md-header__topic {
color: #fff;
font-size: 1.4em;
}
/* Increase the size of the logo */
.md-header__button.md-logo img,
.md-header__button.md-logo svg {
height: 2rem !important;
}
/* Reduce the margin around the logo */
.md-header__button.md-logo {
margin: 0.4em;
padding: 0.4em;
}
/* Remove the `In` and `Out` block in rendered Jupyter notebooks */
.md-container .jp-Cell-outputWrapper .jp-OutputPrompt.jp-OutputArea-prompt,
.md-container .jp-Cell-inputWrapper .jp-InputPrompt.jp-InputArea-prompt {
display: none !important;
}
================================================
FILE: docs/_assets/css/custom.css
================================================
/* Indentation. */
div.doc-contents:not(.first) {
padding-left: 25px;
border-left: 4px solid #e6e6e6;
margin-bottom: 80px;
}
/* Don't capitalize names. */
h5.doc-heading {
text-transform: none !important;
}
/* Don't use vertical space on hidden ToC entries. */
.hidden-toc::before {
margin-top: 0 !important;
padding-top: 0 !important;
}
/* Don't show permalink of hidden ToC entries. */
.hidden-toc a.headerlink {
display: none;
}
/* Avoid breaking parameters name, etc. in table cells. */
td code {
word-break: normal !important;
}
/* For pieces of Markdown rendered in table cells. */
td p {
margin-top: 0 !important;
margin-bottom: 0 !important;
}
================================================
FILE: docs/_assets/js/google-analytics.js
================================================
var gtag_id = "G-K76VNHBRM4";
var script = document.createElement("script");
script.src = "https://www.googletagmanager.com/gtag/js?id=" + gtag_id;
document.head.appendChild(script);
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag("js", new Date());
gtag("config", gtag_id);
================================================
FILE: docs/api/graphium.config.md
================================================
graphium.config
====================
helper functions to load and convert config files
::: graphium.config._loader
::: graphium.config.config_convert
::: graphium.config._load
================================================
FILE: docs/api/graphium.data.md
================================================
graphium.data
====================
module for loading datasets and collating batches
=== "Contents"
* [Data Module](#data-module)
* [Collate Module](#collate-module)
* [Util Functions](#util-functions)
## Data Module
------------
::: graphium.data.datamodule
## Collate Module
------------
::: graphium.data.collate
## Util Functions
------------
::: graphium.data.utils
================================================
FILE: docs/api/graphium.features.md
================================================
graphium.features
====================
Feature extraction and manipulation
=== "Contents"
* [Featurizer](#featurizer)
* [Positional Encoding](#positional-encoding)
* [Properties](#properties)
* [Spectral PE](#spectral-pe)
* [Random Walk PE](#random-walk-pe)
* [NMP](#nmp)
## Featurizer
------------
::: graphium.features.featurizer
## Positional Encoding
------------
::: graphium.features.positional_encoding
## Properties
------------
::: graphium.features.properties
## Spectral PE
------------
::: graphium.features.spectral
## Random Walk PE
------------
::: graphium.features.rw
## NMP
------------
::: graphium.features.nmp
================================================
FILE: docs/api/graphium.finetuning.md
================================================
graphium.finetuning
====================
Module for finetuning models and doing linear probing (fingerprinting).
::: graphium.finetuning.finetuning.GraphFinetuning
::: graphium.finetuning.finetuning_architecture.FullGraphFinetuningNetwork
::: graphium.finetuning.finetuning_architecture.PretrainedModel
::: graphium.finetuning.finetuning_architecture.FinetuningHead
::: graphium.finetuning.fingerprinting.Fingerprinter
================================================
FILE: docs/api/graphium.ipu.md
================================================
graphium.ipu
====================
Code for adapting to run on IPU
=== "Contents"
* [IPU Dataloader](#ipu-dataloader)
* [IPU Losses](#ipu-losses)
* [IPU Metrics](#ipu-metrics)
* [IPU Simple Lightning](#ipu-simple-lightning)
* [IPU Utils](#ipu-utils)
* [IPU Wrapper](#ipu-wrapper)
* [To Dense Batch](#to-dense-batch)
## IPU Dataloader
------------
::: graphium.ipu.ipu_dataloader
## IPU Losses
------------
::: graphium.ipu.ipu_losses
## IPU Metrics
------------
::: graphium.ipu.ipu_metrics
## IPU Simple Lightning
------------
::: graphium.ipu.ipu_simple_lightning
## IPU Utils
------------
::: graphium.ipu.ipu_utils
## IPU Wrapper
------------
::: graphium.ipu.ipu_wrapper
## To Dense Batch
------------
::: graphium.ipu.to_dense_batch
================================================
FILE: docs/api/graphium.nn/architectures.md
================================================
graphium.nn.architectures
====================
High level architectures in the library
=== "Contents"
* [Global Architectures](#global-architectures)
* [PyG Architectures](#pyg-architectures)
* [Encoder Manager](#encoder-manager)
## Global Architectures
------------
::: graphium.nn.architectures.global_architectures
## PyG Architectures
------------
::: graphium.nn.architectures.pyg_architectures
## Encoder Manager
------------
::: graphium.nn.architectures.encoder_manager
================================================
FILE: docs/api/graphium.nn/encoders.md
================================================
graphium.nn.encoders
====================
Implementations of positional encoders in the library
=== "Contents"
* [Base Encoder](#base-encoder)
* [Gaussian Kernal Positional Encoder](#gaussian-kernal-positional-encoder)
* [Laplacian Positional Encoder](#laplacian-positional-encoder)
* [MLP Encoder](#mlp-encoder)
* [Signnet Positional Encoder](#signnet-positional-encoder)
## Base Encoder
------------
::: graphium.nn.encoders.base_encoder
## Gaussian Kernal Positional Encoder
------------
::: graphium.nn.encoders.gaussian_kernel_pos_encoder
## Laplacian Positional Encoder
------------
::: graphium.nn.encoders.laplace_pos_encoder
## MLP Encoder
------------
::: graphium.nn.encoders.mlp_encoder
## Signnet Positional Encoder
------------
::: graphium.nn.encoders.signnet_pos_encoder
================================================
FILE: docs/api/graphium.nn/graphium.nn.md
================================================
graphium.nn
====================
Base structures for neural networks in the library (to be inherented by other modules)
=== "Contents"
* [Base Graph Layer](#base-graph-layer)
* [Base Layers](#base-layers)
* [Residual Connection](#residual-connections)
## Base Graph Layer
------------
::: graphium.nn.base_graph_layer
## Base Layers
------------
::: graphium.nn.base_layers
## Residual Connections
------------
::: graphium.nn.residual_connections
================================================
FILE: docs/api/graphium.nn/pyg_layers.md
================================================
graphium.nn.pyg_layers
====================
Implementations of GNN layers based on PyG in the library
=== "Contents"
* [Gated GCN Layer](#gated-gcn-layer)
* [GPS Layer](#gps-layer)
* [GIN and GINE Layers](#gin-and-gine-layers)
* [MPNN Layer](#mpnn-layer)
* [PNA Layer](#pna-layer)
* [Pooling Layer](#pooling-layers)
## Gated GCN Layer
------------
::: graphium.nn.pyg_layers.gated_gcn_pyg
## GPS Layer
------------
::: graphium.nn.pyg_layers.gps_pyg
## GIN and GINE Layers
------------
::: graphium.nn.pyg_layers.gin_pyg
## MPNN Layer
------------
::: graphium.nn.pyg_layers.mpnn_pyg
## PNA Layer
------------
::: graphium.nn.pyg_layers.pna_pyg
## Pooling Layers
------------
::: graphium.nn.pyg_layers.pooling_pyg
## Utils
------------
::: graphium.nn.pyg_layers.utils
================================================
FILE: docs/api/graphium.trainer.md
================================================
graphium.trainer
====================
Code for training models
=== "Contents"
* [Predictor](#predictor)
* [Metrics](#metrics)
* [Predictor Summaries](#predictor-summaries)
* [Predictor Options](#predictor-options)
## Predictor
------------
::: graphium.trainer.predictor
## Metrics
------------
::: graphium.trainer.metrics
## Predictor Summaries
------------
::: graphium.trainer.predictor_summaries
## Predictor Options
------------
::: graphium.trainer.predictor_options
================================================
FILE: docs/api/graphium.utils.md
================================================
graphium.utils
====================
module for utility functions
=== "Contents"
* [Argument Checker](#argument-checker)
* [Decorators](#decorators)
* [Dictionary of Tensors](#dictionary-of-tensors)
* [File System](#file-system)
* [Hashing](#hashing)
* [Moving Average Tracker](#moving-average-tracker)
* [MUP](#mup)
* [Read File](#read-file)
* [Safe Run](#safe-run)
* [Spaces](#spaces)
* [Tensor](#tensor)
## Argument Checker
----------------
::: graphium.utils.arg_checker
## Decorators
----------------
::: graphium.utils.decorators
## File System
----------------
::: graphium.utils.fs
## Hashing
----------------
::: graphium.utils.hashing
## Moving Average Tracker
----------------
::: graphium.utils.moving_average_tracker
## MUP
----------------
::: graphium.utils.mup
## Read File
----------------
::: graphium.utils.read_file
## Safe Run
----------------
::: graphium.utils.safe_run
## Spaces
----------------
::: graphium.utils.spaces
## Tensor
----------------
::: graphium.utils.tensor
================================================
FILE: docs/baseline.md
================================================
# ToyMix Baseline - Test set metrics
From the paper to be released soon. Below, you can see the baselines for the `ToyMix` dataset, a multitasking dataset comprising of `QM9`, `Zinc12k` and `Tox21`. The datasets and their splits are available on [this link](https://zenodo.org/record/7998401). The following baselines are all for models with ~150k parameters.
One can observe that the smaller datasets (`Zinc12k` and `Tox21`) beneficiate from adding another unrelated task (`QM9`), where the labels are computed from DFT simulations.
**NEW baselines added 2023/09/18**: Multitask baselines have been added for GatedGCN and MPNN++ (sum aggretator) using 3 random seeds. They achieve the best performance by a significant margin on Zinc12k and Tox21, while sacrificing a little on QM9.
| Dataset | Model | MAE ↓ | Pearson ↑ | R² ↑ | MAE ↓ | Pearson ↑ | R² ↑ |
|-----------|-------|-----------|-----------|-----------|---------|-----------|---------|
| |
|
|
| **Tox21** | GCN | 0.202 ± 0.005 | 0.773 ± 0.006 | 0.334 ± 0.03 | 0.176 ± 0.001 | 0.850 ± 0.006 | 0.446 ± 0.01 |
| | GIN | 0.200 ± 0.002 | 0.789 ± 0.009 | 0.350 ± 0.01 | 0.176 ± 0.001 | 0.841 ± 0.005 | 0.454 ± 0.009 |
| | GINE | 0.201 ± 0.007 | 0.783 ± 0.007 | 0.345 ± 0.02 | 0.177 ± 0.0008 | 0.836 ± 0.004 | 0.455 ± 0.008 |
| | GatedGCN | | | | 0.1733 ± 0.0015 | 0.8522 ± 0.0022 | **0.4620 ± 0.0118** |
| | MPNN++ (sum) | | | | **0.1725 ± 0.0012** | **0.8569 ± 0.0005** | 0.4598 ± 0.0044 |
# LargeMix Baseline
## LargeMix test set metrics
From the paper to be released soon. Below, you can see the baselines for the `LargeMix` dataset, a multitasking dataset comprising of `PCQM4M_N4`, `PCQM4M_G25`, `PCBA_1328`, `L1000_VCAP`, and `L1000_MCF7`. The datasets and their splits are available on [this link](https://zenodo.org/record/7998401). The following baselines are all for models with 4-6M parameters.
One can observe that the smaller datasets (`L1000_VCAP` and `L1000_MCF7`) beneficiate tremendously from the multitasking. Indeed, the lack of molecular samples means that it is very easy for a model to overfit.
While `PCQM4M_G25` has no noticeable changes, the node predictions of `PCQM4M_N4` and assay predictions of `PCBA_1328` take a hit, but it is most likely due to underfitting since the training loss is also increased. It seems that 4-6M parameters is far from sufficient to capturing all of the tasks simultaneously, which motivates the need for a larger model.
| Dataset | Model | MAE ↓ | Pearson ↑ | R² ↑ | MAE ↓ | Pearson ↑ | R² ↑ |
|-----------|-------|-----------|-----------|-----------|---------|-----------|---------|
| |
|
| | | | | | | | |
| **Pcba\_1328** | GCN | **0.0316 ± 0.0000** | **0.7960 ± 0.0020** | **0.3368 ± 0.0027** | 0.0349 ± 0.0002 | 0.7661 ± 0.0031 | 0.2527 ± 0.0041 |
| | GIN | 0.0324 ± 0.0000 | 0.7941 ± 0.0018 | 0.3328 ± 0.0019 | 0.0342 ± 0.0001 | 0.7747 ± 0.0025 | 0.2650 ± 0.0020 |
| | GINE | 0.0320 ± 0.0001 | 0.7944 ± 0.0023 | 0.3337 ± 0.0027 | 0.0341 ± 0.0001 | 0.7737 ± 0.0007 | 0.2611 ± 0.0043 |
| **L1000\_vcap** | GCN | 0.1900 ± 0.0002 | 0.5788 ± 0.0034 | 0.3708 ± 0.0007 | 0.1872 ± 0.0020 | 0.6362 ± 0.0012 | 0.4022 ± 0.0008 |
| | GIN | 0.1909 ± 0.0005 | 0.5734 ± 0.0029 | 0.3731 ± 0.0014 | 0.1870 ± 0.0010 | 0.6351 ± 0.0014 | 0.4062 ± 0.0001 |
| | GINE | 0.1907 ± 0.0006 | 0.5708 ± 0.0079 | 0.3705 ± 0.0015 | **0.1862 ± 0.0007** | **0.6398 ± 0.0043** | **0.4068 ± 0.0023** |
| **L1000\_mcf7** | GCN | 0.1869 ± 0.0003 | 0.6123 ± 0.0051 | 0.3866 ± 0.0010 | 0.1863 ± 0.0011 | **0.6401 ± 0.0021** | 0.4194 ± 0.0004 |
| | GIN | 0.1862 ± 0.0003 | 0.6202 ± 0.0091 | 0.3876 ± 0.0017 | 0.1874 ± 0.0013 | 0.6367 ± 0.0066 | **0.4198 ± 0.0036** |
| | GINE | **0.1856 ± 0.0005** | 0.6166 ± 0.0017 | 0.3892 ± 0.0035 | 0.1873 ± 0.0009 | 0.6347 ± 0.0048 | 0.4177 ± 0.0024 |
## LargeMix training set loss
Below is the loss on the training set. One can observe that the multi-task model always underfits the single-task, except on the two `L1000` datasets.
This is not surprising as they contain two orders of magnitude more datapoints and pose a significant challenge for the relatively small models used in this analysis. This favors the Single dataset setup (which uses a model of the same size) and we conjecture larger models to bridge this gap moving forward.
| | | CE or MSE loss in single-task $\downarrow$ | CE or MSE loss in multi-task $\downarrow$ |
|------------|-------|-----------------------------------------|-----------------------------------------|
|
| **Pcqm4m\_g25** | GCN | **0.2660 ± 0.0005** | 0.2767 ± 0.0015 |
| | GIN | **0.2439 ± 0.0004** | 0.2595 ± 0.0016 |
| | GINE | **0.2424 ± 0.0007** | 0.2568 ± 0.0012 |
|
| **Pcqm4m\_n4** | GCN | **0.2515 ± 0.0002** | 0.2613 ± 0.0008 |
| | GIN | **0.2317 ± 0.0003** | 0.2512 ± 0.0008 |
| | GINE | **0.2272 ± 0.0001** | 0.2483 ± 0.0004 |
|
| **Pcba\_1328** | GCN | **0.0284 ± 0.0010** | 0.0382 ± 0.0005 |
| | GIN | **0.0249 ± 0.0017** | 0.0359 ± 0.0011 |
| | GINE | **0.0258 ± 0.0017** | 0.0361 ± 0.0008 |
|
| **L1000\_vcap** | GCN | 0.1906 ± 0.0036 | **0.1854 ± 0.0148** |
| | GIN | 0.1854 ± 0.0030 | **0.1833 ± 0.0185** |
| | GINE | **0.1860 ± 0.0025** | 0.1887 ± 0.0200 |
|
| **L1000\_mcf7** | GCN | 0.1902 ± 0.0038 | **0.1829 ± 0.0095** |
| | GIN | 0.1873 ± 0.0033 | **0.1701 ± 0.0142** |
| | GINE | 0.1883 ± 0.0039 | **0.1771 ± 0.0010** |
## NEW: Largemix improved sweep - 2023/08-18
Unsatisfied with the prior results, we ran a bayesian search over a broader set of parameters, and including only more expressive models, namely GINE, GatedGCN and MPNN++. We further increase the number of parameters to 10M due to evidence of underfitting. We evaluate only the multitask setting.
We observe a significant improvement over all tasks, with a very notable r2-score increase of +0.53 (0.27 -> 0.80) compared to the best node-level property prediction on PCQM4M_N4.
The results are reported below over 1 seed. We are currently running more seeds of the same models.
| Dataset | Model | MAE ↓ | Pearson ↑ | R² ↑ |
|---------------|----------------|--------|---------|--------|
| **PCQM4M_G25** | GINE | 0.2250 | 0.8840 | 0.7911 |
| | GatedGCN | 0.2457 | 0.8698 | 0.7688 |
| | MPNN++ (sum) | 0.2269 | 0.8802 | 0.7855 |
|
| **PCQM4M_N4** | GINE | 0.2699 | 0.8475 | 0.7182 |
| | GatedGCN | 0.3337 | 0.8102 | 0.6566 |
| | MPNN++ (sum) | 0.2114 | 0.8942 | 0.8000 |
| Dataset | Model | BCE ↓ | AUROC ↑ | AP ↑ |
|---------------|----------------|--------|---------|--------|
| **PCBA_1328** | GINE | 0.0334 | 0.7879 | 0.2808 |
| | GatedGCN | 0.0351 | 0.7788 | 0.2611 |
| | MPNN++ (sum) | 0.0344 | 0.7815 | 0.2666 |
|
| **L1000_VCAP** | GINE | 0.1907 | 0.6416 | 0.4042 |
| | GatedGCN | 0.1866 | 0.6395 | 0.4092 |
| | MPNN++ (sum) | 0.1867 | 0.6478 | 0.4131 |
|
| **L1000_MCF7** | GINE | 0.1931 | 0.6352 | 0.4235 |
| | GatedGCN | 0.1859 | 0.6547 | 0.4224 |
| | MPNN++ (sum) | 0.1870 | 0.6593 | 0.4254 |
# UltraLarge Baseline
## UltraLarge test set metrics
For `UltraLarge`, we provide results for the same GNN baselines as for
`LargeMix`. Each model is trained for 50 epochs and results are averaged over 3 seeds. The remaining
setup is the same as for TOYMIX (Section E.1), reporting metrics on the Single Dataset and Multi Dataset using the same performance metrics. We further use the same models (in terms of size) as used for `LargeMix`.
For now, we report only the results for a subset representing 5% of the total dataset due to computational constraint, but aim to provide the full results soon.
Results discussion. `UltraLarge` results can be found in Table 6. Interestingly, on both graph- and node-level tasks we observe that there is no advantage of multi-tasking in terms of performance. We
expect that for this ultra-large dataset, significantly larger models are needed to successfully leverage the multi-task setup. This could be attributed to underfitting, as already demonstrated for `LargeMix`. Nonetheless, our baselines set the stage for large-scale pre-training on `UltraLarge`.
The results presented used approximately 500 GPU hours of compute, with
more compute used for development and hyperparameter search.
We further note that the graph-level tasks results are very strong. Regarding the node-level tasks, they are expected to underperform in low-parameters regime, due to clear signs of underfitting, a very large amount of labels to learn, and susceptibility to over-smoothing from traditional GNNs.
| Dataset | Model | MAE ↓ | Pearson ↑ | R² ↑ | MAE ↓ | Pearson ↑ | R² ↑ |
|------------------|-------|-------------------|-------------------|-------------------|-------------------|-------------------|-------------------|
| |
Single-Task Model
Multi-Task Model
|
| | | | | | | | |
| **Pm6_83m_g62** | GCN | .2606 ± .0011 | .9004 ± .0003 | .7997 ± .0009 | .2625 ± .0011 | .8896 ± .0001 | .7982 ± .0001 |
| | GIN | .2546 ± .0021 | .9051 ± .0019 | .8064 ± .0037 | .2562 ± .0000 | .8901 ± .0000 | .806 ± .0000 |
| | GINE | **.2538 ± .0006** | **.9059 ± .0010** | **.8082 ± .0015** | .258 ± .0011 | .904 ± .0000 | .8048 ± .0001 |
|
| **Pm6_83m_n7** | GCN | .5803 ± .0001 | .3372 ± .0004 | .1191 ± .0002 | .5971 ± .0002 | .3164 ± .0001 | .1019 ± .0011 |
| | GIN | .573 ± .0002 | .3478 ± .0001 | **.1269 ± .0002** | .5831 ± .0001 | .3315 ± .0005 | .1141 ± .0000 |
| | GINE | **.572 ± .0004** | **.3487 ± .0002** | .1266 ± .0001 | .5839 ± .0004 | .3294 ± .0002 | .1104 ± .0000 |
## UltraLarge training set loss
In the table below, we observe that the multi-task model slightly underfits the single-task model, indicating that parameters can be efficiently shared between the node-level and graph-level tasks. We further note that the training loss and the test MAE are almost equal for all tasks, indicating further benefits as we scale both the model and the data.
| | | **MAE loss in single-task ↓** | **MAE loss in multi-task ↓** |
|------------------|-------|---------------------------|--------------------------|
|
| Pm6_83m_g62 | GCN | **.2679 ± .0020** | .2713 ± .0017 |
| | GIN | **.2582 ± .0018** | .2636 ± .0014 |
| | GINE | **.2567 ± .0036** | .2603 ± .0021 |
|
| Pm6_83m_n7 | GCN | **.5818 ± .0021** | .5955 ± .0023 |
| | GIN | **.5707 ± .0019** | .5851 ± .0038 |
| | GINE | **.5724 ± .0015** | .5832 ± .0027 |
================================================
FILE: docs/cli/graphium-train.md
================================================
# `graphium-train`
To support advanced configuration, Graphium uses [`hydra`](https://hydra.cc/) to manage and write config files. A limitation of `hydra`, is that it is designed to function as the main entrypoint for a CLI application and does not easily support subcommands. For that reason, we introduced the `graphium-train` command in addition to the [`graphium`](./graphium.md) command.
!!! info "Curious about the configs?"
If you would like to learn more about the configs, please visit the docs [here](https://github.com/datamol-io/graphium/tree/main/expts/hydra-configs).
This page documents `graphium-train`.
## Running an experiment
We have setup Graphium with `hydra` for managing config files. To run an experiment go to the `expts/` folder. For example, to benchmark a GCN on the ToyMix dataset run
```bash
graphium-train architecture=toymix tasks=toymix training=toymix model=gcn
```
To change parameters specific to this experiment like switching from `fp16` to `fp32` precision, you can either override them directly in the CLI via
```bash
graphium-train architecture=toymix tasks=toymix training=toymix model=gcn trainer.trainer.precision=32
```
or change them permanently in the dedicated experiment config under `expts/hydra-configs/toymix_gcn.yaml`.
Integrating `hydra` also allows you to quickly switch between accelerators. E.g., running
```bash
graphium-train architecture=toymix tasks=toymix training=toymix model=gcn accelerator=gpu
```
automatically selects the correct configs to run the experiment on GPU.
Finally, you can also run a fine-tuning loop:
```bash
graphium-train +finetuning=admet
```
To use a config file you built from scratch you can run
```bash
graphium-train --config-path [PATH] --config-name [CONFIG]
```
Thanks to the modular nature of `hydra` you can reuse many of our config settings for your own experiments with Graphium.
### Preparing the data in advance
The data preparation including the featurization (e.g., of molecules from smiles to pyg-compatible format) is embedded in the pipeline and will be performed when executing `graphium-train [...]`.
However, when working with larger datasets, it is recommended to perform data preparation in advance using a machine with sufficient allocated memory (e.g., ~400GB in the case of `LargeMix`). Preparing data in advance is also beneficial when running lots of concurrent jobs with identical molecular featurization, so that resources aren't wasted and processes don't conflict reading/writing in the same directory.
The following command-line will prepare the data and cache it, then use it to train a model.
```bash
# First prepare the data and cache it in `path_to_cached_data`
graphium data prepare ++datamodule.args.processed_graph_data_path=[path_to_cached_data]
# Then train the model on the prepared data
graphium-train [...] datamodule.args.processed_graph_data_path=[path_to_cached_data]
```
??? note "Config vs. Override"
As with any configuration, note that `datamodule.args.processed_graph_data_path` can also be specified in the configs at `expts/hydra_configs/`.
??? note "Featurization"
Every time the configs of `datamodule.args.featurization` change, you will need to run a new data preparation, which will automatically be saved in a separate directory that uses a hash unique to the configs.
================================================
FILE: docs/cli/graphium.md
================================================
# `graphium`
**Usage**:
```console
$ graphium [OPTIONS] COMMAND [ARGS]...
```
**Options**:
* `--help`: Show this message and exit.
**Commands**:
* `data`: Graphium datasets.
* `finetune`: Utility CLI for extra fine-tuning utilities.
## `graphium data`
Graphium datasets.
**Usage**:
```console
$ graphium data [OPTIONS] COMMAND [ARGS]...
```
**Options**:
* `--help`: Show this message and exit.
**Commands**:
* `download`: Download a Graphium dataset.
* `list`: List available Graphium dataset.
* `prepare`: Prepare a Graphium dataset.
### `graphium data download`
Download a Graphium dataset.
**Usage**:
```console
$ graphium data download [OPTIONS] NAME OUTPUT
```
**Arguments**:
* `NAME`: [required]
* `OUTPUT`: [required]
**Options**:
* `--progress / --no-progress`: [default: progress]
* `--help`: Show this message and exit.
### `graphium data list`
List available Graphium dataset.
**Usage**:
```console
$ graphium data list [OPTIONS]
```
**Options**:
* `--help`: Show this message and exit.
### `graphium data prepare`
Prepare a Graphium dataset.
**Usage**:
```console
$ graphium data prepare [OPTIONS] OVERRIDES...
```
**Arguments**:
* `OVERRIDES...`: [required]
**Options**:
* `--help`: Show this message and exit.
## `graphium finetune`
Utility CLI for extra fine-tuning utilities.
**Usage**:
```console
$ graphium finetune [OPTIONS] COMMAND [ARGS]...
```
**Options**:
* `--help`: Show this message and exit.
**Commands**:
* `admet`: Utility CLI to easily fine-tune a model on...
* `fingerprint`: Endpoint for getting fingerprints from a...
### `graphium finetune admet`
Utility CLI to easily fine-tune a model on (a subset of) the benchmarks in the TDC ADMET group.
A major limitation is that we cannot use all features of the Hydra CLI, such as multiruns.
**Usage**:
```console
$ graphium finetune admet [OPTIONS] OVERRIDES...
```
**Arguments**:
* `OVERRIDES...`: [required]
**Options**:
* `--name TEXT`
* `--inclusive-filter / --no-inclusive-filter`: [default: inclusive-filter]
* `--help`: Show this message and exit.
### `graphium finetune fingerprint`
Endpoint for getting fingerprints from a pretrained model.
The pretrained model should be a `.ckpt` path or pre-specified, named model within Graphium.
The fingerprint layer specification should be of the format `module:layer`.
If specified as a list, the fingerprints from all the specified layers will be concatenated.
See the docs of the `graphium.finetuning.fingerprinting.Fingerprinter` class for more info.
**Usage**:
```console
$ graphium finetune fingerprint [OPTIONS] FINGERPRINT_LAYER_SPEC... PRETRAINED_MODEL SAVE_DESTINATION
```
**Arguments**:
* `FINGERPRINT_LAYER_SPEC...`: [required]
* `PRETRAINED_MODEL`: [required]
* `SAVE_DESTINATION`: [required]
**Options**:
* `--output-type TEXT`: Either numpy (.npy) or torch (.pt) output [default: torch]
* `-o, --override TEXT`: Hydra overrides
* `--help`: Show this message and exit.
================================================
FILE: docs/cli/reference.md
================================================
# CLI Reference
Installing the Graphium library, makes two CLI tools available.
- [`graphium-train`](./graphium-train.md) is the hydra endpoint - specifically meant for training, finetuning and testing. Since this uses `@hydra.main`, it has access to all advanced hydra functionality such as tab completion, multirun, working directory management, logging management.
- [`graphium`](./graphium.md) is the more general CLI endpoint, organized with various sub commands.
Ideally, we would've integrated both in a single CLI endpoint, but the hydra CLI cannot be a subcommand of another CLI, nor does it support easily adding subcommands, which is why provide two separate CLI tools with different purposes.
!!! note "Interactive, embedded CLI docs with `--help`"
In addition to these pages, you can also use `graphium --help` and `graphium-train --help` to interactively navigate the documentation of these tools directly in the CLI.
================================================
FILE: docs/contribute.md
================================================
# Contribute
We are happy to see that you want to contribute 🤗.
Feel free to open an issue or pull request at any time. But first, follow this page to install Graphium in dev mode.
## Installation for developers
### For CPU and GPU developers
Use [`mamba`](https://github.com/mamba-org/mamba), a preferred alternative to conda, to create your environment:
```bash
# Install Graphium's dependencies in a new environment named `graphium`
mamba env create -f env.yml -n graphium
# Install Graphium in dev mode
mamba activate graphium
pip install --no-deps -e .
```
### For IPU developers
Download the SDK and use pypi to create your environment:
```bash
# Install Graphcore's SDK and Graphium dependencies in a new environment called `.graphium_ipu`
./install_ipu.sh .graphium_ipu
```
The above step needs to be done once. After that, enable the SDK and the environment as follows:
```bash
source enable_ipu.sh .graphium_ipu
```
## Build the documentation
You can build and serve the documentation locally with:
```bash
# Build and serve the doc
mkdocs serve
```
================================================
FILE: docs/datasets.md
================================================
# Graphium Datasets
Graphium datasets are hosted at on Zenodo
- ***ToyMix*** and ***LargeMix*** dataseets are hosted on [this link](https://doi.org/10.5281/zenodo.7998401)
- ***UltraLarge*** dataset is hosted on [this link](https://doi.org/10.5281/zenodo.8370547)
Instead of provinding datasets as a single entity, our aim is to provide dataset mixes containing a variety of datasets that are meant to be predicted simultaneously using multi-tasking.
They are visually described in this image, with detailed description below.

## ToyMix (QM9 + Tox21 + Zinc12K)
The ***ToyMix*** dataset combines the ***QM9***, ***Tox21***, and ***Zinc12K*** datasets. These datasets are well-known in the literature and used as toy datasets, or very simple datasets, in various contexts to enable fast iterations of models. By regrouping toy datasets from quantum ML, drug discovery, and GNN expressivity, we hope that the learned model will be representative of the model performance we can expect on the larger datasets.
### Train/Validation/Test Splits
for all the datasets in ***ToyMix*** are split randomly with a ratio of 0.8/0.1/0.1. Random splitting is used since it is the simplest and fits the idea of having a toy dataset well.
### QM9
is a well-known dataset in the field of 3D GNNs. It consists of 19 graph-level quantum properties associated to an energy-minimized 3D conformation of the molecules [1]. It is considered a simple dataset since all the molecules have at most 9 heavy atoms. We chose QM9 in our ***ToyMix*** since it is very similar to the larger proposed quantum datasets, PCQM4M\_multitask and PM6\_83M, but with smaller molecules.
### Tox21
is a well-known dataset for researchers in machine learning for drug discovery [2]. It consists of a multi-label classification task with 12 labels, with most labels missing and a strong imbalance towards the negative class. We chose ***Tox21*** in our ***ToyMix*** since it is very similar to the larger proposed bioassay dataset, ***PCBA\_1328\_1564k*** both in terms of sparsity and imbalance and to the ***L1000*** datasets in terms of imbalance.
### ZINC12k
is a well-known dataset for researchers in GNN expressivity [3]. We include it in our ***ToyMix*** since GNN expressivity is very important for performance on large-scale data. Hence, we hope that the performance on this task will correlate well with the performance when scaling.
## LargeMix (PCQM4M + PCBA1328 + L1000)
In this section, we present the ***LargeMix*** dataset, comprised of four different datasets with tasks taken from quantum chemistry (***PCQM4M***), bio-assays (***PCBA***) and transcriptomics.
### Train/validation/test/test\_seen
Splits For the ***PCQM4M\_G25\_N4***, we create a 0.92/0.04/0.04 split. Then, for all the other datasets in ***LargeMix***, we first create a "test\_seen" split by taking the set of molecules from ***L1000*** and ***PCBA1328*** that are also present in the training set of ***PCQM4M\_G25\_N4***, such that we can evaluate whether having the quantum properties of a molecule helps generalize for biological properties. For the remaining parts, we split randomly with a ratio of 0.92/0.04/0.04.
### L1000 VCAP and MCF7
The ***LINCS L1000*** is a database of high-throughput transcriptomics that screened more than 30,000 perturbations on a set of 978 landmark genes [4] from multiple cell lines. ***VCAP*** and ***MCF7*** are, respectively, prostate cancer and human breast cancer cell lines. In ***L1000***, most of the perturbagens are chemical, meaning that small drug-like molecules are added to the cell lines to observe how the gene expressions change. This allows to generate biological signatures of the molecules, which are known to correlate with drug activity and side effects.
To process the data into our two datasets comprising the ***VCAP*** and ***MCF7*** cell lines, we used their "level 5" data composed of the cleanup data converted to z-scores, and filtered to keep only chemical perturbagens. However, we were left with multiple data points per molecule since some variables could change (e.g., incubation time) and generate a new measure. Given our objective of generating a single signature per molecule, we decided to take the measurements with the strongest global activity such that the variance over the 978 genes is maximal. Then, since these signatures are generally noisy, we binned them into five classes corresponding to z-scores based on the thresholds $\{-4, -2, 2, 4\}$.
The cell lines ***VCAP*** and ***MCF7*** were selected since they have a higher number of unique molecule perturbagens than other cell lines. They also have a relatively lower data imbalance, with ~92% falling in the "neutral class" when the z-score was between -2 and 2.
### PCBA1328
This dataset is very similar to the ***OGBG-PCBA*** dataset [5], but instead of being limited to 128 assays and 437k molecules, it comprises 1,328 assays and 1.56M molecules. This dataset is very interesting for pre-training molecular models since it contains information about a molecule's behavior in various settings relevant to biochemists, with evidence that it improves binding predictions. Analogous to the gene expression, we obtain a bio-assay-expression of each molecule.
To gather the data, we have looped over the PubChem index of bioassays [6] and collected every dataset such that it contains more than 6,000 molecules annotated with either ``Active'' or ``Inactive'' and at least 10 of each. Then, we converted all the molecular IDs to canonical SMILES and used it to merge all of the bioassays into a single dataset.
### PCQM4M\_G25\_N4
This dataset comes from the same data source as the ***OGBG-PCQM4M*** dataset, famously known for being part of the OGB large-scale challenge [7] and being one of the only graph datasets where pure Transformers have proven successful. The data source is the PubChemQC project [8] that computed DFT properties on the energy-minimized conformation of 3.8M small molecules from PubChem.
Contrarily to the OGB challenge, we aim to provide enough data for pre-training GNNs, so we do not limit ourselves to the HOMO-LUMO gap prediction [7]. Instead, we gather properties directly given by the DFT (e.g., energies) and compute other 3D descriptors from the conformation (e.g., inertia, the plane of best fit). We also gather node-level properties, the Mulliken and Lowdin charges at each atom. Furthermore, about half of the molecules have time-dependent DFT to help inform about the molecule's excited state. Looking forward, we plan on adding edge-level tasks to enable the prediction of bond properties, such as their lengths and the gradient of the charges.
## UltraLarge Dataset
### PM6\_83M
This dataset is similar to the ***PCQM4M*** and comes from the same PubChemQC project. However, it uses the PM6 semi-empirical computation of the quantum properties, which is orders of magnitude faster than DFT computation at the expense of less accuracy [8, 9].
This dataset covers 83M unique molecules, 62 graph-level tasks, and 7 node-level tasks. To our knowledge, this is the largest dataset available for training 2D-GNNs regarding the number of unique molecules. The various tasks come from four different molecular states, namely ``S0'' for the ground state, ``T0'' for the lowest energy triplet excited state, ``cation'' for the positively charged state, and ``anion'' for the negatively charged state. In total, there are 221M PM6 computations.
## References
[1] https://www.nature.com/articles/sdata201422/
[2] https://europepmc.org/article/MED/23603828
[3] https://arxiv.org/abs/2003.00982v3
[4] https://pubmed.ncbi.nlm.nih.gov/29195078/
[5] https://arxiv.org/abs/2005.00687
[6] https://pubmed.ncbi.nlm.nih.gov/26400175/
[7] https://arxiv.org/abs/2103.09430
[8] https://pubs.acs.org/doi/10.1021/acs.jcim.7b00083
[9] https://arxiv.org/abs/1904.06046
================================================
FILE: docs/design.md
================================================
# Graphium Library Design
---
The library is designed with 3 things in mind:
- High modularity and configurability with *YAML* files
- Contain the state-of-the art GNNs, including positional encodings and graph Transformers
- Massively multitasking across diverse and sparse datasets
The current page will walk you through the different aspects of the design that enable that.
### Diagram for data processing in Graphium.
First, when working with molecules, there are tons of options regarding atomic and bond featurisation that can be extracted from the periodic table, from empirical results, or from simulated 3D structures.
Second, when working with graph Transformers, there are plenty of options regarding the positional and structural encodings (PSE) that are fundamental in driving the accuracy and the generalization of the models.
With this in mind, we propose a very versatile chemical and PSE encoding, alongside an encoder manager, that can be fully configured from the yaml files. The idea is to assign matching *input keys* to both the features and the encoders, then pool the outputs according to the *output keys*. It is better summarized in the image below.
### Diagram for Muti-task network in Graphium
As mentioned, we want to be able to pperform massive multi-tasking to enable us to work across a huge diversity of datasets. The idea is to use a combination of multiple task-heads, where a different MLP is applied to each task predictions. However, it is also designed such that each task can have as many labels as desired, thus enabling to group labels together according to whether they should share weights/losses.
The design is better explained in the image below. Notice how the *keys* outputed by GraphDict are used differently across the different GNN layers.
## Structure of the code
The code is built to rapidly iterate on different architectures of neural networks (NN) and graph neural networks (GNN) with Pytorch. The main focus of this work is molecular tasks, and we use the package `rdkit` to transform molecular SMILES into graphs.
Below are a list of directory and their respective documentations:
- cli
- [config](https://github.com/datamol-io/graphium/blob/main/graphium/config/README.md)
- [data](https://github.com/datamol-io/graphium/blob/main/graphium/data/README.md)
- [features](https://github.com/datamol-io/graphium/tree/main/graphium/features/README.md)
- finetuning
- [ipu](https://github.com/datamol-io/graphium/tree/main/graphium/ipu/README.md)
- [nn](https://github.com/datamol-io/graphium/tree/main/graphium/nn/README.md)
- [trainer](https://github.com/datamol-io/graphium/tree/main/graphium/trainer/README.md)
- [utils](https://github.com/datamol-io/graphium/tree/main/graphium/features/README.md)
## Structure of the configs
Making the library very modular requires to have configuration files that have >200 lines, which becomes intractable, especially when we only want to have minor changes between configurations.
Hence, we use [hydra](https://hydra.cc/docs/intro/) to enable splitting the configurations into smaller and composable configuration files.
Examples of possibilities include:
- Switching between accelerators (CPU, GPU and IPU)
- Benchmarking different models on the same dataset
- Fine-tuning a pre-trained model on a new dataset
[In this document](https://github.com/datamol-io/graphium/tree/main/expts/hydra-configs#readme), we describe in details how each of the above functionality is achieved and how users can benefit from this design to achieve the most with Graphium with as little configuration as possible.
================================================
FILE: docs/index.md
================================================
# Overview
A deep learning library focused on graph representation learning for real-world chemical tasks.
- ✅ State-of-the-art GNN architectures.
- 🐍 Extensible API: build your own GNN model and train it with ease.
- ⚗️ Rich featurization: powerful and flexible built-in molecular featurization.
- 🧠 Pretrained models: for fast and easy inference or transfer learning.
- ⮔ Read-to-use training loop based on [Pytorch Lightning](https://www.pytorchlightning.ai/).
- 🔌 Have a new dataset? Graphium provides a simple plug-and-play interface. Change the path, the name of the columns to predict, the atomic featurization, and you’re ready to play!
## Installation
### For CPU or GPU
Use [`mamba`](https://github.com/mamba-org/mamba):
```bash
# Install Graphium
mamba install -c conda-forge graphium
```
or pip:
```bash
pip install graphium
```
### For IPU
```bash
# Install Graphcore's SDK and Graphium dependencies in a new environment called `.graphium_ipu`
./install_ipu.sh .graphium_ipu
```
The above step needs to be done once. After that, enable the SDK and the environment as follows:
```bash
source enable_ipu.sh .graphium_ipu
```
Finally, you will need to install graphium with pip
```bash
pip install graphium
```
================================================
FILE: docs/license.md
================================================
```
{!LICENSE!}
```
================================================
FILE: docs/pretrained_models.md
================================================
# Graphium pretrained models
Graphium aims to provide a set of pretrained models that you can use for inference or transfer learning. The models will be made available once trained and validated.
## Listing all available models
To know which models are available, you can run the following command
```python
import graphium
print(graphium.trainer.PredictorModule.list_pretrained_models())
```
## Dummy pre-trained model
At the moment, only `tests/dummy-pretrained-model.ckpt` is provided, which is mostly useful for development and debugging of the checkpointing and finetuning pipelines.
You can load a pretrained models using the Graphium API:
```python
import graphium
predictor = graphium.trainer.PredictorModule.load_pretrained_models("dummy-pretrained-model")
```
================================================
FILE: docs/tutorials/feature_processing/add_new_positional_encoding.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Add new positional encoding\n",
"\n",
"One of the main advantage of this library is the ability to easily incorporate novel positional encodings on the node, edge and graph level. The positional encodings are computed and feed into respective encoders and then the hidden embeddings from all pe encoders are pooled (according to if they are node, edge, or graph level) and then feed into the GNN layers as features. The designs allow any combination of positional encodings to be used by modifying the configuration file. For more details on the data processing part, please visit the [design page of the doc](https://graphium-docs.datamol.io/stable/design.html).\n",
"\n",
"Here is the workflow for computing and processing positional encoding in the library:\n",
"1. edit related parts in the yaml configuration file\n",
"\n",
"2. compute the raw positional encoding from the graph in [`graphium/features/positional_encoding.py`](https://graphium-docs.datamol.io/stable/api/graphium.features.html#graphium.features.positional_encoding) (from the [`graph positional encoder`](https://graphium-docs.datamol.io/stable/api/graphium.features.html#graphium.features.positional_encoding.graph_positional_encoder))\n",
"\n",
"3. feed the raw positional encoding into the respective (specialized) encoders in [`graphium/nn/encoders`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html). For example, a simple [`MLP positional encoder`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html#graphium.nn.encoders.mlp_encoder) can be found. \n",
"\n",
"4. Output the hidden embeddings of pe from the encoders in their respective output keys: `feat`(node feature), `edge_feat`(edge feature), `graph_feat`(graph feature) and potentially other keys if needed such as `nodepair_feat` \n",
"\n",
"5. pool the hidden embeddings with same keys together: for example, all output with `feat` key will be pooled together\n",
"\n",
"6. Construct the [`PyG Batch`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Batch.html#torch_geometric.data.Batch), batch of graphs, each contain the output keys seen above, ready for use in the [GNN layers](https://graphium-docs.datamol.io/stable/api/graphium.nn/pyg_layers.html) \n",
"\n",
"Since this library is built using PyG, we recommend looking at their [Docs](https://pytorch-geometric.readthedocs.io/en/latest/) and [Tutorials](https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html) for more info. \n",
"\n",
"We start by editing the configuration file first.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Table of content\n",
"1. [edit the config file](#Edit-the-yaml-Configuration-File)\n",
"2. [compute the pe from mol](#Compute-the-Positional-Encoding)\n",
"3. [add existing encoder](#Add-Existing-Encoder)\n",
"4. [add specialized encoder](#Add-Specialized-Encoder)\n",
"5. [add the keys to spaces](#Add-the-Keys-to-Spaces)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Edit the yaml Configuration File\n",
"\n",
"### Computing Raw PE\n",
"We will use the degree of each node as a positional encoding in this tutorial. \n",
"First start with an existing yaml configuration file, you can find them in `expts/configs`\n",
"\n",
"We first look at where in the yaml file is the raw positional encodings computed. `deg_pos` is added as an example below. You can add relevant arguments for computing the positional encoding here as well such as `normalize` in the example."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```\n",
"pos_encoding_as_features:\n",
" pos_types:\n",
" deg_pos: #example, degree centrality\n",
" pos_type: degree\n",
" normalize: False\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Specifying Encoders for the PE\n",
"Now we want to specify arguments for the encoders associated with the pe"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```\n",
"pe_encoders:\n",
" out_dim: 64\n",
" pool: \"sum\" #choice of pooling across multiple pe encoders\n",
" last_norm: None #\"batch_norm\", \"layer_norm\"\n",
" encoders: \n",
" deg_pos: #same name from the previous cell\n",
" encoder_type: \"mlp\" #or you can specify your own specialized encoder\n",
" input_keys: [\"degree\"] #same as the pos_type configured before\n",
" output_keys: [\"feat\"] #node feature\n",
" hidden_dim: 64\n",
" num_layers: 1\n",
" dropout: 0.1\n",
" normalization: \"none\" #\"batch_norm\" or \"layer_norm\"\n",
" first_normalization: \"layer_norm\" #\"batch_norm\" or \"layer_norm\"\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Compute the Positional Encoding\n",
"Next, we want to compute the raw degree of each node from the molecule graph.\n",
"\n",
"### add function to compute the pe\n",
"Go to [graphium/features](https://graphium-docs.datamol.io/stable/api/graphium.features.html) and add a new file `deg.py` to add the function to compute the pe. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from typing import Tuple, Union, Optional\n",
"\n",
"from scipy import sparse\n",
"from scipy.sparse import spmatrix\n",
"import numpy as np\n",
"\n",
"def compute_deg(adj: Union[np.ndarray, spmatrix], normalize: bool) -> np.ndarray:\n",
" \"\"\"\n",
" Compute the node degree positional encoding \n",
"\n",
" Parameters:\n",
" adj: Adjacency matrix\n",
" normalize: indicate if the degree across all nodes are normalized to [0,1] or not\n",
" Returns:\n",
" 2D array with shape (num_nodes, 1) specifying (outgoing) degree for each node\n",
" \"\"\"\n",
" \n",
" #first adj convert to scipy sparse matrix if not already\n",
" if type(adj) is np.ndarray:\n",
" adj = sparse.csr_matrix(adj)\n",
" \n",
" #https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.sum.html\n",
" degs = adj.sum(axis=0) #sum over each row\n",
" \n",
" if (normalize): #normalize the degree sequence to [0,1]\n",
" degs = degs / np.max(degs)\n",
" return degs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Test with toy matrix\n",
"\n",
"here we will test if our code compute the degrees of each node correctly"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"matrix([[1., 1., 1., 1., 1.]])"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"adj = np.identity(5) #make an identity matrix\n",
"normalize = True\n",
"\n",
"degs = compute_deg(adj, normalize=normalize)\n",
"\n",
"degs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### add to positional_encoding.py\n",
"\n",
"To compute the new pe along with all existing pe, we need to add the function we wrote to [`graphium/feature/positional_encoding.py`](https://graphium-docs.datamol.io/stable/api/graphium.features.html#graphium.features.positional_encoding). Modify the [`graph_positional_encoder`](https://graphium-docs.datamol.io/stable/api/graphium.features.html#graphium.features.positional_encoding.graph_positional_encoder) function by adding `pos_type == \"degree\"` logic "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Add Existing Encoder\n",
"\n",
"In order to pool over all the positional encodings, we need to add encoder to process the raw computed positional encoding and ensure the output dimension from all pe encoders are the same. When designing the encoder, you can either use an existing encoder or write a specialized encoder you made\n",
"\n",
"here we can simply specify [`MLPEncoder`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html#graphium.nn.encoders.mlp_encoder.MLPEncoder) in the yaml file and the library will automatically feed the raw positional encoding to a mlp encoder based on the input arguments. Note that in this example, the encoder takes in the pe stored at the input key `degree` and then outputs to the output key `feat`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```\n",
"encoders: \n",
" deg_pos: \n",
" encoder_type: \"mlp\" \n",
" input_keys: [\"degree\"] \n",
" output_keys: [\"feat\"] # node feature\n",
" hidden_dim: 64\n",
" num_layers: 1\n",
" dropout: 0.1\n",
" normalization: \"none\" #\"batch_norm\" or \"layer_norm\"\n",
" first_normalization: \"layer_norm\" #\"batch_norm\" or \"layer_norm\"\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Add Specialized Encoder\n",
"\n",
"You can also add specialized encoder, such as `laplacian_pe` for the laplacian eigenvectors and eigenvalues. Here, we can add a new `deg_pos_encoder.py` in [`graphium/nn/encoders`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html). As an example and template, please see the [`MLPEncoder`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html#graphium.nn.encoders.mlp_encoder.MLPEncoder)\n",
"\n",
"Note that all new encoders must inherent from [`BaseEncoder`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html#graphium.nn.encoders.base_encoder.BaseEncoder) class and implement the following abstract methods\n",
"\n",
"- [`forward`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html#graphium.nn.encoders.base_encoder.BaseEncoder.forward): the forward function of the encoder, how to process the input\n",
"\n",
"- [`parse_input_keys`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html#graphium.nn.encoders.base_encoder.BaseEncoder.parse_input_keys): how to parse the input keys\n",
"\n",
"- [`parse_output_keys`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html#graphium.nn.encoders.laplace_pos_encoder.LapPENodeEncoder.parse_output_keys): how to parse the output keys\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Add the Keys to Spaces\n",
"\n",
"In order to directly find the correct encoders from the yaml file, we need to specify which key corresponding to what class. \n",
"\n",
"- add our new `deg_pos_encoder` to `graphium/utils/spaces.py` in the `PE_ENCODERS_DICT`\n",
"- add our new `deg_pos_encoder` to [`graphium/nn/architectures/encoder_manager.py`](https://graphium-docs.datamol.io/stable/api/graphium.nn/architectures.html#graphium.nn.architectures.encoder_manager.EncoderManager) in the `PE_ENCODERS_DICT`\n",
"- add the import of our encoder to `graphium/nn/encoders/__init__.py`\n",
"\n",
"Now we can modify the yaml file to use our new encoder"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```\n",
"encoders: \n",
" deg_pos: \n",
" encoder_type: \"deg_pos_encoder\" \n",
" input_keys: [\"degree\"] \n",
" output_keys: [\"feat\"] # node feature\n",
" hidden_dim: 64\n",
" #any other keys that might be used for initialization\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "f4a99d018a205fcbcc0480c84566beaebcb91b08d0414b39a842df533e2a1d25"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: docs/tutorials/feature_processing/choosing_parallelization.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "0693c5c9",
"metadata": {},
"source": [
"# Choosing the right parallelization\n",
"\n",
"This tutorial is meant to help you benchmark different parallel processing methods for the processing of molecules into graphs. This will allow you to chose the one most suitable for your machine, since the benchmarks vary per machine.\n",
"\n",
"In general, we find that using `joblib` with the `loky` parallel processing and a batch size of `1000` is most beneficial. The logic is abstracted into `datamol.parallelized_with_batches`"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b5df2ac6-2ded-4597-a445-f2b5fb106330",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import joblib\n",
"\n",
"import numpy as np\n",
"import datamol as dm\n",
"import pandas as pd\n",
"\n",
"# from pandarallel import pandarallel\n",
"\n",
"# pandarallel.initialize(progress_bar=True, nb_workers=joblib.cpu_count())"
]
},
{
"cell_type": "markdown",
"id": "f81fe0f5-055f-436e-9ef4-e58a89fd50ec",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0f31e18d-bdd9-4d9b-8ba5-81e5887b857e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# download from https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv\n",
"# data = pd.read_csv(\"/home/hadim/250k_rndm_zinc_drugs_clean_3.csv\", usecols=[\"smiles\"])\n",
"\n",
"# download from https://storage.valencelabs.com/graphium/datasets/QM9/norm_qm9.csv\n",
"data = pd.read_csv(\"https://storage.valencelabs.com/graphium/datasets/QM9/norm_qm9.csv\", usecols=[\"smiles\"])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a1197c31-7dbc-4fd7-a69a-5215e1a96b8e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"rows_number_list = [250_000]\n",
"batch_size_list = [10, 100, 1_000, 10_000]\n",
"\n",
"\n",
"def smiles_to_unique_mol_id(smiles):\n",
" try:\n",
" mol = dm.to_mol(mol=smiles)\n",
" mol_id = dm.unique_id(mol)\n",
" except:\n",
" mol_id = \"\"\n",
" if mol_id is None:\n",
" mol_id = \"\"\n",
" return mol_id\n",
"\n",
"\n",
"def smiles_to_unique_mol_id_batch(smiles_list):\n",
" mol_id_list = []\n",
" for smiles in smiles_list:\n",
" mol_id_list.append(smiles_to_unique_mol_id(smiles))\n",
" return mol_id_list"
]
},
{
"cell_type": "markdown",
"id": "5d8eea8f-b983-46e7-bfb4-b8012b3ede1a",
"metadata": {},
"source": [
"## Benchmarks"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2f8ce5c3-4232-4279-8ea3-7a74832303be",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"benchmark = []"
]
},
{
"cell_type": "markdown",
"id": "691f8963-6cac-4d58-b8a0-8049f1185486",
"metadata": {},
"source": [
"### No batch"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a246cdcf-b5ea-4c9e-9ccc-dd3c544587bb",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cc396220c7144c8d8b195fb87694bbfe",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/133885 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for n in rows_number_list:\n",
" df = data.iloc[:n]\n",
"\n",
" with dm.utils.perf.watch_duration(log=False) as d:\n",
" out = dm.parallelized(\n",
" smiles_to_unique_mol_id,\n",
" df[\"smiles\"].values,\n",
" progress=True,\n",
" n_jobs=-1,\n",
" scheduler=\"processes\",\n",
" )\n",
"\n",
" datum = {\n",
" \"batch\": False,\n",
" \"batch_size\": None,\n",
" \"scheduler\": \"loky_processes\",\n",
" \"duration_minutes\": d.duration_minutes,\n",
" \"duration_seconds\": d.duration,\n",
" \"n_rows\": len(df),\n",
" }\n",
" benchmark.append(datum)"
]
},
{
"cell_type": "markdown",
"id": "41139854-892f-4481-8dd6-a3972bb6ece0",
"metadata": {},
"source": [
"### Batch"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "845a72d5-0b3f-4a21-8d7d-8312671e9924",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0187ede3dbd8429995511883319e246f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/13388 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fdd5ac6375b444359f6e8cef7b755b9b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1338 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c11c5894843a46848725d0c03fef9e02",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/133 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bd34e1a806604192bbd6f95ba6435b44",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/13 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for batch_size in batch_size_list:\n",
" for n in rows_number_list:\n",
" df = data.iloc[:n]\n",
"\n",
" with dm.utils.perf.watch_duration(log=False) as d:\n",
" out = dm.parallelized_with_batches(\n",
" smiles_to_unique_mol_id_batch,\n",
" df[\"smiles\"].values,\n",
" batch_size=batch_size,\n",
" progress=True,\n",
" n_jobs=-1,\n",
" scheduler=\"processes\",\n",
" )\n",
" assert len(out) == len(df), f\"{len(out)} != {len(df)}\"\n",
"\n",
" datum = {\n",
" \"batch\": True,\n",
" \"batch_size\": batch_size,\n",
" \"scheduler\": \"loky_processes\",\n",
" \"duration_minutes\": d.duration_minutes,\n",
" \"duration_seconds\": d.duration,\n",
" \"n_rows\": len(df),\n",
" }\n",
" benchmark.append(datum)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ec7529d1-2ca3-42ec-a811-3dc010a03350",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "928c2236948d4a109c79a6bc79bd4649",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=558), Label(value='0 / 558'))), HB…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for n in rows_number_list:\n",
" df = data.iloc[:n]\n",
"\n",
" with dm.utils.perf.watch_duration(log=False) as d:\n",
" _ = df[\"smiles\"].parallel_apply(smiles_to_unique_mol_id)\n",
"\n",
" datum = {\n",
" \"batch\": False,\n",
" \"batch_size\": None,\n",
" \"scheduler\": \"pandarallel\",\n",
" \"duration_minutes\": d.duration_minutes,\n",
" \"duration_seconds\": d.duration,\n",
" \"n_rows\": len(df),\n",
" }\n",
" benchmark.append(datum)"
]
},
{
"cell_type": "markdown",
"id": "1ddf56ef-6349-4d43-9720-9e53699e9a8e",
"metadata": {},
"source": [
"## Results"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "989ae5dd-9826-4adb-af0a-64d9e2c19e2c",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"
"
],
"text/plain": [
" batch batch_size scheduler duration_minutes duration_seconds \n",
"3 True 1000.0 loky_processes 0.014199 0.851930 \\\n",
"2 True 100.0 loky_processes 0.037132 2.227947 \n",
"4 True 10000.0 loky_processes 0.047438 2.846266 \n",
"5 False NaN pandarallel 0.118230 7.093791 \n",
"1 True 10.0 loky_processes 0.222177 13.330603 \n",
"0 False NaN loky_processes 4.002346 240.140754 \n",
"\n",
" n_rows duration_seconds_per_mol \n",
"3 133885 0.000006 \n",
"2 133885 0.000017 \n",
"4 133885 0.000021 \n",
"5 133885 0.000053 \n",
"1 133885 0.000100 \n",
"0 133885 0.001794 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b = pd.DataFrame(benchmark)\n",
"b[\"duration_seconds_per_mol\"] = b[\"duration_seconds\"] / b[\"n_rows\"]\n",
"\n",
"b.sort_values(\"duration_seconds_per_mol\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04c0b10c-bf99-43bf-979a-9d1f0c1ff1fd",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "39b3d77e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "graphium_ipu",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: docs/tutorials/gnn/add_new_gnn_layers.ipynb
================================================
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Creating GNN layers\n",
"\n",
"One of the primary advantage of this library is the fact that GNN layers are independent from model architecture, thus allowing more flexibility with the code by easily swapping different layer types as hyper-parameters. However, this requires that the layers be implemented using the PyG library, and must be inherited from the class [`BaseGraphStructure`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure), which standardizes the inputs, outputs and properties of the layers. Thus, the architecture can be handled independantly using the class [`FeedForwardPyg`](https://graphium-docs.datamol.io/stable/api/graphium.nn/architectures.html#graphium.nn.architectures.pyg_architectures.FeedForwardPyg) or other custom classes.\n",
"\n",
"We will first start by a simple layer that does not use edges, to a more complex layer that uses edges.\n",
"\n",
"Since these examples are built on top of PyG, we recommend looking at their [Docs](https://pytorch-geometric.readthedocs.io/en/latest/) and [Tutorials](https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html) for more info. "
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Table of content\n",
"1. [Define test graph](#Define-synthetic-test-graph)\n",
"2. [Create simple layer](#Creating-a-simple-layer)\n",
"3. [Test simple layer](#Test-our-simple-layer)\n",
"4. [Create complex layer](#Creating-a-complex-layer-with-edges)\n",
"5. [Test complex layer](#Test-our-complex-layer) "
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import torch\n",
"from torch import Tensor\n",
"from torch_geometric.data import Data, Batch\n",
"from torch_geometric.nn.conv import MessagePassing\n",
"from torch_scatter import scatter\n",
"from torch_geometric.typing import (\n",
" OptPairTensor,\n",
" OptTensor\n",
")\n",
"\n",
"from copy import deepcopy\n",
"from typing import Callable, Union, Optional, List\n",
"from graphium.nn.base_graph_layer import BaseGraphStructure\n",
"from graphium.nn.base_layers import FCLayer\n",
"from graphium.utils.decorators import classproperty\n",
"\n",
"_ = torch.manual_seed(42)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define synthetic test graph\n",
"\n",
"We define below a small batched graph on which we can test the created layers. You can also use synthetic graphs to define unit tests "
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataBatch(edge_index=[2, 7], feat=[7, 5], edge_feat=[7, 13], batch=[7], ptr=[3])\n"
]
}
],
"source": [
"in_dim = 5 # Input node-feature dimensions\n",
"out_dim = 11 # Desired output node-feature dimensions\n",
"in_dim_edges = 13 # Input edge-feature dimensions\n",
"\n",
"\n",
"# Let's create 2 simple pyg graphs. \n",
"# start by specifying the edges with edge index\n",
"edge_idx1 = torch.tensor([[0, 1, 2],\n",
" [1, 2, 3]])\n",
"edge_idx2 = torch.tensor([[2, 0, 0, 1],\n",
" [0, 1, 2, 0]])\n",
"\n",
"# specify the node features, convention with variable x\n",
"x1 = torch.randn(edge_idx1.max() + 1, in_dim, dtype=torch.float32)\n",
"x2 = torch.randn(edge_idx2.max() + 1, in_dim, dtype=torch.float32)\n",
"\n",
"# specify the edge features in e\n",
"e1 = torch.randn(edge_idx1.shape[-1], in_dim_edges, dtype=torch.float32)\n",
"e2 = torch.randn(edge_idx2.shape[-1], in_dim_edges, dtype=torch.float32)\n",
"\n",
"# make the pyg graph objects with our constructed features\n",
"g1 = Data(feat=x1, edge_index=edge_idx1, edge_feat=e1)\n",
"g2 = Data(feat=x2, edge_index=edge_idx2, edge_feat=e2)\n",
"\n",
"# put the two graphs into a Batch graph\n",
"bg = Batch.from_data_list([g1, g2])\n",
"\n",
"# The batched graph will show as a single graph with 7 nodes\n",
"print(bg)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Creating a simple layer\n",
"\n",
"Here, we will show how to create a GNN layer that simples does a mean aggregation on the neighbouring features.\n",
"\n",
"First, for the layer to be fully compatible with the flexible architecture provided by [`FeedForwardPyg`](https://graphium-docs.datamol.io/stable/api/graphium.nn/architectures.html#graphium.nn.architectures.pyg_architectures.FeedForwardPyg), it needs to inherit from the class [`BaseGraphStructure`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure). This base-layer has multiple virtual methods that must be implemented in any class that inherits from it.\n",
"\n",
"The virtual methods are below\n",
"\n",
"- [`layer_supports_edges`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure): We want to return `False` since our layer doesn't support edges\n",
"- [`layer_inputs_edges`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure.layer_inputs_edges): We want to return `False` since our layer doesn't input edges\n",
"- [`layer_outputs_edges`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure.layer_outputs_edges): We want to return `False` since our layer doesn't output edges\n",
"- [`out_dim_factor`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure.out_dim_factor): We want to return `1` since the output dimension does not depend on internal parameters.\n",
"\n",
"The example is given below"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"# inherit from message passing class from pyg \n",
"# inherit from BaseGraphStructure\n",
"# this example is also based of : https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/simple_conv.html#SimpleConv.forward\n",
"\n",
"class SimpleMeanLayer(MessagePassing, BaseGraphStructure):\n",
" def __init__(self, \n",
" in_dim: int, \n",
" out_dim: int, \n",
" activation: Union[Callable, str] = \"relu\", \n",
" dropout: float = 0.0, \n",
" normalization: Union[str, Callable] = \"none\",\n",
" aggr: str = \"mean\",\n",
" ):\n",
" r\"\"\"\n",
" add documentation in the format shown here for this layer to be automatically added to the graphium api reference\n",
" the type information will also be automatically shown in the doc\n",
" \n",
" Parameters:\n",
"\n",
" in_dim:\n",
" Input feature dimensions of the layer\n",
"\n",
" out_dim:\n",
" Output feature dimensions of the layer\n",
"\n",
" activation:\n",
" activation function to use in the layer\n",
"\n",
" dropout:\n",
" The ratio of units to dropout. Must be between 0 and 1\n",
"\n",
" normalization:\n",
" Normalization to use. Choices:\n",
"\n",
" - \"none\" or `None`: No normalization\n",
" - \"batch_norm\": Batch normalization\n",
" - \"layer_norm\": Layer normalization\n",
" - `Callable`: Any callable function\n",
" aggr:\n",
" what aggregation to use (\"add\", \"mean\" or \"max\")\n",
" \"\"\"\n",
" # Initialize the parent class\n",
" MessagePassing.__init__(self, node_dim=0, aggr=aggr)\n",
" BaseGraphStructure.__init__(self,\n",
" in_dim=in_dim, \n",
" out_dim=out_dim, \n",
" activation=activation,\n",
" dropout=dropout, \n",
" normalization=normalization)\n",
" \n",
" self.aggr = aggr\n",
" self._initialize_activation_dropout_norm()\n",
" # Create the mlp layer \n",
" # https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_layers.MLP\n",
" self.transform = FCLayer(in_dim=in_dim, out_dim=out_dim)\n",
" \n",
" # define the forward function \n",
" def forward(self, \n",
" batch: Union[Data, Batch],\n",
" ) -> Union[Data, Batch]:\n",
" r\"\"\"\n",
" similarly add documentation to the functions in the class\n",
" \n",
" Parameters:\n",
" batch: pyg Batch graphs to pass through the layer\n",
" Returns:\n",
" batch: pyg Batch graphs\n",
" \"\"\"\n",
" \n",
" x = batch.feat\n",
" edge_index = batch.edge_index\n",
" \n",
" if isinstance(x, Tensor):\n",
" x: OptPairTensor = (x, x)\n",
"\n",
" # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)\n",
" out = self.propagate(edge_index, x=x)\n",
" out = self.transform(out)\n",
" out = self.apply_norm_activation_dropout(out, batch_idx=batch.batch)\n",
" batch.feat = out\n",
" return batch\n",
" \n",
" def message(self, x_j):\n",
"\n",
" return x_j\n",
"\n",
" # Finally, we define all the virtual properties according to how the class works with proper documentation of course\n",
" @classproperty\n",
" def layer_supports_edges(cls) -> bool:\n",
" r\"\"\"\n",
" Return a boolean specifying if the layer type supports edges or not.\n",
"\n",
" Returns:\n",
"\n",
" bool:\n",
" False for the current class\n",
" \"\"\"\n",
" return False\n",
"\n",
" @property\n",
" def layer_inputs_edges(self) -> bool:\n",
" r\"\"\"\n",
" Returns:\n",
"\n",
" bool:\n",
" Returns False\n",
" \"\"\"\n",
" return False\n",
" \n",
" @property\n",
" def layer_outputs_edges(self) -> bool:\n",
" r\"\"\"\n",
" Returns:\n",
"\n",
" bool:\n",
" Always ``False`` for the current class\n",
" \"\"\"\n",
" return False\n",
"\n",
" @property\n",
" def out_dim_factor(self) -> int:\n",
" r\"\"\"\n",
" Get the factor by which the output dimension is multiplied for\n",
" the next layer.\n",
"\n",
" For standard layers, this will return ``1``.\n",
" Returns:\n",
"\n",
" int:\n",
" Always ``1`` for the current class\n",
" \"\"\"\n",
" return 1"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test our simple layer\n",
"\n",
"Now, we are ready to test the `SimpleMeanLayer` on our constructed graphs. Note that in this example, we **ignore** the edge features since they are not supported."
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([7, 5])\n",
"torch.Size([7, 11])\n"
]
}
],
"source": [
"graph = deepcopy(bg)\n",
"print(graph.feat.shape)\n",
"\n",
"layer = SimpleMeanLayer(\n",
" in_dim=in_dim, \n",
" out_dim=out_dim, \n",
" activation=\"relu\", \n",
" dropout=.3, \n",
" normalization=\"batch_norm\")\n",
"\n",
"graph = layer(graph)\n",
"print(graph.feat.shape)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Creating a complex layer with edges\n",
"\n",
"Here, we will show how to create a GNN layer that does a mean aggregation on the neighbouring features, concatenated to the edge features with their neighbours. In that case, only the node features will change, and the network will not update the edge features.\n",
"\n",
"The virtual methods will have different outputs\n",
"\n",
"- [`layer_supports_edges`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure): We want to return `True` since our layer does support edges\n",
"- [`layer_inputs_edges`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure.layer_inputs_edges): We want to return `True` since our layer does input edges\n",
"- [`layer_outputs_edges`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure.layer_outputs_edges): We want to return `False` since our layer will not output new edges\n",
"- [`out_dim_factor`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure.out_dim_factor): We want to return `1` since the output dimension does not depend on internal parameters.\n",
"\n",
"The example is given below"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [],
"source": [
"# inherent from message passing class from pyg \n",
"# inherent from BaseGraphStructure\n",
"# adapting example from graphium/nn/pyg_layers/png_pyg.py\n",
"\n",
"class ComplexMeanLayer(MessagePassing, BaseGraphStructure):\n",
" def __init__(self, \n",
" in_dim: int, \n",
" out_dim: int, \n",
" in_dim_edges: int, \n",
" activation: Union[Callable, str] = \"relu\", \n",
" dropout: float = 0.0, \n",
" normalization: Union[str, Callable] = \"none\",\n",
" aggr: str = \"mean\",\n",
" ):\n",
" r\"\"\"\n",
" add documentation in the format shown here for this layer to be automatically added to the graphium api reference\n",
" the type information will also be automatically shown in the doc\n",
" \n",
" Parameters:\n",
"\n",
" in_dim:\n",
" Input feature dimensions of the layer\n",
"\n",
" out_dim:\n",
" Output feature dimensions of the layer\n",
" \n",
" in_dim_edges:\n",
" Input edge feature dimension of the layer\n",
"\n",
" activation:\n",
" activation function to use in the layer\n",
"\n",
" dropout:\n",
" The ratio of units to dropout. Must be between 0 and 1\n",
"\n",
" normalization:\n",
" Normalization to use. Choices:\n",
"\n",
" - \"none\" or `None`: No normalization\n",
" - \"batch_norm\": Batch normalization\n",
" - \"layer_norm\": Layer normalization\n",
" - `Callable`: Any callable function\n",
" aggr:\n",
" what aggregation to use (\"add\", \"mean\" or \"max\")\n",
" \"\"\"\n",
" # Initialize the parent class\n",
" MessagePassing.__init__(self, node_dim=0, aggr=aggr)\n",
" BaseGraphStructure.__init__(self,\n",
" in_dim=in_dim, \n",
" out_dim=out_dim, \n",
" activation=activation,\n",
" dropout=dropout, \n",
" normalization=normalization)\n",
" \n",
" self.aggr = aggr\n",
" self._initialize_activation_dropout_norm()\n",
" # Create the mlp layer \n",
" # https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_layers.MLP\n",
" self.transform = FCLayer(in_dim=(in_dim + in_dim_edges), out_dim=out_dim)\n",
" \n",
" # define the forward function \n",
" def forward(self, \n",
" batch: Union[Data, Batch],\n",
" ) -> Union[Data, Batch]:\n",
" r\"\"\"\n",
" similarly add documentation to the functions in the class\n",
" \n",
" Parameters:\n",
" batch: pyg Batch graphs to pass through the layer\n",
" Returns:\n",
" batch: pyg Batch graphs\n",
" \"\"\" \n",
" x = batch.feat\n",
" edge_index = batch.edge_index\n",
" edge_feat = batch.edge_feat\n",
" \n",
" if isinstance(x, Tensor):\n",
" x: OptPairTensor = (x, x)\n",
"\n",
" # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)\n",
" out = self.propagate(edge_index, x=x, edge_feat=edge_feat, size=None)\n",
" out = self.transform(out)\n",
" out = self.apply_norm_activation_dropout(out, batch_idx=batch.batch)\n",
" batch.feat = out\n",
" return batch\n",
" \n",
" def message(self, x_j: Tensor, edge_feat: OptTensor) -> Tensor:\n",
" r\"\"\"\n",
" message function\n",
"\n",
" Parameters:\n",
" x_i: node features\n",
" x_j: neighbour node features\n",
" edge_feat: edge features\n",
" Returns:\n",
" feat: the message\n",
" \"\"\"\n",
" feat = torch.cat([x_j, edge_feat], dim=-1)\n",
" return feat\n",
" \n",
" def aggregate(\n",
" self,\n",
" inputs: Tensor,\n",
" index: Tensor,\n",
" edge_index: Tensor,\n",
" dim_size: Optional[int] = None,\n",
" ) -> Tensor:\n",
" r\"\"\"\n",
" aggregate function\n",
" Parameters:\n",
" inputs: input features\n",
" index: index of the nodes\n",
" edge_index: edge index\n",
" dim_size: dimension size\n",
" Returns:\n",
" out: aggregated features\n",
" \"\"\"\n",
" out = scatter(inputs, index, 0, None, dim_size, reduce=self.aggr)\n",
" return out\n",
" \n",
"\n",
" # Finally, we define all the virtual properties according to how the class works with proper documentation of course\n",
" @classproperty\n",
" def layer_supports_edges(cls) -> bool:\n",
" r\"\"\"\n",
" Return a boolean specifying if the layer type supports edges or not.\n",
"\n",
" Returns:\n",
"\n",
" bool:\n",
" False for the current class\n",
" \"\"\"\n",
" return True\n",
"\n",
" @property\n",
" def layer_inputs_edges(self) -> bool:\n",
" r\"\"\"\n",
" Returns:\n",
"\n",
" bool:\n",
" Returns True\n",
" \"\"\"\n",
" return True\n",
" \n",
" @property\n",
" def layer_outputs_edges(self) -> bool:\n",
" r\"\"\"\n",
" Returns:\n",
"\n",
" bool:\n",
" Always ``True`` for the current class\n",
" \"\"\"\n",
" return True\n",
"\n",
" @property\n",
" def out_dim_factor(self) -> int:\n",
" r\"\"\"\n",
" Get the factor by which the output dimension is multiplied for\n",
" the next layer.\n",
"\n",
" For standard layers, this will return ``1``.\n",
" Returns:\n",
"\n",
" int:\n",
" Always ``1`` for the current class\n",
" \"\"\"\n",
" return 1"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test our complex layer \n",
"\n",
"Now, we are ready to test the `ComplexMeanLayer` on our constructed graphs. Note that in this example, we utilize the edge features and node features"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([7, 5])\n",
"torch.Size([7, 13])\n",
"torch.Size([7, 11])\n"
]
}
],
"source": [
"graph = deepcopy(bg)\n",
"print(graph.feat.shape)\n",
"print(graph.edge_feat.shape)\n",
"\n",
"layer = ComplexMeanLayer(\n",
" in_dim=in_dim, \n",
" out_dim=out_dim, \n",
" in_dim_edges=in_dim_edges,\n",
" activation=\"relu\", \n",
" dropout=.3, \n",
" normalization=\"batch_norm\")\n",
"\n",
"graph = layer(graph)\n",
"print(graph.feat.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "f4a99d018a205fcbcc0480c84566beaebcb91b08d0414b39a842df533e2a1d25"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: docs/tutorials/gnn/making_gnn_networks.ipynb
================================================
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Making GNN Networks\n",
"\n",
"In this example, you will learn how to easily build a full multi-task GNN using any kind of GNN layer. This tutorial uses the architecture defined by the class `FullGraphMultiTaskNetwork`.\n",
"\n",
"`FullGraphMultiTaskNetwork` is an architecture that takes as input node features and (optionally) edge features. It applies a pre-MLP on both sets of features, then passes them into a main GNN network, and finally applies graph output NNs to produces the final outputs on possibly several task levels (graph, node, or edge-level).\n",
"\n",
"The network is very easy to built via a dictionnary of parameter that allow to customize each part of the network."
]
},
{
"cell_type": "code",
"execution_count": 114,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import torch\n",
"from copy import deepcopy\n",
"\n",
"from torch_geometric.data import Data, Batch\n",
"\n",
"from graphium.nn.architectures import FullGraphMultiTaskNetwork\n",
"\n",
"_ = torch.manual_seed(42)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We will first create some simple batched graphs that will be used accross the examples. Here, `bg` is a batch containing 2 graphs with random node features."
]
},
{
"cell_type": "code",
"execution_count": 115,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataBatch(edge_index=[2, 7], feat=[7, 5], edge_feat=[7, 13], batch=[7], ptr=[3])\n"
]
}
],
"source": [
"in_dim = 5 # Input node-feature dimensions\n",
"in_dim_edges = 13 # Input edge-feature dimensions\n",
"out_dim = 11 # Desired output node-feature dimensions\n",
"\n",
"\n",
"# Let's create 2 simple pyg graphs. \n",
"# start by specifying the edges with edge index\n",
"edge_idx1 = torch.tensor([[0, 1, 2],\n",
" [1, 2, 3]])\n",
"edge_idx2 = torch.tensor([[2, 0, 0, 1],\n",
" [0, 1, 2, 0]])\n",
"\n",
"# specify the node features, convention with variable x\n",
"x1 = torch.randn(edge_idx1.max() + 1, in_dim, dtype=torch.float32)\n",
"x2 = torch.randn(edge_idx2.max() + 1, in_dim, dtype=torch.float32)\n",
"\n",
"# specify the edge features in e\n",
"e1 = torch.randn(edge_idx1.shape[-1], in_dim_edges, dtype=torch.float32)\n",
"e2 = torch.randn(edge_idx2.shape[-1], in_dim_edges, dtype=torch.float32)\n",
"\n",
"# make the pyg graph objects with our constructed features\n",
"g1 = Data(feat=x1, edge_index=edge_idx1, edge_feat=e1)\n",
"g2 = Data(feat=x2, edge_index=edge_idx2, edge_feat=e2)\n",
"\n",
"# put the two graphs into a Batch graph\n",
"bg = Batch.from_data_list([g1, g2])\n",
"\n",
"# The batched graph will show as a single graph with 7 nodes\n",
"print(bg)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Building a network\n",
"\n",
"To build the network, we must define the arguments to pass at the different steps:\n",
"\n",
"- `pre_nn_kwargs`: The parameters used by a feed-forward neural network on the input node-features, before passing to the convolutional layers. See class `FeedForwardNN` for details on the required parameters. Will be ignored if set to `None`.\n",
"\n",
"- `gnn_kwargs`: The parameters used by a feed-forward **graph** neural network on the features after it has passed through the pre-processing NN. See class `FeedForwardGraph` for details on the required parameters.\n",
"\n",
"- `task_heads_kwargs`: The parameters used to specify the different task heads for possibly multiple tasks, each with a specified `task_level` (graph, node or edge).\n",
"\n",
"- `graph_output_nn_kwargs`: The parameters used by the graph output NNs to process the features after the GNN layers. We need to to specify a NN for each `task_level` occuring in in the `task_heads_kwargs`."
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [],
"source": [
"temp_dim_1 = 23\n",
"temp_dim_2 = 17\n",
"\n",
"pre_nn_kwargs = {\n",
" \"in_dim\": in_dim,\n",
" \"out_dim\": temp_dim_1,\n",
" \"hidden_dims\": 4,\n",
" \"depth\": 2,\n",
" \"activation\": 'relu',\n",
" \"last_activation\": \"none\",\n",
" \"dropout\": 0.2\n",
"}\n",
"\n",
"gnn_kwargs = {\n",
" \"in_dim\": temp_dim_1,\n",
" \"out_dim\": temp_dim_2,\n",
" \"hidden_dims\": 5,\n",
" \"depth\": 1,\n",
" \"activation\": 'gelu',\n",
" \"last_activation\": None,\n",
" \"dropout\": 0.1,\n",
" \"normalization\": 'layer_norm',\n",
" \"last_normalization\": 'layer_norm',\n",
" \"residual_type\": 'simple',\n",
" \"virtual_node\": None,\n",
" \"layer_type\": 'pyg:gcn',\n",
" \"layer_kwargs\": None\n",
"}\n",
"\n",
"task_heads_kwargs = {\n",
" \"graph-task-1\": {\n",
" \"task_level\": 'graph',\n",
" \"out_dim\": 3,\n",
" \"hidden_dims\": 32,\n",
" \"depth\": 2,\n",
" \"activation\": 'relu',\n",
" \"last_activation\": None,\n",
" \"dropout\": 0.1,\n",
" \"normalization\": None,\n",
" \"last_normalization\": None,\n",
" \"residual_type\": \"none\"\n",
" },\n",
" \"graph-task-2\": {\n",
" \"task_level\": 'graph',\n",
" \"out_dim\": 4,\n",
" \"hidden_dims\": 32,\n",
" \"depth\": 2,\n",
" \"activation\": 'relu',\n",
" \"last_activation\": None,\n",
" \"dropout\": 0.1,\n",
" \"normalization\": None,\n",
" \"last_normalization\": None,\n",
" \"residual_type\": \"none\"\n",
" },\n",
" \"node-task-1\": {\n",
" \"task_level\": 'node',\n",
" \"out_dim\": 2,\n",
" \"hidden_dims\": 32,\n",
" \"depth\": 2,\n",
" \"activation\": 'relu',\n",
" \"last_activation\": None,\n",
" \"dropout\": 0.1,\n",
" \"normalization\": None,\n",
" \"last_normalization\": None,\n",
" \"residual_type\": \"none\"\n",
" }\n",
"}\n",
"\n",
"graph_output_nn_kwargs = {\n",
" \"graph\": {\n",
" \"pooling\": ['sum'],\n",
" \"out_dim\": temp_dim_2,\n",
" \"hidden_dims\": temp_dim_2,\n",
" \"depth\": 1,\n",
" \"activation\": 'relu',\n",
" \"last_activation\": None,\n",
" \"dropout\": 0.1,\n",
" \"normalization\": None,\n",
" \"last_normalization\": None,\n",
" \"residual_type\": \"none\"\n",
" },\n",
" \"node\": {\n",
" \"pooling\": None,\n",
" \"out_dim\": temp_dim_2,\n",
" \"hidden_dims\": temp_dim_2,\n",
" \"depth\": 1,\n",
" \"activation\": 'relu',\n",
" \"last_activation\": None,\n",
" \"dropout\": 0.1,\n",
" \"normalization\": None,\n",
" \"last_normalization\": None,\n",
" \"residual_type\": \"none\"\n",
" }\n",
"}\n",
" \n",
"\n",
"gnn_net = FullGraphMultiTaskNetwork(\n",
" gnn_kwargs=gnn_kwargs,\n",
" pre_nn_kwargs=pre_nn_kwargs, \n",
" task_heads_kwargs=task_heads_kwargs,\n",
" graph_output_nn_kwargs = graph_output_nn_kwargs\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Applying the network\n",
"\n",
"Once the network is defined, we only need to run the forward pass on the input graphs to get a prediction.\n",
"\n",
"The model outputs a dictionary of outputs, one for each task specified in `task_heads_kwargs`."
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([7, 5])\n",
"\n",
"\n",
"FullGNN\n",
"---------\n",
" pre-NN(depth=2, ResidualConnectionNone)\n",
" [FCLayer[5 -> 4 -> 23]\n",
" \n",
" GNN(depth=1, ResidualConnectionSimple(skip_steps=1))\n",
" GCNConvPyg[23 -> 17]\n",
" \n",
" \n",
" Task heads:\n",
" graph-task-1: NN-graph-task-1(depth=2, ResidualConnectionNone)\n",
" [FCLayer[17 -> 32 -> 3]\n",
" graph-task-2: NN-graph-task-2(depth=2, ResidualConnectionNone)\n",
" [FCLayer[17 -> 32 -> 4]\n",
" node-task-1: NN-node-task-1(depth=2, ResidualConnectionNone)\n",
" [FCLayer[17 -> 32 -> 2]\n",
"\n",
"\n",
"graph-task-1 torch.Size([2, 3])\n",
"graph-task-2 torch.Size([2, 4])\n",
"node-task-1 torch.Size([7, 2])\n"
]
}
],
"source": [
"graph = deepcopy(bg)\n",
"print(graph.feat.shape)\n",
"print(\"\\n\")\n",
"\n",
"print(gnn_net)\n",
"print(\"\\n\")\n",
"\n",
"out = gnn_net(graph)\n",
"for task in out.keys():\n",
" print(task, out[task].shape)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Building a network (using additional edge features)\n",
"\n",
"We can further use a GNN that uses additional edge features as inputs. For that, we only need to slightly modify the `gnn_kwargs` from above. Below, we define the new `gnn_edge_kwargs`, where we can set the `layer_type` to one of the GNN layers that support edge features (e.g., `pyg:gine` or `pyg:gps`) and add the additional parameter `in_dim_edges`. Optionally, we can configure a preprocessing NN for the the edge features via a dictionaly `pre_nn_edges_kwargs` analogous to `pre_nn_kwargs` above, or skip this step by setting it to `None`."
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {},
"outputs": [],
"source": [
"temp_dim_edges = 8\n",
"\n",
"gnn_edge_kwargs = {\n",
" \"in_dim\": temp_dim_1,\n",
" \"in_dim_edges\": temp_dim_edges,\n",
" \"out_dim\": temp_dim_2,\n",
" \"hidden_dims\": 5,\n",
" \"depth\": 1,\n",
" \"activation\": 'gelu',\n",
" \"last_activation\": None,\n",
" \"dropout\": 0.1,\n",
" \"normalization\": 'layer_norm',\n",
" \"last_normalization\": 'layer_norm',\n",
" \"residual_type\": 'simple',\n",
" \"virtual_node\": None,\n",
" \"layer_type\": 'pyg:gine',\n",
" \"layer_kwargs\": None\n",
"}\n",
"\n",
"pre_nn_edges_kwargs = {\n",
" \"in_dim\": in_dim_edges,\n",
" \"out_dim\": temp_dim_edges,\n",
" \"hidden_dims\": 4,\n",
" \"depth\": 2,\n",
" \"activation\": 'relu',\n",
" \"last_activation\": \"none\",\n",
" \"dropout\": 0.2\n",
"}\n",
"\n",
"\n",
"gnn_net_edges = FullGraphMultiTaskNetwork(\n",
" gnn_kwargs=gnn_edge_kwargs,\n",
" pre_nn_kwargs=pre_nn_kwargs,\n",
" pre_nn_edges_kwargs=pre_nn_edges_kwargs,\n",
" task_heads_kwargs=task_heads_kwargs,\n",
" graph_output_nn_kwargs = graph_output_nn_kwargs\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Applying the network (using additional edge features)\n",
"\n",
"Again, we only need to run the forward pass on the input graphs to get a prediction."
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([7, 5])\n",
"torch.Size([7, 13])\n",
"\n",
"\n",
"FullGNN\n",
"---------\n",
" pre-NN(depth=2, ResidualConnectionNone)\n",
" [FCLayer[5 -> 4 -> 23]\n",
" \n",
" GNN(depth=1, ResidualConnectionSimple(skip_steps=1))\n",
" GCNConvPyg[23 -> 17]\n",
" \n",
" \n",
" Task heads:\n",
" graph-task-1: NN-graph-task-1(depth=2, ResidualConnectionNone)\n",
" [FCLayer[17 -> 32 -> 3]\n",
" graph-task-2: NN-graph-task-2(depth=2, ResidualConnectionNone)\n",
" [FCLayer[17 -> 32 -> 4]\n",
" node-task-1: NN-node-task-1(depth=2, ResidualConnectionNone)\n",
" [FCLayer[17 -> 32 -> 2]\n",
"\n",
"\n",
"graph-task-1 torch.Size([2, 3])\n",
"graph-task-2 torch.Size([2, 4])\n",
"node-task-1 torch.Size([7, 2])\n"
]
}
],
"source": [
"graph = deepcopy(bg)\n",
"print(graph.feat.shape)\n",
"print(graph.edge_feat.shape)\n",
"print(\"\\n\")\n",
"\n",
"print(gnn_net)\n",
"print(\"\\n\")\n",
"\n",
"out = gnn_net_edges(graph)\n",
"for task in out.keys():\n",
" print(task, out[task].shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "f4a99d018a205fcbcc0480c84566beaebcb91b08d0414b39a842df533e2a1d25"
},
"kernelspec": {
"display_name": "Python 3.8.8 64-bit ('goli': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
================================================
FILE: docs/tutorials/gnn/using_gnn_layers.ipynb
================================================
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using GNN layers\n",
"\n",
"The current library implements multiple state-of-the-art graph neural networks. In this tutorial, you will learn how to use the **GCN**, **GIN**, **GINE**, **GPS**, **Gated-GCN** and **PNA** layers in a simple `forward` context."
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import torch\n",
"\n",
"import torch_geometric as pyg\n",
"from torch_geometric.data import Data, Batch\n",
"\n",
"from copy import deepcopy\n",
"\n",
"from graphium.nn.pyg_layers import (\n",
" GCNConvPyg,\n",
" GINConvPyg,\n",
" GatedGCNPyg,\n",
" GINEConvPyg,\n",
" GPSLayerPyg,\n",
" PNAMessagePassingPyg\n",
")\n",
"\n",
"_ = torch.manual_seed(42)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We will first create some simple batched graphs that will be used accross the examples. Here, `bg` is a batch containing 2 graphs with random node features."
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataBatch(edge_index=[2, 7], feat=[7, 5], edge_feat=[7, 13], batch=[7], ptr=[3])\n"
]
}
],
"source": [
"in_dim = 5 # Input node-feature dimensions\n",
"in_dim_edges = 13 # Input edge-feature dimensions\n",
"out_dim = 11 # Desired output node-feature dimensions\n",
"out_dim_edges = 15 # Desired output edge-feature dimensions\n",
"\n",
"\n",
"# Let's create 2 simple pyg graphs. \n",
"# start by specifying the edges with edge index\n",
"edge_idx1 = torch.tensor([[0, 1, 2],\n",
" [1, 2, 3]])\n",
"edge_idx2 = torch.tensor([[2, 0, 0, 1],\n",
" [0, 1, 2, 0]])\n",
"\n",
"# specify the node features, convention with variable x\n",
"x1 = torch.randn(edge_idx1.max() + 1, in_dim, dtype=torch.float32)\n",
"x2 = torch.randn(edge_idx2.max() + 1, in_dim, dtype=torch.float32)\n",
"\n",
"# specify the edge features in e\n",
"e1 = torch.randn(edge_idx1.shape[-1], in_dim_edges, dtype=torch.float32)\n",
"e2 = torch.randn(edge_idx2.shape[-1], in_dim_edges, dtype=torch.float32)\n",
"\n",
"# make the pyg graph objects with our constructed features\n",
"g1 = Data(feat=x1, edge_index=edge_idx1, edge_feat=e1)\n",
"g2 = Data(feat=x2, edge_index=edge_idx2, edge_feat=e2)\n",
"\n",
"# put the two graphs into a Batch graph\n",
"bg = Batch.from_data_list([g1, g2])\n",
"\n",
"# The batched graph will show as a single graph with 7 nodes\n",
"print(bg)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## GCN Layer\n",
"\n",
"To use the GCN layer from the *Kipf et al.* paper, the steps are very simple. We create the layer with the desired attributes, and apply it to the graph.\n",
"\n",
"Kipf, Thomas N., and Max Welling. \"Semi-supervised classification with graph convolutional networks.\" arXiv preprint arXiv:1609.02907 (2016)."
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([7, 5])\n",
"GCNConvPyg(5 -> 11, activation=relu)\n",
"torch.Size([7, 11])\n"
]
}
],
"source": [
"# The GCN method doesn't support edge features, so we ignore them\n",
"graph = deepcopy(bg)\n",
"print(graph.feat.shape)\n",
"\n",
"# We create the layer\n",
"layer = GCNConvPyg(\n",
" in_dim=in_dim, out_dim=out_dim, \n",
" activation=\"relu\", dropout=.3, normalization=\"batch_norm\")\n",
"\n",
"# We apply the forward loop on the node features\n",
"graph = layer(graph)\n",
"\n",
"# 7 is the number of nodes, 5 number of input features and 11 number of output features\n",
"print(layer)\n",
"print(graph.feat.shape)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## GIN Layer\n",
"\n",
"To use the GIN layer from the *Xu et al.* paper, the steps are identical to GCN.\n",
"\n",
"Xu, Keyulu, et al. \"How powerful are graph neural networks?.\" arXiv preprint arXiv:1810.00826 (2018)."
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([7, 5])\n",
"GINConvPyg(5 -> 11, activation=relu)\n",
"torch.Size([7, 11])\n"
]
}
],
"source": [
"graph = deepcopy(bg)\n",
"print(graph.feat.shape)\n",
"\n",
"# We create the layer\n",
"layer = GINConvPyg(\n",
" in_dim=in_dim, out_dim=out_dim, \n",
" activation=\"relu\", dropout=.3, normalization=\"batch_norm\")\n",
"\n",
"# We apply the forward loop on the node features\n",
"graph = layer(graph)\n",
"\n",
"# 7 is the number of nodes, 5 number of input features and 11 number of output features\n",
"print(layer)\n",
"print(graph.feat.shape)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## GINE Layer\n",
"\n",
"To use the GINE layer from the *Hu et al.* paper, we also need to provide additional edge features as inputs.\n",
"\n",
"Hu, Weihua, et al. \"Strategies for Pre-training Graph Neural Networks.\" arXiv preprint arXiv:1905.12265 (2019)."
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([7, 5])\n",
"torch.Size([7, 13])\n",
"GINEConvPyg(5 -> 11, activation=relu)\n",
"torch.Size([7, 11])\n"
]
}
],
"source": [
"# The GINE method uses edge features, so we have to pass the input dimension\n",
"graph = deepcopy(bg)\n",
"print(graph.feat.shape)\n",
"print(graph.edge_feat.shape)\n",
"\n",
"# We create the layer\n",
"layer = GINEConvPyg(\n",
" in_dim=in_dim, out_dim=out_dim,\n",
" in_dim_edges=in_dim_edges,\n",
" activation=\"relu\", dropout=.3, normalization=\"batch_norm\")\n",
"\n",
"# We apply the forward loop on the node features\n",
"graph = layer(graph)\n",
"\n",
"# 7 is the number of nodes, 5 number of input features and 11 number of output features\n",
"# 7 is the number of edges, 13 number of input edge features\n",
"print(layer)\n",
"print(graph.feat.shape)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## GPS Layer\n",
"\n",
"To use the GPS layer from the *Rampášek et al.* paper, we also need to provide additional edge features as inputs. It is a hybrid approach using both a GNN and transformer in conjunction. Therefore, we further need to specify the GNN type and attention type used in the layer.\n",
"\n",
"Rampášek, Ladislav, et al. \"Recipe for a General, Powerful, Scalable Graph Transformer.\" arXiv preprint arXiv:2205.12454 (2022)."
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([7, 5])\n",
"torch.Size([7, 13])\n",
"GPSLayerPyg(5 -> 11, activation=relu)\n",
"torch.Size([7, 11])\n",
"torch.Size([7, 13])\n"
]
}
],
"source": [
"graph = deepcopy(bg)\n",
"print(graph.feat.shape)\n",
"print(graph.edge_feat.shape)\n",
"\n",
"# We create the layer\n",
"layer = GPSLayerPyg(\n",
" in_dim=in_dim, out_dim=out_dim,\n",
" in_dim_edges=in_dim_edges,\n",
" mpnn_type = \"pyg:gine\", attn_type = \"full-attention\",\n",
" activation=\"relu\", dropout=.3, normalization=\"batch_norm\")\n",
"\n",
"# We apply the forward loop on the node features\n",
"graph = layer(graph)\n",
"\n",
"# 7 is the number of nodes, 5 number of input features and 11 number of output features\n",
"# 7 is the number of edges, 13 number of input edge features\n",
"print(layer)\n",
"print(graph.feat.shape)\n",
"print(graph.edge_feat.shape)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gated-GCN Layer\n",
"\n",
"To use the Gated-GCN layer from the *Bresson et al.* paper, the steps are different since the layer not only requires edge features as inputs, but also outputs new edge features. Therefore, we have to further specify the number of output edge features\n",
"\n",
"Bresson, Xavier, and Thomas Laurent. \"Residual gated graph convnets.\" arXiv preprint arXiv:1711.07553 (2017)."
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([7, 5])\n",
"torch.Size([7, 13])\n",
"GatedGCNPyg()\n",
"torch.Size([7, 11])\n",
"torch.Size([7, 15])\n"
]
}
],
"source": [
"graph = deepcopy(bg)\n",
"print(graph.feat.shape)\n",
"print(graph.edge_feat.shape)\n",
"\n",
"# We create the layer\n",
"layer = GatedGCNPyg(\n",
" in_dim=in_dim, out_dim=out_dim,\n",
" in_dim_edges=in_dim_edges, out_dim_edges=out_dim_edges,\n",
" activation=\"relu\", dropout=.3, normalization=\"batch_norm\")\n",
"\n",
"# We apply the forward loop on the node features\n",
"graph = layer(graph)\n",
"\n",
"# 7 is the number of nodes, 5 number of input features and 11 number of output features\n",
"# 7 is the number of edges, 13 number of input edge features and 15 the the number of output edge features\n",
"print(layer)\n",
"print(graph.feat.shape)\n",
"print(graph.edge_feat.shape)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## PNA\n",
"\n",
"PNA is a multi-aggregator method proposed by *Corso et al.*. It supports 2 types of aggregations, convolutional *PNA-conv* or message passing *PNA-msgpass*. Here, we provide the typically more powerful *PNA-msgpass*. It supports edges as inputs, but doesn't output edges. Here, we need to further specify the aggregators and scalers specific to this layer.\n",
"\n",
"Corso, Gabriele, et al. \"Principal Neighbourhood Aggregation for Graph Nets.\"\n",
"arXiv preprint arXiv:2004.05718 (2020)."
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([7, 5])\n",
"torch.Size([7, 13])\n",
"PNAMessagePassingPyg()\n",
"torch.Size([7, 11])\n"
]
}
],
"source": [
"graph = deepcopy(bg)\n",
"print(graph.feat.shape)\n",
"print(graph.edge_feat.shape)\n",
"\n",
"# We create the layer, and need to specify the aggregators and scalers\n",
"layer = PNAMessagePassingPyg(\n",
" in_dim=in_dim, out_dim=out_dim,\n",
" in_dim_edges=in_dim_edges,\n",
" aggregators=[\"mean\", \"max\", \"min\", \"std\"],\n",
" scalers=[\"identity\", \"amplification\", \"attenuation\"],\n",
" activation=\"relu\", dropout=.3, normalization=\"batch_norm\")\n",
"\n",
"graph = layer(graph)\n",
"\n",
"print(layer)\n",
"print(graph.feat.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "f4a99d018a205fcbcc0480c84566beaebcb91b08d0414b39a842df533e2a1d25"
},
"kernelspec": {
"display_name": "Python 3.8.8 64-bit ('goli': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 2
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: docs/tutorials/model_training/simple-molecular-model.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Building and training a simple model from configurations\n",
"\n",
"This tutorial will walk you through how to use a configuration file to define all the parameters of a model and of the trainer. This tutorial focuses on training from SMILES data in a CSV format.\n",
"\n",
"The work flow of testing your code on the entire pipeline is as follows:\n",
"\n",
"1. Select a subset of the [available configs](https://github.com/datamol-io/graphium/tree/main/expts/hydra-configs) as a starting point.\n",
"2. Create additional configs or modify the existing configs to suit your needs.\n",
"3. Train or fine-tune a model with the `graphium-train` CLI.\n",
"\n",
"## Creating the yaml file\n",
"\n",
"The first step is to create a YAML file containing all the required configurations, with an example given at `graphium/expts/hydra-configs/main.yaml`. We will go through each part of the configurations. See also the README [here](https://github.com/datamol-io/graphium/tree/main/expts/hydra-configs)."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"import yaml\n",
"import omegaconf\n",
"\n",
"from hydra import compose, initialize"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def print_config_with_key(config, key):\n",
" new_config = {key: config[key]}\n",
" print(omegaconf.OmegaConf.to_yaml(new_config))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Yaml file loaded\n"
]
}
],
"source": [
"# First, let's read the yaml configuration file\n",
"with initialize(version_base=None, config_path=\"../../../expts/hydra-configs\"):\n",
" yaml_config = compose(config_name=\"main\")\n",
"\n",
"print(\"Yaml file loaded\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Constants\n",
"\n",
"First, we define the constants such as the random seed and whether the model should raise or ignore an error."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"constants:\n",
" name: neurips2023_small_data_gcn\n",
" seed: 42\n",
" max_epochs: 100\n",
" data_dir: expts/data/neurips2023/small-dataset\n",
" raise_train_error: true\n",
"\n"
]
}
],
"source": [
"print_config_with_key(yaml_config, \"constants\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Datamodule\n",
"\n",
"Here, we define all the parameters required by the datamodule to run correctly, such as the dataset path, whether to cache, the columns for the training, the molecular featurization to use, the train/val/test splits and the batch size.\n",
"\n",
"For more details, see class [`MultitaskFromSmilesDataModule`](https://graphium-docs.datamol.io/stable/api/graphium.data.html#graphium.data.datamodule.MultitaskFromSmilesDataModule)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"datamodule:\n",
" module_type: MultitaskFromSmilesDataModule\n",
" args:\n",
" prepare_dict_or_graph: pyg:graph\n",
" featurization_n_jobs: 4\n",
" featurization_progress: true\n",
" featurization_backend: loky\n",
" processed_graph_data_path: ../datacache/neurips2023-small/\n",
" num_workers: 4\n",
" persistent_workers: false\n",
" featurization:\n",
" atom_property_list_onehot:\n",
" - atomic-number\n",
" - group\n",
" - period\n",
" - total-valence\n",
" atom_property_list_float:\n",
" - degree\n",
" - formal-charge\n",
" - radical-electron\n",
" - aromatic\n",
" - in-ring\n",
" edge_property_list:\n",
" - bond-type-onehot\n",
" - stereo\n",
" - in-ring\n",
" add_self_loop: false\n",
" explicit_H: false\n",
" use_bonds_weights: false\n",
" pos_encoding_as_features:\n",
" pos_types:\n",
" lap_eigvec:\n",
" pos_level: node\n",
" pos_type: laplacian_eigvec\n",
" num_pos: 8\n",
" normalization: none\n",
" disconnected_comp: true\n",
" lap_eigval:\n",
" pos_level: node\n",
" pos_type: laplacian_eigval\n",
" num_pos: 8\n",
" normalization: none\n",
" disconnected_comp: true\n",
" rw_pos:\n",
" pos_level: node\n",
" pos_type: rw_return_probs\n",
" ksteps: 16\n",
" task_specific_args:\n",
" qm9:\n",
" df: null\n",
" df_path: ${constants.data_dir}/qm9.csv.gz\n",
" smiles_col: smiles\n",
" label_cols:\n",
" - A\n",
" - B\n",
" - C\n",
" - mu\n",
" - alpha\n",
" - homo\n",
" - lumo\n",
" - gap\n",
" - r2\n",
" - zpve\n",
" - u0\n",
" - u298\n",
" - h298\n",
" - g298\n",
" - cv\n",
" - u0_atom\n",
" - u298_atom\n",
" - h298_atom\n",
" - g298_atom\n",
" splits_path: ${constants.data_dir}/qm9_random_splits.pt\n",
" seed: ${constants.seed}\n",
" task_level: graph\n",
" label_normalization:\n",
" normalize_val_test: true\n",
" method: normal\n",
" tox21:\n",
" df: null\n",
" df_path: ${constants.data_dir}/Tox21-7k-12-labels.csv.gz\n",
" smiles_col: smiles\n",
" label_cols:\n",
" - NR-AR\n",
" - NR-AR-LBD\n",
" - NR-AhR\n",
" - NR-Aromatase\n",
" - NR-ER\n",
" - NR-ER-LBD\n",
" - NR-PPAR-gamma\n",
" - SR-ARE\n",
" - SR-ATAD5\n",
" - SR-HSE\n",
" - SR-MMP\n",
" - SR-p53\n",
" splits_path: ${constants.data_dir}/Tox21_random_splits.pt\n",
" seed: ${constants.seed}\n",
" task_level: graph\n",
" zinc:\n",
" df: null\n",
" df_path: ${constants.data_dir}/ZINC12k.csv.gz\n",
" smiles_col: smiles\n",
" label_cols:\n",
" - SA\n",
" - logp\n",
" - score\n",
" splits_path: ${constants.data_dir}/ZINC12k_random_splits.pt\n",
" seed: ${constants.seed}\n",
" task_level: graph\n",
" label_normalization:\n",
" normalize_val_test: true\n",
" method: normal\n",
" batch_size_training: 200\n",
" batch_size_inference: 200\n",
"\n"
]
}
],
"source": [
"print_config_with_key(yaml_config, \"datamodule\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Architecture\n",
"\n",
"The architecture is based on [`FullGraphMultiTaskNetwork`](https://graphium-docs.datamol.io/stable/api/graphium.nn/architectures.html#graphium.nn.architectures.global_architectures.FullGraphMultiTaskNetwork).\n",
"Here, we define all the layers for the model, including the layers for the pre-processing MLP (input layers `pre-nn` and `pre_nn_edges`), the positional encoder (`pe_encoders`), the post-processing MLP (output layers `post-nn`), and the main GNN (graph neural network `gnn`).\n",
"\n",
"You can find details in the following: \n",
"- info about the positional encoder in [`graphium.nn.encoders`](https://graphium-docs.datamol.io/stable/api/graphium.nn/encoders.html)\n",
"- info about the gnn layers in [`graphium.nn.pyg_layers`](https://graphium-docs.datamol.io/stable/api/graphium.nn/pyg_layers.html)\n",
"- info about the architecture [`FullGraphMultiTaskNetwork`](https://graphium-docs.datamol.io/stable/api/graphium.nn/architectures.html#graphium.nn.architectures.global_architectures.FullGraphMultiTaskNetwork)\n",
"- Main class for the GNN layers in [`BaseGraphStructure`](https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_graph_layer.BaseGraphStructure)\n",
"\n",
"The parameters allow to chose the feature size, the depth, the skip connections, the pooling and the virtual node. It also support different GNN layers such as [`GatedGCNPyg`](https://graphium-docs.datamol.io/stable/api/graphium.nn/pyg_layers.html#graphium.nn.pyg_layers.gated_gcn_pyg), [`GINConvPyg`](https://graphium-docs.datamol.io/stable/api/graphium.nn/pyg_layers.html#graphium.nn.pyg_layers.gin_pyg), [`GINEConvPyg`](https://graphium-docs.datamol.io/stable/api/graphium.nn/pyg_layers.html#graphium.nn.pyg_layers.gin_pyg.GINEConvPyg), [`GPSLayerPyg`](https://graphium-docs.datamol.io/stable/api/graphium.nn/pyg_layers.html#graphium.nn.pyg_layers.gps_pyg.GPSLayerPyg), [`MPNNPlusPyg`](https://graphium-docs.datamol.io/stable/api/graphium.nn/pyg_layers.html#graphium.nn.pyg_layers.mpnn_pyg.MPNNPlusPyg).\n"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"architecture:\n",
" model_type: FullGraphMultiTaskNetwork\n",
" mup_base_path: null\n",
" pre_nn:\n",
" out_dim: 64\n",
" hidden_dims: 256\n",
" depth: 2\n",
" activation: relu\n",
" last_activation: none\n",
" dropout: 0.18\n",
" normalization: layer_norm\n",
" last_normalization: ${architecture.pre_nn.normalization}\n",
" residual_type: none\n",
" pre_nn_edges: null\n",
" pe_encoders:\n",
" out_dim: 32\n",
" pool: sum\n",
" last_norm: None\n",
" encoders:\n",
" la_pos:\n",
" encoder_type: laplacian_pe\n",
" input_keys:\n",
" - laplacian_eigvec\n",
" - laplacian_eigval\n",
" output_keys:\n",
" - feat\n",
" hidden_dim: 64\n",
" out_dim: 32\n",
" model_type: DeepSet\n",
" num_layers: 2\n",
" num_layers_post: 1\n",
" dropout: 0.1\n",
" first_normalization: none\n",
" rw_pos:\n",
" encoder_type: mlp\n",
" input_keys:\n",
" - rw_return_probs\n",
" output_keys:\n",
" - feat\n",
" hidden_dim: 64\n",
" out_dim: 32\n",
" num_layers: 2\n",
" dropout: 0.1\n",
" normalization: layer_norm\n",
" first_normalization: layer_norm\n",
" gnn:\n",
" in_dim: 64\n",
" out_dim: 96\n",
" hidden_dims: 96\n",
" depth: 4\n",
" activation: gelu\n",
" last_activation: none\n",
" dropout: 0.1\n",
" normalization: layer_norm\n",
" last_normalization: ${architecture.pre_nn.normalization}\n",
" residual_type: simple\n",
" virtual_node: none\n",
" layer_type: pyg:gcn\n",
" layer_kwargs: null\n",
" graph_output_nn:\n",
" graph:\n",
" pooling:\n",
" - sum\n",
" out_dim: 96\n",
" hidden_dims: 96\n",
" depth: 1\n",
" activation: relu\n",
" last_activation: none\n",
" dropout: ${architecture.pre_nn.dropout}\n",
" normalization: ${architecture.pre_nn.normalization}\n",
" last_normalization: none\n",
" residual_type: none\n",
" task_heads:\n",
" qm9:\n",
" task_level: graph\n",
" out_dim: 19\n",
" hidden_dims: 128\n",
" depth: 2\n",
" activation: relu\n",
" last_activation: none\n",
" dropout: ${architecture.pre_nn.dropout}\n",
" normalization: ${architecture.pre_nn.normalization}\n",
" last_normalization: none\n",
" residual_type: none\n",
" tox21:\n",
" task_level: graph\n",
" out_dim: 12\n",
" hidden_dims: 64\n",
" depth: 2\n",
" activation: relu\n",
" last_activation: none\n",
" dropout: ${architecture.pre_nn.dropout}\n",
" normalization: ${architecture.pre_nn.normalization}\n",
" last_normalization: none\n",
" residual_type: none\n",
" zinc:\n",
" task_level: graph\n",
" out_dim: 3\n",
" hidden_dims: 32\n",
" depth: 2\n",
" activation: relu\n",
" last_activation: none\n",
" dropout: ${architecture.pre_nn.dropout}\n",
" normalization: ${architecture.pre_nn.normalization}\n",
" last_normalization: none\n",
" residual_type: none\n",
"\n"
]
}
],
"source": [
"print_config_with_key(yaml_config, \"architecture\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Predictor\n",
"\n",
"In the predictor, we define the loss functions, the metrics to track on the progress bar, and all the parameters necessary for the optimizer."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"predictor:\n",
" metrics_on_progress_bar:\n",
" qm9:\n",
" - mae\n",
" tox21:\n",
" - auroc\n",
" zinc:\n",
" - mae\n",
" loss_fun:\n",
" qm9: mae_ipu\n",
" tox21: bce_logits_ipu\n",
" zinc: mae_ipu\n",
" random_seed: ${constants.seed}\n",
" optim_kwargs:\n",
" lr: 4.0e-05\n",
" torch_scheduler_kwargs:\n",
" module_type: WarmUpLinearLR\n",
" max_num_epochs: ${constants.max_epochs}\n",
" warmup_epochs: 10\n",
" verbose: false\n",
" scheduler_kwargs: null\n",
" target_nan_mask: null\n",
" multitask_handling: flatten\n",
" metrics_every_n_train_steps: 300\n",
"\n"
]
}
],
"source": [
"print_config_with_key(yaml_config, \"predictor\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Metrics\n",
"\n",
"All the metrics can be defined there. If we want to use a classification metric, we can also define a threshold.\n",
"\n",
"See class [`graphium.trainer.metrics.MetricWrapper`](https://graphium-docs.datamol.io/stable/api/graphium.trainer.html#graphium.trainer.metrics.MetricWrapper) for more details."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"metrics:\n",
" qm9:\n",
" - name: mae\n",
" metric: mae_ipu\n",
" target_nan_mask: null\n",
" multitask_handling: flatten\n",
" threshold_kwargs: null\n",
" - name: pearsonr\n",
" metric: pearsonr_ipu\n",
" threshold_kwargs: null\n",
" target_nan_mask: null\n",
" multitask_handling: mean-per-label\n",
" - name: r2_score\n",
" metric: r2_score_ipu\n",
" target_nan_mask: null\n",
" multitask_handling: mean-per-label\n",
" threshold_kwargs: null\n",
" tox21:\n",
" - name: auroc\n",
" metric: auroc_ipu\n",
" task: binary\n",
" multitask_handling: mean-per-label\n",
" threshold_kwargs: null\n",
" - name: avpr\n",
" metric: average_precision_ipu\n",
" task: binary\n",
" multitask_handling: mean-per-label\n",
" threshold_kwargs: null\n",
" - name: f1 > 0.5\n",
" metric: f1\n",
" multitask_handling: mean-per-label\n",
" target_to_int: true\n",
" num_classes: 2\n",
" average: micro\n",
" threshold_kwargs:\n",
" operator: greater\n",
" threshold: 0.5\n",
" th_on_preds: true\n",
" th_on_target: true\n",
" - name: precision > 0.5\n",
" metric: precision\n",
" multitask_handling: mean-per-label\n",
" average: micro\n",
" threshold_kwargs:\n",
" operator: greater\n",
" threshold: 0.5\n",
" th_on_preds: true\n",
" th_on_target: true\n",
" zinc:\n",
" - name: mae\n",
" metric: mae_ipu\n",
" target_nan_mask: null\n",
" multitask_handling: flatten\n",
" threshold_kwargs: null\n",
" - name: pearsonr\n",
" metric: pearsonr_ipu\n",
" threshold_kwargs: null\n",
" target_nan_mask: null\n",
" multitask_handling: mean-per-label\n",
" - name: r2_score\n",
" metric: r2_score_ipu\n",
" target_nan_mask: null\n",
" multitask_handling: mean-per-label\n",
" threshold_kwargs: null\n",
"\n"
]
}
],
"source": [
"print_config_with_key(yaml_config, \"metrics\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Trainer\n",
"\n",
"Finally, the Trainer defines the parameters for the number of epochs to train, the checkpoints, and the patience."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"trainer:\n",
" seed: ${constants.seed}\n",
" model_checkpoint:\n",
" filename: ${constants.name}\n",
" save_last: true\n",
" dirpath: models_checkpoints/neurips2023-small-gcn/\n",
" trainer:\n",
" precision: 32\n",
" max_epochs: ${constants.max_epochs}\n",
" min_epochs: 1\n",
" check_val_every_n_epoch: 20\n",
" accumulate_grad_batches: 1\n",
"\n"
]
}
],
"source": [
"print_config_with_key(yaml_config, \"trainer\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training the model\n",
"\n",
"Now that we defined all the configuration files, we want to train the model. The steps are fairly easy using the config loaders, and are given below.\n",
"\n",
"First make sure the dataset file is downloaded. Using `config_gps_10M_pcqm4m.yaml` as an example, make sure the file specified by `df_path` in the config is available.\n",
"In this case, we need to download `pcqm4mv2-20k.csv` into the specified directory `graphium/data/PCQM4M/pcqm4mv2-20k.csv`.\n",
"\n",
"After that, we can simply run a training through the CLI:\n",
"```bash\n",
"graphium-train\n",
"```"
]
}
],
"metadata": {
"interpreter": {
"hash": "f4a99d018a205fcbcc0480c84566beaebcb91b08d0414b39a842df533e2a1d25"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
================================================
FILE: enable_ipu.sh
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Graphcore Limited is not liable
for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
#!/bin/bash
# Default location for the virtual environment
default_venv_name=".graphium_ipu"
# Allow the user to specify the location of their virtual environment
# If not specified, use the default location
venv_name=${1:-$default_venv_name}
# Constants
sdk_path="${venv_name}/poplar_sdk-ubuntu_20_04-3.3.0+1403-208993bbb7"
# Source the virtual environment
source ${venv_name}/bin/activate
source ${sdk_path}/enable
================================================
FILE: env.yml
================================================
channels:
- conda-forge
# - pyg # Add for Windows
dependencies:
- python >=3.8
- pip
- typer
- loguru
- omegaconf >=2.0.0
- tqdm
- platformdirs
# scientific
- numpy
- scipy >=1.4
- pandas >=1.0
- scikit-learn
- fastparquet
# viz
- matplotlib >=3.0.1
- seaborn
# cloud IO
- fsspec >=2021.6
- s3fs >=2021.6
- gcsfs >=2021.6
# ML packages
- cuda-version # works also with CPU-only system.
- pytorch >=1.12
- lightning >=2.0
- torchmetrics >=0.7.0,<0.11
- ogb
- pytorch_geometric >=2.0 # Use `pyg` for Windows instead of `pytorch_geometric`
- wandb
- mup
- pytorch_sparse >=0.6
- pytorch_cluster >=1.5
- pytorch_scatter >=2.0
# chemistry
- rdkit
- datamol >=0.10
# Optional deps
- sympy
- tensorboard
- pydantic <2 # because of lightning. See https://github.com/Lightning-AI/lightning/issues/18026 and https://github.com/Lightning-AI/lightning/pull/18022
# Dev
- pytest >=6.0
- pytest-xdist
- pytest-cov
- pytest-forked
- nbconvert
- black >=23
- jupyterlab
- ipywidgets
# Doc
- mkdocs
- mkdocs-material
- mkdocs-material-extensions
- mkdocstrings
- mkdocstrings-python
- mkdocs-jupyter
- markdown-include
- mike >=1.0.0
- pip:
- lightning-graphcore # optional, for using IPUs only
- hydra-core>=1.3.2
- hydra-optuna-sweeper
================================================
FILE: expts/__init__.py
================================================
================================================
FILE: expts/configs/config_gps_10M_pcqm4m.yaml
================================================
# Testing the mpnn only model with the PCQMv2 dataset on IPU.
constants:
name: &name pcqm4mv2_mpnn_4layer
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 20 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 60
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 16 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 120
# Data handling-related
batch_size_training: 64
batch_size_inference: 16
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 16
accumulate_grad_batches: 4
ipu_config:
- deviceIterations(20) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
ipu_inference_config: # Optional. If not provided, same as `ipu_config`
- deviceIterations(80) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 256
# batch_size_inference: 64
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
homolumo:
df: null
task_level: "graph"
df_path: ~/scratch/data/graphium/data/PCQM4M/pcqm4mv2-20k.csv
# wget https://storage.valencelabs.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2-20k.csv
# or set path as https://storage.valencelabs.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2-20k.csv directly
smiles_col: "cxsmiles"
label_cols: ["homo_lumo_gap"]
# sample_size: 30000 # use sample_size for test
# splits_path: graphium/data/PCQM4M/split_dict_v2.pt # Download with `wget https://storage.valencelabs.com/datasets-public-research/PCQM4M/cxsmiles/split_dict_v2.pt`
split_val: 0.1
split_test: 0.1
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
mask_nan: 0
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
conformer_property_list: [positions_3d]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 0 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 64
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.1
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 16
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
out_dim: 32
hidden_dims: 32
depth: 4
activation: gelu
last_activation: none
dropout: 0.0
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
pooling: [sum]
virtual_node: 'none'
layer_type: 'pyg:gps' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
layer_kwargs: # Parameters for the model itself. You could define dropout_attn: 0.1
node_residual: false
mpnn_type: 'pyg:mpnnplus'
mpnn_kwargs:
in_dim: 32
out_dim: 32
in_dim_edges: 16
out_dim_edges: 16
attn_type: "full-attention" # "full-attention", "none"
# biased_attention: false
attn_kwargs:
num_heads: *num_heads
biased_attention_key: nodepair_gaussian_bias_3d
post_nn: null
task_heads:
homolumo:
out_dim: 1
hidden_dims: 256
depth: 2 # Not needed if we have hidden_dims
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
homolumo: ["mae", "pearsonr"]
loss_fun:
homolumo: mse_ipu
random_seed: *seed
optim_kwargs:
lr: 4.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 5
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor homolumo/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore: ignore nan values from loss
flag_kwargs:
n_steps: 0 # 1
alpha: 0.0 # 0.01
# Task-specific
metrics:
homolumo:
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
trainer:
logger:
save_dir: logs/PCQMv2
name: *name
project: PCQMv2_mpnn
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/PCMQv2/
filename: *name
#monitor: *monitor
#mode: *mode
save_top_k: 1
every_n_epochs: 100
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/configs/config_gps_10M_pcqm4m_mod.yaml
================================================
# Testing the mpnn only model with the PCQMv2 dataset on IPU.
constants:
name: &name pcqm4mv2_mpnn_4layer
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: gpu # cpu or ipu or gpu
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
homolumo:
df: null
task_level: "graph"
df_path: ~/scratch/data/graphium/data/PCQM4M/pcqm4mv2-20k.csv
# wget https://storage.valencelabs.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2-20k.csv
# or set path as https://storage.valencelabs.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2-20k.csv directly
smiles_col: "cxsmiles"
label_cols: ["homo_lumo_gap"]
# sample_size: 30000 # use sample_size for test
# splits_path: graphium/data/PCQM4Mv2/split_dict.pt # Download with `wget https://storage.valencelabs.com/datasets-public-research/PCQM4M/cxsmiles/split_dict.pt`
split_val: 0.1
split_test: 0.1
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
mask_nan: 0
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
conformer_property_list: [positions_3d]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
node_laplacian_eigvec:
pos_type: laplacian_eigvec
pos_level: node
num_pos: 8
normalization: "none"
disconnected_comp: True
node_laplacian_eigval:
pos_type: laplacian_eigval
pos_level: node
num_pos: 8
normalization: "none"
disconnected_comp: True
rw_return_probs:
pos_type: rw_return_probs
pos_level: node
ksteps: [4, 8]
nodepair_rw_transition_probs:
pos_type: rw_transition_probs
pos_level: edge
ksteps: [2, 4]
nodepair_rw_return_probs:
pos_type: rw_return_probs
pos_level: nodepair
ksteps: [4]
electrostatic:
pos_type: electrostatic
pos_level: node
edge_commute:
pos_type: commute
pos_level: edge
nodepair_graphormer:
pos_type: graphormer
pos_level: nodepair
# Data handling-related
batch_size_training: 64
batch_size_inference: 16
num_workers: 0 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
# ipu_dataloader_training_opts:
# mode: async
# max_num_nodes_per_graph: 20 # train max nodes: 20, max_edges: 54
# max_num_edges_per_graph: 60
# ipu_dataloader_inference_opts:
# mode: async
# max_num_nodes_per_graph: 20 # valid max nodes: 51, max_edges: 118
# max_num_edges_per_graph: 120
# # test-dev max nodes: 50, max_edges: 116
# # test-challenge max nodes: 51, max_edges: 106
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 64
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.1
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 16
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
pe_encoders:
out_dim: &pe_out_dim 32
edge_out_dim: &edge_pe_out_dim 16
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders:
emb_la_pos:
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
emb_rwse:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
emb_electrostatic:
encoder_type: "mlp"
input_keys: ["electrostatic"]
output_keys: ["feat"]
hidden_dim: 32
num_layers: 1
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
emb_edge_rwse:
encoder_type: "mlp"
input_keys: ["edge_rw_transition_probs"]
output_keys: ["edge_feat"]
hidden_dim: 32
num_layers: 1
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
emb_edge_pes:
encoder_type: "cat_mlp"
input_keys: ["edge_rw_transition_probs", "edge_commute"]
output_keys: ["edge_feat"]
hidden_dim: 32
num_layers: 1
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
gaussian_pos:
encoder_type: "gaussian_kernel"
input_keys: ["positions_3d"]
output_keys: ["feat", "nodepair_gaussian_bias_3d"]
num_heads: &num_heads 2
num_layers: 2
embed_dim: *pe_out_dim
use_input_keys_prefix: False
gnn: # Set as null to avoid a post-nn network
out_dim: 32
hidden_dims: 32
depth: 4
activation: gelu
last_activation: none
dropout: 0.0
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
pooling: [sum]
virtual_node: 'none'
layer_type: 'pyg:gps' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
layer_kwargs: # Parameters for the model itself. You could define dropout_attn: 0.1
node_residual: false
mpnn_type: 'pyg:mpnnplus'
mpnn_kwargs:
in_dim: 32
out_dim: 32
in_dim_edges: 16
out_dim_edges: 16
attn_type: "full-attention" # "full-attention", "none"
# biased_attention: false
attn_kwargs:
num_heads: *num_heads
biased_attention_key: nodepair_gaussian_bias_3d
post_nn: null
task_heads:
homolumo:
out_dim: 1
hidden_dims: 256
depth: 2 # Not needed if we have hidden_dims
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
homolumo: ["mae", "pearsonr"]
loss_fun:
homolumo: mse_ipu
random_seed: *seed
optim_kwargs:
lr: 4.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 5
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor homolumo/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore: ignore nan values from loss
flag_kwargs:
n_steps: 0 # 1
alpha: 0.0 # 0.01
# Task-specific
metrics:
homolumo:
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
trainer:
logger:
save_dir: logs/PCQMv2
name: *name
project: PCQMv2_mpnn
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/PCMQv2/
filename: *name
#monitor: *monitor
#mode: *mode
save_top_k: 1
every_n_epochs: 100
trainer:
precision: 32
max_epochs: *max_epochs
min_epochs: 1
accumulate_grad_batches: 2
check_val_every_n_epoch: 20
================================================
FILE: expts/configs/config_mpnn_10M_b3lyp.yaml
================================================
# Testing the mpnn only model with the b3lyp dataset on IPU.
constants:
name: &name b3lyp_mpnn_4layer
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 20 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 60
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 16 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 120
# Data handling-related
batch_size_training: 64
batch_size_inference: 16
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 16
accumulate_grad_batches: 4
ipu_config:
- deviceIterations(20) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
ipu_inference_config: # Optional. If not provided, same as `ipu_config`
- deviceIterations(80) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 256
# batch_size_inference: 64
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
betagap:
df: null
task_level: "graph"
df_path: graphium/data/b3lyp/b3lyp_mini.parquet #graphium/data/b3lyp/b3lyp_mini.parquet
# wget https://storage.valencelabs.com/graphium/datasets/b3lyp/b3lyp_mini.parquet
# or set path as https://storage.valencelabs.com/graphium/datasets/b3lyp/b3lyp_mini.parquet directly
smiles_col: "smiles"
label_cols: ["beta_gap"]
# sample_size: 30000 # use sample_size for test
split_val: 0.1
split_test: 0.1
alphagap:
df: null
task_level: "graph"
df_path: graphium/data/b3lyp/b3lyp_mini.parquet #graphium/data/b3lyp/b3lyp_mini.parquet
# wget https://storage.valencelabs.com/graphium/datasets/b3lyp/b3lyp_mini.parquet
# or set path as https://storage.valencelabs.com/graphium/datasets/b3lyp/b3lyp_mini.parquet directly
smiles_col: "smiles"
label_cols: ["alpha_gap"]
# sample_size: 30000 # use sample_size for test
# splits_path: graphium/data/PCQM4M/split_dict_v2.pt # Download with `wget https://storage.valencelabs.com/datasets-public-research/PCQM4M/cxsmiles/split_dict_v2.pt`
split_val: 0.1
split_test: 0.1
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/b3lyp/"
dataloading_from: ram
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 0 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 64
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.1
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 16
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
out_dim: 32
hidden_dims: 32
depth: 4
activation: gelu
last_activation: none
dropout: 0.0
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gps' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
layer_kwargs: # Parameters for the model itself. You could define dropout_attn: 0.1
node_residual: false
mpnn_type: 'pyg:mpnnplus'
mpnn_kwargs:
in_dim: 32
out_dim: 32
in_dim_edges: 16
out_dim_edges: 16
attn_type: "none" # "full-attention", "none"
# biased_attention: false
attn_kwargs: null
graph_output_nn:
graph:
pooling: [sum]
out_dim: 256
hidden_dims: 256
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
alphagap:
task_level: graph
out_dim: 1
hidden_dims: 256
depth: 2 # Not needed if we have hidden_dims
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
betagap:
task_level: graph
out_dim: 1
hidden_dims: 256
depth: 2 # Not needed if we have hidden_dims
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
alphagap: ["mae", "pearsonr"]
betagap: ["mae", "pearsonr"]
loss_fun:
alphagap: mse_ipu
betagap: mse_ipu
random_seed: *seed
optim_kwargs:
lr: 4.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor homolumo/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore: ignore nan values from loss
flag_kwargs:
n_steps: 0 # 1
alpha: 0.0 # 0.01
# Task-specific
metrics:
alphagap: &alpha_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
betagap: *alpha_metrics
trainer:
logger:
save_dir: logs/b3lyp
name: *name
project: PCQMv2_mpnn
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/b3lyp/
filename: *name
#monitor: *monitor
#mode: *mode
save_top_k: 1
every_n_epochs: 100
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/configs/config_mpnn_pcqm4m.yaml
================================================
# Testing the mpnn only model with the PCQMv2 dataset on IPU.
constants:
name: &name pcqm4mv2_mpnn_4layer
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: cpu # cpu or ipu or gpu
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
homolumo:
df: null
task_level: "graph"
df_path: graphium/data/PCQM4M/pcqm4mv2-20k.csv
# wget https://storage.valencelabs.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2-20k.csv
# or set path as https://storage.valencelabs.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2-20k.csv directly
smiles_col: "cxsmiles"
label_cols: ["homo_lumo_gap"]
# sample_size: 6000 # use sample_size for test
splits_path: graphium/data/PCQM4Mv2/split_dict_v2.pt # Download with `wget https://storage.valencelabs.com/datasets-public-research/PCQM4M/cxsmiles/split_dict_v2.pt`
# graphium/data/PCQM4Mv2/split_dict.pt
# graphium/data/PCQM4Mv2/pcqm4m_split.csv
split_names: ["train", "valid", "test-dev"]
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 20
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "graphium/data/PCQM4Mv2/"
dataloading_from: ram
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
la_pos: &pos_enc
pos_type: laplacian_eigvec_eigval #laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_type: rwse
ksteps: 16
# pos_encoding_as_directions: *pos_enc # Only for DGN or directional pooling
# Data handling-related
batch_size_training: 64
batch_size_inference: 16
num_workers: 40 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
# ipu_dataloader_training_opts:
# mode: async
# max_num_nodes_per_graph: 20 # train max nodes: 20, max_edges: 54
# max_num_edges_per_graph: 60
# ipu_dataloader_inference_opts:
# mode: async
# max_num_nodes_per_graph: 20 # valid max nodes: 51, max_edges: 118
# max_num_edges_per_graph: 120
# # test-dev max nodes: 50, max_edges: 116
# # test-challenge max nodes: 51, max_edges: 106
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 64
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.1
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 16
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
out_dim: 32
hidden_dims: 32
depth: 4
activation: gelu
last_activation: none
dropout: 0.0
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
pooling: [sum]
virtual_node: 'none'
layer_type: 'pyg:gps' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
layer_kwargs: # Parameters for the model itself. You could define dropout_attn: 0.1
node_residual: false
mpnn_type: 'pyg:mpnnplus'
mpnn_kwargs:
in_dim: 32
out_dim: 32
in_dim_edges: 16
out_dim_edges: 16
attn_type: "none" # "full-attention", "none"
# biased_attention: false
attn_kwargs: null
post_nn: null
task_heads:
homolumo:
out_dim: 1
hidden_dims: 256
depth: 2 # Not needed if we have hidden_dims
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
homolumo: ["mae", "pearsonr"]
loss_fun:
homolumo: mse_ipu
random_seed: *seed
optim_kwargs:
lr: 4.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor homolumo/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore: ignore nan values from loss
flag_kwargs:
n_steps: 0 # 1
alpha: 0.0 # 0.01
# Task-specific
metrics:
homolumo:
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
trainer:
logger:
save_dir: logs/PCQMv2
name: *name
project: PCQMv2_mpnn
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/PCMQv2/
filename: *name
#monitor: *monitor
#mode: *mode
save_top_k: 1
every_n_epochs: 100
trainer:
precision: 32
max_epochs: *max_epochs
min_epochs: 1
accumulate_grad_batches: 2
check_val_every_n_epoch: 20
================================================
FILE: expts/data/micro_zinc_splits.csv
================================================
train,val,test
957,392.0,654.0
223,588.0,430.0
916,538.0,317.0
410,903.0,629.0
287,923.0,927.0
496,677.0,641.0
456,144.0,518.0
589,83.0,645.0
151,385.0,649.0
10,996.0,504.0
7,506.0,716.0
180,205.0,557.0
769,161.0,931.0
389,746.0,804.0
304,99.0,30.0
867,444.0,836.0
177,814.0,469.0
79,984.0,448.0
937,292.0,863.0
494,377.0,412.0
771,986.0,61.0
587,864.0,805.0
552,966.0,326.0
881,974.0,520.0
808,106.0,263.0
186,670.0,874.0
925,238.0,920.0
268,995.0,226.0
108,387.0,749.0
717,696.0,724.0
37,871.0,583.0
47,515.0,928.0
258,596.0,117.0
173,801.0,459.0
979,301.0,633.0
846,420.0,239.0
578,688.0,421.0
320,572.0,484.0
69,451.0,399.0
221,318.0,886.0
96,87.0,133.0
732,672.0,759.0
303,944.0,774.0
48,891.0,523.0
455,311.0,103.0
140,297.0,610.0
734,776.0,743.0
138,972.0,104.0
699,548.0,574.0
706,152.0,204.0
713,418.0,822.0
705,214.0,989.0
230,254.0,620.0
134,709.0,648.0
551,164.0,66.0
92,835.0,75.0
890,395.0,929.0
185,539.0,184.0
691,445.0,712.0
31,573.0,98.0
195,894.0,997.0
630,852.0,94.0
447,893.0,959.0
967,673.0,728.0
280,819.0,689.0
813,829.0,323.0
788,711.0,658.0
432,750.0,722.0
876,153.0,305.0
766,524.0,600.0
764,800.0,23.0
135,279.0,76.0
12,700.0,73.0
136,71.0,740.0
406,942.0,482.0
267,935.0,773.0
234,854.0,208.0
531,59.0,985.0
274,124.0,812.0
938,913.0,122.0
512,591.0,398.0
353,693.0,726.0
127,790.0,380.0
737,54.0,293.0
401,283.0,429.0
121,492.0,747.0
563,702.0,906.0
694,954.0,316.0
176,360.0,842.0
25,624.0,145.0
187,898.0,302.0
434,719.0,781.0
828,628.0,65.0
113,232.0,388.0
199,537.0,314.0
227,60.0,825.0
411,206.0,536.0
590,567.0,787.0
581,878.0,495.0
697,383.0,89.0
147,683.0,331.0
798,806.0,338.0
918,49.0,656.0
751,993.0,350.0
415,85.0,485.0
692,150.0,840.0
120,64.0,760.0
686,980.0,41.0
466,356.0,261.0
517,735.0,855.0
101,486.0,439.0
528,883.0,503.0
845,18.0,38.0
644,879.0,202.0
698,478.0,519.0
501,336.0,625.0
86,40.0,105.0
224,472.0,763.0
422,522.0,493.0
15,276.0,405.0
767,397.0,950.0
193,192.0,156.0
437,889.0,351.0
376,638.0,603.0
910,549.0,461.0
288,680.0,919.0
171,381.0,765.0
955,655.0,940.0
932,490.0,352.0
623,132.0,917.0
848,17.0,142.0
592,643.0,870.0
155,24.0,508.0
457,157.0,561.0
216,229.0,174.0
250,453.0,951.0
483,857.0,550.0
404,463.0,601.0
282,626.0,904.0
529,646.0,831.0
513,477.0,668.0
189,211.0,880.0
441,307.0,824.0
576,408.0,613.0
2,428.0,657.0
744,615.0,241.0
111,378.0,402.0
718,100.0,873.0
391,499.0,16.0
322,785.0,755.0
566,675.0,584.0
452,794.0,650.0
860,129.0,669.0
952,729.0,598.0
921,290.0,468.0
875,337.0,571.0
42,602.0,660.0
128,553.0,443.0
172,431.0,219.0
908,26.0,577.0
21,994.0,489.0
895,95.0,58.0
414,479.0,373.0
505,197.0,119.0
809,278.0,652.0
371,249.0,137.0
899,442.0,868.0
6,594.0,826.0
960,269.0,666.0
964,324.0,341.0
84,887.0,252.0
329,502.0,470.0
344,662.0,858.0
102,27.0,262.0
933,123.0,953.0
608,877.0,667.0
681,862.0,359.0
542,924.0,595.0
977,945.0,419.0
582,514.0,902.0
568,256.0,1.0
464,321.0,939.0
355,343.0,833.0
454,527.0,163.0
115,400.0,67.0
36,14.0,168.0
190,295.0,851.0
799,973.0,721.0
642,965.0,74.0
285,62.0,756.0
328,110.0,245.0
141,213.0,970.0
237,555.0,786.0
130,333.0,754.0
81,704.0,362.0
8,481.0,348.0
210,792.0,827.0
118,546.0,762.0
963,196.0,375.0
139,260.0,334.0
992,,
13,,
417,,
632,,
427,,
28,,
560,,
525,,
509,,
723,,
264,,
653,,
181,,
823,,
884,,
91,,
471,,
535,,
782,,
914,,
257,,
143,,
179,,
222,,
255,,
897,,
20,,
146,,
349,,
922,,
565,,
170,,
200,,
497,,
365,,
820,,
607,,
661,,
242,,
160,,
235,,
236,,
68,,
701,,
43,,
948,,
90,,
983,,
240,,
310,,
46,,
166,,
57,,
838,,
346,,
892,,
720,,
690,,
271,,
856,,
97,,
299,,
423,,
209,,
384,,
357,,
810,,
265,,
943,,
541,,
165,,
768,,
659,,
368,,
225,,
821,,
604,,
631,,
22,,
253,,
714,,
544,,
159,,
148,,
956,,
796,,
476,,
33,,
684,,
178,,
9,,
207,,
545,,
532,,
635,,
149,,
296,,
533,,
715,,
473,,
850,,
559,,
647,,
298,,
990,,
982,,
640,,
926,,
498,,
270,,
738,,
912,,
736,,
424,,
272,,
56,,
651,,
366,,
770,,
784,,
637,,
911,,
599,,
802,,
885,,
78,,
586,,
968,,
830,,
843,,
866,,
731,,
999,,
319,,
50,,
361,,
175,,
861,,
987,,
289,,
450,,
394,,
832,,
564,,
687,,
909,,
363,,
11,,
45,,
248,,
543,,
679,,
34,,
480,,
158,,
674,,
547,,
347,,
791,,
569,,
300,,
425,,
981,,
975,,
386,,
507,,
847,,
218,,
627,,
154,,
436,,
286,,
930,,
888,,
35,,
685,,
72,,
570,,
109,,
217,,
562,,
795,,
962,,
82,,
775,,
882,,
580,,
460,,
554,,
703,,
407,,
526,,
511,,
93,,
621,,
4,,
752,,
131,,
315,,
370,,
978,,
70,,
593,,
682,,
947,,
818,,
339,,
676,,
971,,
345,,
797,,
3,,
616,,
663,,
390,,
725,,
462,,
281,,
664,,
745,,
0,,
844,,
617,,
946,,
294,,
859,,
748,,
915,,
998,,
742,,
530,,
116,,
309,,
741,,
988,,
275,,
606,,
900,,
374,,
367,,
107,,
312,,
639,,
182,,
739,,
934,,
39,,
811,,
5,,
63,,
487,,
80,,
753,,
51,,
678,,
335,,
841,,
426,,
284,,
29,,
251,,
611,,
585,,
396,,
491,,
340,,
815,,
488,,
793,,
358,,
77,,
259,,
44,,
228,,
243,,
125,,
761,,
516,,
958,,
558,,
665,,
540,,
244,,
215,,
277,,
521,,
403,,
198,,
730,,
413,,
332,,
710,,
458,,
500,,
579,,
112,,
865,,
247,,
52,,
695,,
758,,
614,,
382,,
619,,
372,,
393,,
126,,
961,,
449,,
839,,
634,,
816,,
88,,
169,,
907,,
597,,
707,,
789,,
474,,
556,,
618,,
727,,
778,,
575,,
32,,
991,,
636,,
191,,
114,,
803,,
233,,
896,,
246,,
446,,
379,,
612,,
733,,
779,,
905,,
183,,
475,,
435,,
409,,
837,,
949,,
273,,
212,,
941,,
433,,
467,,
780,,
465,,
510,,
220,,
438,,
708,,
817,,
416,,
325,,
969,,
188,,
872,,
55,,
313,,
291,,
53,,
194,,
369,,
849,,
869,,
853,,
772,,
201,,
203,,
777,,
327,,
807,,
976,,
231,,
783,,
167,,
364,,
306,,
342,,
266,,
609,,
901,,
162,,
534,,
440,,
834,,
671,,
330,,
308,,
19,,
936,,
354,,
757,,
622,,
605,,
================================================
FILE: expts/data/tiny_zinc_splits.csv
================================================
train,val,test
3,61.0,65.0
21,64.0,24.0
26,89.0,90.0
42,28.0,37.0
23,95.0,83.0
87,46.0,63.0
68,30.0,58.0
99,91.0,85.0
77,48.0,70.0
41,43.0,18.0
0,7.0,81.0
36,72.0,94.0
29,38.0,13.0
11,47.0,1.0
79,16.0,33.0
82,62.0,14.0
27,8.0,97.0
51,76.0,2.0
25,78.0,59.0
88,6.0,44.0
96,,
17,,
20,,
45,,
67,,
12,,
54,,
49,,
32,,
69,,
60,,
55,,
57,,
35,,
53,,
92,,
4,,
73,,
75,,
9,,
71,,
84,,
80,,
15,,
39,,
50,,
86,,
10,,
5,,
34,,
22,,
56,,
66,,
31,,
74,,
52,,
19,,
40,,
98,,
93,,
================================================
FILE: expts/dataset_benchmark.py
================================================
import os
from os.path import dirname, abspath
import yaml
from omegaconf import DictConfig
from datetime import datetime
from copy import deepcopy
import graphium
from graphium.config._loader import load_datamodule, load_accelerator
import time
from typing import Optional, List, Sequence
import wandb
import statistics
import tqdm
import torch
import numpy as np
# Set up the working directory
MAIN_DIR = dirname(dirname(abspath(graphium.__file__)))
os.chdir(MAIN_DIR)
# CONFIG_FILE = "expts/neurips2023_configs/debug/config_large_gcn_debug.yaml"
CONFIG_FILE = "expts/neurips2023_configs/config_large_gcn.yaml"
# CONFIG_FILE = "expts/configs/config_pcqmv2_mpnn.yaml"
# CONFIG_FILE = "expts/configs/config_ipu_qm9.yaml"
def benchmark(fn, *args, message="", log2wandb=False, **kwargs):
start = time.time()
value = fn(*args, **kwargs)
duration = time.time() - start
print(f"{message} {duration:.3f} secs")
if log2wandb:
wandb.log({message: duration})
return value
def benchmark_dataloader(dataloader, name, n_epochs=5, log2wandb=False):
print(f"length of {name} dataloader: {len(dataloader)}")
epoch_times = [0] * n_epochs
tputs = [0] * n_epochs
n_batches = [0] * n_epochs
n_graphs = [0] * n_epochs
n_nodes = [0] * n_epochs
n_edges = [0] * n_epochs
for i in range(n_epochs):
start = time.time()
for data in tqdm.tqdm(dataloader):
n_batches[i] += 1
n_graphs[i] += torch.sum(torch.max(data["features"]["batch"], dim=-1).values).item()
n_nodes[i] += np.prod(data["features"]["batch"].shape)
n_edges[i] += np.prod(data["features"]["edge_weight"].shape)
epoch_times[i] = time.time() - start
tputs[i] = n_graphs[i] / epoch_times[i]
average_tput = statistics.mean(tputs)
print(f"{name} dataloader average tput {average_tput}")
print(f"{name} dataloader tputs per epoch {tputs}")
print(f"{name} dataloader epoch times {epoch_times}")
print(f"{name} dataloader total batches per epoch {n_batches}")
print(f"{name} dataloader total graphs per epoch {n_graphs}")
print(f"{name} dataloader total nodes per epoch {n_nodes}")
print(f"{name} dataloader total edges per epoch {n_edges}")
if log2wandb:
wandb.log({"average tput": average_tput})
for i in range(n_epochs):
d = {
"epoch": i,
"tput per epoch": tputs[i],
"epoch times ": epoch_times[i],
"batches per epoch": n_batches[i],
"graphs per epoch": n_graphs[i],
"nodes per epoch": n_nodes[i],
"edges per epoch": n_edges[i],
}
print(d)
wandb.log(d)
def main(
cfg: DictConfig,
stages: Optional[Sequence[str]] = None,
run_name: str = "dataset_benchmark",
add_date_time: bool = True,
log2wandb: bool = False,
) -> None:
if add_date_time:
run_name += "_" + datetime.now().strftime("%d.%m.%Y_%H.%M.%S")
if log2wandb:
wandb.init(project="multitask-gnn", name=run_name, config=cfg)
cfg = deepcopy(cfg)
cfg, accelerator_type = load_accelerator(cfg)
# Load and initialize the dataset
datamodule = benchmark(
load_datamodule, cfg, accelerator_type, message="Load duration", log2wandb=log2wandb
)
benchmark(datamodule.prepare_data, message="Prepare duration", log2wandb=log2wandb)
if False: # stages is not None:
for stage in stages:
benchmark(datamodule.setup, stage, message=f"Setup {stage} duration", log2wandb=log2wandb)
else:
benchmark(datamodule.setup, message=f"Setup duration", log2wandb=log2wandb)
if stages is None or {"train", "fit"}.intersection(stages):
dataloader = datamodule.train_dataloader()
benchmark_dataloader(
dataloader, name="train", n_epochs=cfg["trainer"]["trainer"]["max_epochs"], log2wandb=log2wandb
)
if stages is None or {"val", "valid", "validation"}.intersection(stages):
dataloader = datamodule.val_dataloader()
benchmark_dataloader(
dataloader,
name="validation",
n_epochs=cfg["trainer"]["trainer"]["max_epochs"],
log2wandb=log2wandb,
)
if stages is None or {"test", "testing"}.intersection(stages):
dataloader = datamodule.test_dataloader()
benchmark_dataloader(
dataloader, name="testing", n_epochs=cfg["trainer"]["trainer"]["max_epochs"], log2wandb=log2wandb
)
if __name__ == "__main__":
with open(os.path.join(MAIN_DIR, CONFIG_FILE), "r") as f:
cfg = yaml.safe_load(f)
main(cfg, stages=["train"], log2wandb=True)
================================================
FILE: expts/debug_yaml.py
================================================
# test_yaml.py
import yaml
CONFIG_FILE = "neurips2023_configs/config_small_mpnn.yaml"
import re
from collections import defaultdict
# def get_anchors_and_aliases(filepath):
# anchors = defaultdict(list)
# current_level = {}
# with open(filepath, "r") as file:
# for line in file:
# indent = len(line) - len(line.lstrip(' '))
# key_match = re.search(r'(\w+):', line)
# anchor_match = re.search(r'&(\w+)', line)
# alias_match = re.search(r'\*(\w+)', line)
# if key_match:
# key = key_match.group(1)
# # Compute the full path of the current key.
# full_path = '.'.join([current_level[i] for i in sorted(current_level.keys()) if i < indent] + [key])
# current_level[indent] = key
# # Remove any keys that are indented more than the current line.
# keys_to_remove = [i for i in current_level if i > indent]
# for i in keys_to_remove:
# del current_level[i]
# else:
# full_path = '.'.join([current_level[i] for i in sorted(current_level.keys())])
# if anchor_match:
# anchor = anchor_match.group(1)
# anchors[anchor].append(full_path)
# if alias_match:
# alias = alias_match.group(1)
# anchors[alias].append(full_path)
# return {k: v for k, v in anchors.items() if len(v) > 1}
def get_anchors_and_aliases(filepath):
anchors = defaultdict(list)
current_level = {}
anchor_to_path = {}
with open(filepath, "r") as file:
for line in file:
indent = len(line) - len(line.lstrip(" "))
key_match = re.search(r"(\w+):", line)
anchor_match = re.search(r"&(\w+)", line)
alias_match = re.search(r"\*(\w+)", line)
if key_match:
key = key_match.group(1)
# Compute the full path of the current key.
full_path = ".".join(
[current_level[i] for i in sorted(current_level.keys()) if i < indent] + [key]
)
current_level[indent] = key
# Remove any keys that are indented more than the current line.
keys_to_remove = [i for i in current_level if i > indent]
for i in keys_to_remove:
del current_level[i]
else:
full_path = ".".join([current_level[i] for i in sorted(current_level.keys())])
if anchor_match:
anchor = anchor_match.group(1)
anchor_to_path[anchor] = full_path
if alias_match:
alias = alias_match.group(1)
if alias in anchor_to_path:
anchors[anchor_to_path[alias]].append(full_path)
return anchors
refs = get_anchors_and_aliases(CONFIG_FILE)
import ipdb
ipdb.set_trace()
pass
================================================
FILE: expts/hydra-configs/README.md
================================================
# Configuring Graphium with Hydra
This document provides users with a point of entry to composing configs in Graphium. As a flexible library with many features, configuration is an important part of Graphium. To make configurations as reusable as possible while providing maximum flexibility, we integrated Graphium with `hydra`. Our config structure is designed to make the following functionality as accessible as possible:
- Switching between **accelerators** (CPU, GPU and IPU)
- **Benchmarking** different models on the same dataset
- **Fine-tuning** a pre-trained model on a new dataset
In what follows, we describe how each of the above functionality is achieved and how users can benefit from this design to achieve the most with Graphium with as little configuration as possible.
## Accelerators
With Graphium supporting CPU, GPU and IPU hardware, easily switching between these accelerators is pre-configured. General, accelerator-specific configs are specified under `accelerator/`, whereas experiment-specific differences between the accelerators are specialized under `training/accelerator`.
## Benchmarking
Benchmarking multiple models on the same datasets and tasks requires us to easily switch between model configurations without redefining major parts of the architecture, task heads, featurization, metrics, predictor, etc. For example, when changing from a GCN to a GIN model, a simple switch of `architecture.gnn.layer_type: 'pyg:gin'` might suffice. Hence, we abstract the `model` configs under `model/` where such model configurations can be specified.
In addition, switching models may have implications on configs specific to your current experiment, such as the name of the run or the directory to which model checkpoints are written. To enable such overrides, we can utilize `hydra` [specializations](https://hydra.cc/docs/patterns/specializing_config/). For example, for our ToyMix dataset, we specify the layer type under `model/[model_name].yaml`, e.g., for the GCN layer,
```yaml
# @package _global_
architecture:
gnn:
layer_type: 'pyg:gcn'
```
and set experiment-related parameters in `training/model/toymix_[model_name].yaml` as a specialization, e.g., for the GIN layer,
```yaml
# @package _global_
constants:
name: neurips2023_small_data_gin
...
trainer:
model_checkpoint:
dirpath: models_checkpoints/neurips2023-small-gin/${now:%Y-%m-%d_%H-%M-%S}/
```
We can now utilize `hydra` to e.g., run a sweep over our models on the ToyMix dataset via
```bash
graphium-train -m model=gcn,gin
```
where the ToyMix dataset is pre-configured in `main.yaml`. Read on to find out how to define new datasets and architectures for pre-training and fine-tuning.
## Pre-training / Fine-tuning
Say you trained a model with the following command:
```bash
graphium-train --config-name "main"
```
Fine-tuning this model on downstream tasks is then as simple as:
```bash
graphium-train --config-name "main" +finetuning=...
```
From a configuration point-of-view, fine-tuning requires us to load a pre-trained model and override part of the training configuration to fine-tune it on downstream tasks. To allow a quick switch between pre-training and fine-tuning, by default, we configure models and the corresponding tasks in a separate manner. More specifically,
- under `architecture/` we store architecture related configurations such as the definition of the GNN/Transformer layers or positional/structural encoders
- under `tasks/` we store configurations specific to one task set, such as the multi-task dataset ToyMix
- under `tasks/task_heads` we specify the task-specific heads to add on top of the base architecture.
- under `tasks/loss_metrics_datamodule` we specify the data-module to use and the task-specific loss functions and metrics
- under `training/` we store configurations specific to training models which could be different for each combination of `architecture` and `tasks`
- under `finetuning/` we store configurations with overrides
Since architecture and tasks are logically separated it now becomes very easy to e.g., use an existing architecture backbone on a new set of tasks or a new dataset altogether. Additionally, separating training allows us to specify different training parameters for e.g., pre-training and fine-tuning of the same architecture and task set.
We will now detail how you can add new architectures, tasks and training configurations.
### Adding an architecture
The architecture config consists of specifications of the neural network components, including encoders, under the config key `architecture` and the featurization, containing the positional/structural information that is to be extracted from the data.
To add a new architecture, create a file `architecture/my_architecture.yaml` with the following information specified:
```yaml
# @package _global_
architecture:
model_type: FullGraphMultiTaskNetwork # for example
pre_nn:
...
pre_nn_edges:
...
pe_encoders:
encoders: # your encoders
...
gnn: # your GNN definition
...
graph_output_nn: # output NNs for different levels such as graph, node, etc.
graph:
...
node:
...
...
datamodule:
module_type: "MultitaskFromSmilesDataModule"
args: # Make sure to not specify anything task-specific here
...
featurization:
...
```
You can then select your new architecture during training, e.g., by running
```bash
graphium-train architecture=my_architecture
```
### Adding tasks
The task set config consists of specifications for the task head neural nets under the config key `architecture.task_heads`; if required, any task-specific arguments to the datamodule you use, e.g., `datamodule.args.task_specfic_args` when using the `MultitaskFromSmilesDataModule` datamodule; the per-task metrics under the config key `metrics.[task]` where `[task]` matches the tasks specified under `architecture.task_heads`; the per-task configs of the `predictor` module, as well as the loss functions of the task set under the config key `predictor.loss_fun`.
To add a new task set, create a file `tasks/my_tasks.yaml` with the following information specified:
```yaml
# @package _global_
architecture:
task_heads:
task1:
...
task2:
...
datamodule: # optional, depends on your concrete datamodule class. Here: "MultitaskFromSmilesDataModule"
args:
task_specific_args:
task1:
...
task2:
...
metrics:
task1:
...
task2:
...
predictor:
metrics_on_progress_bar:
task1:
task2:
loss_fun: ... # your loss functions for the multi-tasking
```
You can then select your new dataset during training, e.g., by running
```bash
graphium-train tasks=my_tasks
```
### Adding training configs
The training configs consist of specifications to the `predictor` and `trainer` modules.
To add new training configs, create a file `training/my_training.yaml` with the following information specified:
```yaml
# @package _global_
predictor:
optim_kwargs:
lr: 4.e-5
torch_scheduler_kwargs: # example
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 10
verbose: False
scheduler_kwargs:
...
trainer:
...
trainer: # example
precision: 16
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
```
================================================
FILE: expts/hydra-configs/__init__.py
================================================
================================================
FILE: expts/hydra-configs/accelerator/cpu.yaml
================================================
type: cpu
================================================
FILE: expts/hydra-configs/accelerator/gpu.yaml
================================================
type: gpu
================================================
FILE: expts/hydra-configs/accelerator/ipu.yaml
================================================
type: ipu
ipu_config:
- deviceIterations(60) # IPU would require large batches to be ready for the model.
# 60 for PCQM4mv2
# 30 for largemix
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 96)
- Precision.enableStochasticRounding(True)
ipu_inference_config:
# set device iteration and replication factor to 1 during inference
# gradient accumulation was set to 1 in the code
- deviceIterations(1)
- replicationFactor(1)
- Precision.enableStochasticRounding(False)
================================================
FILE: expts/hydra-configs/accelerator/ipu_pipeline.yaml
================================================
type: ipu
ipu_config:
- deviceIterations(60) # IPU would require large batches to be ready for the model.
# 60 for PCQM4mv2
# 30 for largemix
- replicationFactor(4)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 96)
- Precision.enableStochasticRounding(True)
ipu_inference_config:
# set device iteration and replication factor to 1 during inference
# gradient accumulation was set to 1 in the code
- deviceIterations(60)
- replicationFactor(1)
- Precision.enableStochasticRounding(False)
accelerator_kwargs:
_accelerator: "ipu"
gnn_layers_per_ipu: [4, 4, 4, 4]
================================================
FILE: expts/hydra-configs/architecture/largemix.yaml
================================================
# @package _global_
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.1
normalization: ${constants.norm}
last_normalization: ${constants.norm}
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: ${constants.norm} #"batch_norm" or "layer_norm"
first_normalization: ${constants.norm} #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 768
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: ${constants.norm}
last_normalization: ${constants.norm}
residual_type: simple
virtual_node: 'none'
graph_output_nn:
graph:
pooling: [sum]
out_dim: 1024
hidden_dims: *gnn_dim
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: ${constants.norm}
last_normalization: "none"
residual_type: none
node:
pooling: [sum]
out_dim: 64
hidden_dims: *gnn_dim
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: ${constants.norm}
last_normalization: "none"
residual_type: none
datamodule:
module_type: "MultitaskFromSmilesDataModule"
args:
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 20
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: ${constants.datacache_path}
dataloading_from: "disk"
num_workers: 20 # -1 to use all
persistent_workers: True
featurization:
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
================================================
FILE: expts/hydra-configs/architecture/pcqm4m.yaml
================================================
# @package _global_
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 256
hidden_dims: 1024
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 128
hidden_dims: 512
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
out_dim: 256
hidden_dims: 256
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
graph_output_nn:
graph:
pooling: [sum]
out_dim: 256
hidden_dims: 256
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: ${constants.datacache_path}
num_workers: 40 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
================================================
FILE: expts/hydra-configs/architecture/toymix.yaml
================================================
# @package _global_
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn:
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: 0.18
normalization: layer_norm
last_normalization: ${architecture.pre_nn.normalization}
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 96
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: ${architecture.pre_nn.normalization}
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
layer_kwargs: null # Parameters for the model itself. You could define dropout_attn: 0.1
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: "none"
residual_type: none
datamodule:
module_type: "MultitaskFromSmilesDataModule"
args:
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: ${constants.datacache_path}
dataloading_from: ram
num_workers: 30 # -1 to use all
persistent_workers: False
featurization:
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features:
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # normalization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # normalization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
================================================
FILE: expts/hydra-configs/experiment/toymix_mpnn.yaml
================================================
# @package _global_
constants:
name: neurips2023_small_data_mpnn
entity: "multitask-gnn"
seed: 42
max_epochs: 100
data_dir: expts/data/neurips2023/small-dataset
raise_train_error: true
trainer:
model_checkpoint:
dirpath: models_checkpoints/neurips2023-small-mpnn/${now:%Y-%m-%d_%H-%M-%S}/
================================================
FILE: expts/hydra-configs/finetuning/admet.yaml
================================================
# @package _global_
# == Fine-tuning configs in Graphium ==
#
# A fine-tuning config is a appendum to a (pre-)training config.
# Since many things (e.g. the architecture), will stay constant between (pre-)training and fine-tuning,
# this config should be as minimal as possible to avoid unnecessary duplication. It only specifies
# what to override with regards to the config used for (pre-)training.
#
# Given the following training command:
# >>> graphium-train --cfg /path/to/train.yaml
#
# Fine-tuning now is as easy as:
# >>> graphium-train --cfg /path/to/train.yaml +finetune=admet
#
# NOTE: This config can be used for each of the benchmarks in the TDC ADMET benchmark suite.
# The only thing that needs to be changed is the `constants.task` key.
## == Overrides ==
defaults:
# This file contains all metrics and loss function info for all ADMET tasks.
# This config is filtered at runtime based on the `constants.task` key.
- override /tasks/loss_metrics_datamodule: admet
constants:
# For now, we assume a model is always fine-tuned on a single task at a time.
# You can override this value with any of the benchmark names in the TDC benchmark suite.
# See also https://tdcommons.ai/benchmark/admet_group/overview/
task: lipophilicity_astrazeneca
name: finetuning_${constants.task}_gcn
wandb:
name: ${constants.name}
project: ${constants.task}
entity: multitask-gnn
save_dir: logs/${constants.task}
seed: 42
max_epochs: 100
data_dir: expts/data/admet/${constants.task}
raise_train_error: true
predictor:
optim_kwargs:
lr: 4.e-5
# == Fine-tuning config ==
finetuning:
# For now, we assume a model is always fine-tuned on a single task at a time.
# You can override this value with any of the benchmark names in the TDC benchmark suite.
# See also https://tdcommons.ai/benchmark/admet_group/overview/
task: ${constants.task}
level: graph
# Pretrained model
pretrained_model: dummy-pretrained-model
finetuning_module: task_heads # gnn
sub_module_from_pretrained: zinc # optional
new_sub_module: ${constants.task} # optional
# keep_modules_after_finetuning_module: # optional
# graph_output_nn/graph: {}
# task_heads/zinc:
# new_sub_module: lipophilicity_astrazeneca
# out_dim: 1
# Changes to finetuning_module
drop_depth: 1
new_out_dim: 8
added_depth: 2
# Training
unfreeze_pretrained_depth: 0
epoch_unfreeze_all: none
# Optional finetuning head appended to model after finetuning_module
finetuning_head:
task: ${constants.task}
previous_module: task_heads
incoming_level: graph
model_type: mlp
in_dim: 8
out_dim: 1
hidden_dims: 8
depth: 2
last_layer_is_readout: true
================================================
FILE: expts/hydra-configs/finetuning/admet_baseline.yaml
================================================
# @package _global_
defaults:
- override /tasks/loss_metrics_datamodule: admet
constants:
task: tbd
name: finetune_${constants.task}
wandb:
name: ${constants.name}
project: finetuning
entity: recursion
seed: 42
max_epochs: 100
data_dir: ../data/graphium/admet/${constants.task}
datacache_path: ../datacache/admet/${constants.task}
raise_train_error: true
metric: ${get_metric_name:${constants.task}}
datamodule:
args:
batch_size_training: 32
dataloading_from: ram
persistent_workers: true
num_workers: 4
trainer:
model_checkpoint:
# save_top_k: 1
# monitor: graph_${constants.task}/${constants.metric}/val
# mode: ${get_metric_mode:${constants.task}}
# save_last: true
# filename: best
dirpath: model_checkpoints/finetuning/${constants.task}/${now:%Y-%m-%d_%H-%M-%S.%f}/
every_n_epochs: 200
trainer:
precision: 32
check_val_every_n_epoch: 1
# early_stopping:
# monitor: graph_${constants.task}/${constants.metric}/val
# mode: ${get_metric_mode:${constants.task}}
# min_delta: 0.001
# patience: 10
accumulate_grad_batches: none
# test_from_checkpoint: best.ckpt
# test_from_checkpoint: ${trainer.model_checkpoint.dirpath}/best.ckpt
predictor:
optim_kwargs:
lr: 0.000005
# == Fine-tuning config ==
finetuning:
task: ${constants.task}
level: graph
pretrained_model: tbd
finetuning_module: graph_output_nn
sub_module_from_pretrained: graph
new_sub_module: graph
keep_modules_after_finetuning_module: # optional
task_heads-pcqm4m_g25:
new_sub_module: ${constants.task}
hidden_dims: 256
depth: 2
last_activation: ${get_last_activation:${constants.task}}
out_dim: 1
epoch_unfreeze_all: tbd
================================================
FILE: expts/hydra-configs/hparam_search/optuna.yaml
================================================
# @package _global_
#
# For running a hyper-parameter search, we use the Optuna plugin for hydra.
# This makes optuna available as a sweeper in hydra and integrates easily with the rest of the codebase.
# For more info, see https://hydra.cc/docs/plugins/optuna_sweeper/
#
# To run a hyper-param search,
# (1) Update this config, specifically the hyper-param search space;
# (2) Run `graphium-train +hparam_search=optuna` from the command line.
defaults:
- override /hydra/sweeper: optuna
# Optuna supports various sweepers (e.g. grid search, random search, TPE sampler)
- override /hydra/sweeper/sampler: tpe
hyper_param_search:
# For the sweeper to work, the main process needs to return
# the objective value(s) (as a float) we are trying to optimize.
# Assuming this is a metric, the `objective` key specifies which metric.
# Optuna supports multi-parameter optimization as well.
# If configured correctly, you can specify multiple keys.
objective: loss/test
# Where to save results to
# NOTE (cwognum): Ideally, we would use the `hydra.sweep.dir` key, but they don't support remote paths.
# save_destination: gs://path/to/bucket
# overwrite_destination: false
hydra:
# Run in multirun mode by default (i.e. actually use the sweeper)
mode: MULTIRUN
# Changes the working directory
sweep:
dir: hparam-search-results/${constants.name}
subdir: ${hydra.job.num}
# Sweeper config
sweeper:
sampler:
seed: ${constants.seed}
direction: minimize
study_name: ${constants.name}
storage: null
n_trials: 100
n_jobs: 1
# The hyper-parameter search space definition
# See https://hydra.cc/docs/plugins/optuna_sweeper/#search-space-configuration for the options
params:
predictor.optim_kwargs.lr: tag(log, interval(0.00001, 0.001))
================================================
FILE: expts/hydra-configs/main.yaml
================================================
defaults:
# Accelerators
- accelerator: cpu
# Pre-training/fine-tuning
- architecture: toymix
- tasks: toymix
- training: toymix
# Benchmarking
- model: gcn
# Specializations
- training/accelerator: ${training}_${accelerator}
- training/model: ${training}_${model}
================================================
FILE: expts/hydra-configs/model/gated_gcn.yaml
================================================
# @package _global_
architecture:
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: ${architecture.pre_nn.normalization}
residual_type: none
gnn:
out_dim: ${constants.gnn_dim} # &gnn_dim 704
hidden_dims: ${architecture.gnn.out_dim} # *gnn_dim
hidden_dims_edges: ${constants.gnn_edge_dim}
layer_type: 'pyg:gated-gcn'
graph_output_nn:
graph:
hidden_dims: ${architecture.gnn.out_dim}
node:
hidden_dims: ${architecture.gnn.out_dim}
================================================
FILE: expts/hydra-configs/model/gcn.yaml
================================================
# @package _global_
architecture:
gnn:
layer_type: 'pyg:gcn'
================================================
FILE: expts/hydra-configs/model/gin.yaml
================================================
# @package _global_
architecture:
gnn:
layer_type: 'pyg:gin'
================================================
FILE: expts/hydra-configs/model/gine.yaml
================================================
# @package _global_
architecture:
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: ${constants.gnn_edge_dim}
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: ${architecture.pre_nn.normalization}
residual_type: none
gnn:
out_dim: ${constants.gnn_dim}
hidden_dims: ${architecture.gnn.out_dim}
layer_type: 'pyg:gine'
graph_output_nn:
graph:
hidden_dims: ${architecture.gnn.out_dim}
node:
hidden_dims: ${architecture.gnn.out_dim}
================================================
FILE: expts/hydra-configs/model/gpspp.yaml
================================================
# @package _global_
datamodule:
args:
batch_size_training: 32
featurization:
conformer_property_list: [positions_3d]
trainer:
trainer:
accumulate_grad_batches: 2
architecture:
pe_encoders:
encoders:
gaussian_pos:
encoder_type: "gaussian_kernel"
input_keys: ["positions_3d"]
output_keys: ["feat", "nodepair_gaussian_bias_3d"]
num_heads: 32
num_layers: 1 #2
embed_dim: 32
out_dim: 32 # need num of gaussian kernels 128
# but currently it checks pe_out_dim == pe_out_dim in encoder_manager.py, line 128
use_input_keys_prefix: False
gnn:
layer_type: 'pyg:gps'
layer_kwargs: # Parameters for the model itself. You could define dropout_attn: 0.1
node_residual: false
mpnn_type: 'pyg:mpnnplus'
mpnn_kwargs:
in_dim: 256
out_dim: 256
in_dim_edges: 128
out_dim_edges: 128
attn_type: "full-attention" # "full-attention", "none"
precision: &precision 16-true
biased_attention_key: "nodepair_gaussian_bias_3d" # 3D_bias
attn_kwargs:
num_heads: 32
droppath_rate_attn: 0.0
droppath_rate_ffn: 0.0
================================================
FILE: expts/hydra-configs/model/mpnn.yaml
================================================
# @package _global_
architecture:
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: ${architecture.pre_nn.normalization}
residual_type: none
gnn:
out_dim: ${constants.gnn_dim}
hidden_dims: ${architecture.gnn.out_dim}
hidden_dims_edges: ${constants.gnn_edge_dim}
layer_type: 'pyg:mpnnplus'
graph_output_nn:
graph:
hidden_dims: ${architecture.gnn.out_dim}
node:
hidden_dims: ${architecture.gnn.out_dim}
================================================
FILE: expts/hydra-configs/tasks/admet.yaml
================================================
# NOTE: We cannot have a single config, since for fine-tuning we will
# only want to override the loss_metrics_datamodule, whereas for training we will
# want to override both.
defaults:
- task_heads: admet
- loss_metrics_datamodule: admet
================================================
FILE: expts/hydra-configs/tasks/l1000_mcf7.yaml
================================================
# NOTE: We cannot have a single config, since for fine-tuning we will
# only want to override the loss_metrics_datamodule, whereas for training we will
# want to override both.
defaults:
- task_heads: l1000_mcf7
- loss_metrics_datamodule: l1000_mcf7
================================================
FILE: expts/hydra-configs/tasks/l1000_vcap.yaml
================================================
# NOTE: We cannot have a single config, since for fine-tuning we will
# only want to override the loss_metrics_datamodule, whereas for training we will
# want to override both.
defaults:
- task_heads: l1000_vcap
- loss_metrics_datamodule: l1000_vcap
================================================
FILE: expts/hydra-configs/tasks/largemix.yaml
================================================
# NOTE: We cannot have a single config, since for fine-tuning we will
# only want to override the loss_metrics_datamodule, whereas for training we will
# want to override both.
defaults:
- task_heads: largemix
- loss_metrics_datamodule: largemix
================================================
FILE: expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml
================================================
# @package _global_
#Task-specific
predictor:
metrics_on_progress_bar:
# All below metrics are directly copied from the TDC website.
# For more information, see https://tdcommons.ai/benchmark/admet_group/overview/
caco2_wang: ["mae"]
hia_hou: ["auroc"]
pgp_broccatelli: ["auroc"]
bioavailability_ma: ["auroc"]
lipophilicity_astrazeneca: ["mae"]
solubility_aqsoldb: ["mae"]
bbb_martins: ["auroc"]
ppbr_az: ["mae"]
vdss_lombardo: ["spearman"]
cyp2d6_veith: ["auprc"]
cyp3a4_veith: ["auprc"]
cyp2c9_veith: ["auprc"]
cyp2d6_substrate_carbonmangels: ["auprc"]
cyp3a4_substrate_carbonmangels: ["auprc"]
cyp2c9_substrate_carbonmangels: ["auprc"]
half_life_obach: ["spearman"]
clearance_microsome_az: ["spearman"]
clearance_hepatocyte_az: ["spearman"]
herg: ["auroc"]
ames: ["auroc"]
dili: ["auroc"]
ld50_zhu: ["auroc"]
loss_fun:
caco2_wang: mae
hia_hou: bce
pgp_broccatelli: bce
bioavailability_ma: bce
lipophilicity_astrazeneca: mae
solubility_aqsoldb: mae
bbb_martins: bce
ppbr_az: mae
vdss_lombardo: mae
cyp2d6_veith: bce
cyp3a4_veith: bce
cyp2c9_veith: bce
cyp2d6_substrate_carbonmangels: bce
cyp3a4_substrate_carbonmangels: bce
cyp2c9_substrate_carbonmangels: bce
half_life_obach: mae
clearance_microsome_az: mae
clearance_hepatocyte_az: mae
herg: bce
ames: bce
dili: bce
ld50_zhu: mae
random_seed: ${constants.seed}
optim_kwargs:
lr: 4.e-5 # warmup can be scheduled using torch_scheduler_kwargs
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 10
warmup_epochs: 10
verbose: False
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
caco2_wang: ®ression_metrics
- name: mae
metric: mae
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: spearman
metric: spearmanr
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: pearson
metric: pearsonr
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2_score
metric: r2_score
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
hia_hou: &classification_metrics
- name: auroc
metric: auroc
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
- name: auprc
metric: averageprecision
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
- name: accuracy
metric: accuracy
multitask_handling: mean-per-label
target_to_int: True
average: micro
threshold_kwargs: &threshold_05
operator: greater
threshold: 0.5
th_on_preds: True
th_on_target: True
- name: mcc
metric: mcc
num_classes: 2
multitask_handling: mean-per-label
target_to_int: True
average: micro
threshold_kwargs: *threshold_05
pgp_broccatelli: *classification_metrics
bioavailability_ma: *classification_metrics
lipophilicity_astrazeneca: *regression_metrics
solubility_aqsoldb: *regression_metrics
bbb_martins: *classification_metrics
ppbr_az: *regression_metrics
vdss_lombardo: *regression_metrics
cyp2d6_veith: *classification_metrics
cyp3a4_veith: *classification_metrics
cyp2c9_veith: *classification_metrics
cyp2d6_substrate_carbonmangels: *classification_metrics
cyp3a4_substrate_carbonmangels: *classification_metrics
cyp2c9_substrate_carbonmangels: *classification_metrics
half_life_obach: *regression_metrics
clearance_microsome_az: *regression_metrics
clearance_hepatocyte_az: *regression_metrics
herg: *classification_metrics
ames: *classification_metrics
dili: *classification_metrics
ld50_zhu: *regression_metrics
datamodule:
module_type: "ADMETBenchmarkDataModule"
args:
# TDC specific
tdc_benchmark_names: null
tdc_train_val_seed: ${constants.seed}
================================================
FILE: expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_mcf7.yaml
================================================
# @package _global_
predictor:
metrics_on_progress_bar:
l1000_mcf7: []
metrics_on_training_set:
l1000_mcf7: []
loss_fun:
l1000_mcf7:
name: hybrid_ce_ipu
n_brackets: 3
alpha: 0.5
metrics:
l1000_mcf7:
- name: auroc
metric: auroc
num_classes: 3
task: multiclass
target_to_int: True
target_nan_mask: -1000
ignore_index: -1000
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: averageprecision
num_classes: 3
task: multiclass
target_to_int: True
target_nan_mask: -1000
ignore_index: -1000
multitask_handling: mean-per-label
threshold_kwargs: null
datamodule:
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
l1000_mcf7:
df: null
df_path: ../data/graphium/large-dataset/LINCS_L1000_MCF7_0-2_th2.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: ../data/graphium/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt`
# split_names: [train, val, test_seen]
epoch_sampling_fraction: 1.0
================================================
FILE: expts/hydra-configs/tasks/loss_metrics_datamodule/l1000_vcap.yaml
================================================
# @package _global_
predictor:
metrics_on_progress_bar:
l1000_vcap: []
metrics_on_training_set:
l1000_vcap: []
loss_fun:
l1000_vcap:
name: hybrid_ce_ipu
n_brackets: 3
alpha: 0.5
metrics:
l1000_vcap:
- name: auroc
metric: auroc
num_classes: 3
task: multiclass
target_to_int: True
target_nan_mask: -1000
ignore_index: -1000
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: averageprecision
num_classes: 3
task: multiclass
target_to_int: True
target_nan_mask: -1000
ignore_index: -1000
multitask_handling: mean-per-label
threshold_kwargs: null
datamodule:
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
l1000_vcap:
df: null
df_path: ../data/graphium/large-dataset/LINCS_L1000_VCAP_0-2_th2.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: ../data/graphium/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt`
# split_names: [train, val, test_seen]
epoch_sampling_fraction: 1.0
================================================
FILE: expts/hydra-configs/tasks/loss_metrics_datamodule/largemix.yaml
================================================
# @package _global_
predictor:
metrics_on_progress_bar:
l1000_vcap: []
l1000_mcf7: []
pcba_1328: []
pcqm4m_g25: []
pcqm4m_n4: []
metrics_on_training_set:
l1000_vcap: []
l1000_mcf7: []
pcba_1328: []
pcqm4m_g25: []
pcqm4m_n4: []
loss_fun:
l1000_vcap:
name: hybrid_ce_ipu
n_brackets: 3
alpha: 0.5
l1000_mcf7:
name: hybrid_ce_ipu
n_brackets: 3
alpha: ${predictor.loss_fun.l1000_vcap.alpha}
pcba_1328: bce_logits_ipu
pcqm4m_g25: mae_ipu
pcqm4m_n4: mae_ipu
metrics:
l1000_vcap: &classif_metrics
- name: auroc
metric: auroc
num_classes: 3
task: multiclass
target_to_int: True
target_nan_mask: -1000
ignore_index: -1000
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: averageprecision
num_classes: 3
task: multiclass
target_to_int: True
target_nan_mask: -1000
ignore_index: -1000
multitask_handling: mean-per-label
threshold_kwargs: null
l1000_mcf7: *classif_metrics
pcba_1328:
# use auroc and averageprecision (non_ipu version) so tha nans are handled correctly
- name: auroc
metric: auroc
task: binary
multitask_handling: mean-per-label
target_nan_mask: ignore
threshold_kwargs: null
- name: avpr
metric: averageprecision
task: binary
multitask_handling: mean-per-label
target_nan_mask: ignore
threshold_kwargs: null
pcqm4m_g25: &pcqm_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2
metric: r2_score_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
pcqm4m_n4: *pcqm_metrics
datamodule:
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
l1000_vcap:
df: null
df_path: ../data/graphium/large-dataset/LINCS_L1000_VCAP_0-2_th2.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: ../data/graphium/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt`
# split_names: [train, val, test_seen]
epoch_sampling_fraction: 1.0
l1000_mcf7:
df: null
df_path: ../data/graphium/large-dataset/LINCS_L1000_MCF7_0-2_th2.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: ../data/graphium/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt`
# split_names: [train, val, test_seen]
epoch_sampling_fraction: 1.0
pcba_1328:
df: null
df_path: ../data/graphium/large-dataset/PCBA_1328_1564k.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCBA_1328_1564k.parquet
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: assayID-* # assayID-* means all columns starting with "assayID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: ../data/graphium/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt`
# split_names: [train, val, test_seen]
epoch_sampling_fraction: 1.0
pcqm4m_g25:
df: null
df_path: ../data/graphium/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: graph_* # graph_* means all columns starting with "graph_"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: ../data/graphium/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
# split_names: [train, val, test_seen]
label_normalization:
normalize_val_test: True
method: "normal"
epoch_sampling_fraction: 1.0
pcqm4m_n4:
df: null
df_path: ../data/graphium/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: node_* # node_* means all columns starting with "node_"
# sample_size: 2000 # use sample_size for test
task_level: node
splits_path: ../data/graphium/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
# split_names: [train, val, test_seen]
seed: 42
label_normalization:
normalize_val_test: True
method: "normal"
epoch_sampling_fraction: 1.0
================================================
FILE: expts/hydra-configs/tasks/loss_metrics_datamodule/pcba_1328.yaml
================================================
# @package _global_
predictor:
metrics_on_progress_bar:
pcba_1328: []
metrics_on_training_set:
pcba_1328: []
loss_fun:
pcba_1328: bce_logits_ipu
metrics:
pcba_1328:
# use auroc and averageprecision (non_ipu version) so tha nans are handled correctly
- name: auroc
metric: auroc
task: binary
multitask_handling: mean-per-label
target_nan_mask: ignore
threshold_kwargs: null
- name: avpr
metric: averageprecision
task: binary
multitask_handling: mean-per-label
target_nan_mask: ignore
threshold_kwargs: null
datamodule:
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
pcba_1328:
df: null
df_path: ../data/graphium/large-dataset/PCBA_1328_1564k.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCBA_1328_1564k.parquet
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: assayID-* # assayID-* means all columns starting with "assayID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: ../data/graphium/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt`
# split_names: [train, val, test_seen]
epoch_sampling_fraction: 1.0
================================================
FILE: expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m.yaml
================================================
# @package _global_
#Task-specific
predictor:
metrics_on_progress_bar:
homolumo: []
metrics_on_training_set:
homolumo: ["pearsonr"]
loss_fun:
homolumo: mae_ipu
# Task-specific
metrics:
homolumo:
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
homolumo:
df: null
task_level: "graph"
df_path: graphium/data/PCQM4M/pcqm4mv2.csv
# wget https://storage.valencelabs.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2.csv
# or set path as https://storage.valencelabs.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2.csv directly
smiles_col: "cxsmiles"
label_cols: ["homo_lumo_gap"]
# sample_size: 8000 # use sample_size for test
splits_path: graphium/data/PCQM4M/split_dict_v2.pt # Download with `wget https://storage.valencelabs.com/datasets-public-research/PCQM4M/cxsmiles/split_dict_v2.pt`
split_names: ["train", "valid", "test-dev"]
# graphium/data/PCQM4Mv2/split_dict.pt
# graphium/data/PCQM4Mv2/pcqm4m_split.csv
# split_val: 0.1
# split_test: 0.1
seed: ${constants.seed}
label_normalization:
method: "normal"
================================================
FILE: expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_g25.yaml
================================================
# @package _global_
predictor:
metrics_on_progress_bar:
pcqm4m_g25: []
metrics_on_training_set:
pcqm4m_g25: []
loss_fun:
pcqm4m_g25: mae_ipu
metrics:
pcqm4m_g25:
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2
metric: r2_score_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
datamodule:
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
pcqm4m_g25:
df: null
df_path: ../data/graphium/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: graph_* # graph_* means all columns starting with "graph_"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: ../data/graphium/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
# split_names: [train, val, test_seen]
label_normalization:
normalize_val_test: True
method: "normal"
epoch_sampling_fraction: 1.0
================================================
FILE: expts/hydra-configs/tasks/loss_metrics_datamodule/pcqm4m_n4.yaml
================================================
# @package _global_
predictor:
metrics_on_progress_bar:
pcqm4m_n4: []
metrics_on_training_set:
pcqm4m_n4: []
loss_fun:
pcqm4m_n4: mae_ipu
metrics:
pcqm4m_n4:
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2
metric: r2_score_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
datamodule:
pcqm4m_n4:
df: null
df_path: ../data/graphium/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: node_* # node_* means all columns starting with "node_"
# sample_size: 2000 # use sample_size for test
task_level: node
splits_path: ../data/graphium/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
# split_names: [train, val, test_seen]
seed: 42
label_normalization:
normalize_val_test: True
method: "normal"
epoch_sampling_fraction: 1.0
================================================
FILE: expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml
================================================
# @package _global_
predictor:
metrics_on_progress_bar:
qm9: ["mae"]
tox21: ["auroc"]
zinc: ["mae"]
loss_fun:
qm9: mae_ipu
tox21: bce_logits_ipu
zinc: mae_ipu
metrics:
qm9: &qm9_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2_score
metric: r2_score_ipu
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
tox21:
- name: auroc
metric: auroc_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
- name: f1 > 0.5
metric: f1
multitask_handling: mean-per-label
target_to_int: True
num_classes: 2
average: micro
threshold_kwargs: &threshold_05
operator: greater
threshold: 0.5
th_on_preds: True
th_on_target: True
- name: precision > 0.5
metric: precision
multitask_handling: mean-per-label
average: micro
threshold_kwargs: *threshold_05
zinc: *qm9_metrics
datamodule:
args:
task_specific_args:
qm9:
df: null
df_path: ${constants.data_dir}/qm9.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/qm9.csv.gz
# or set path as the URL directly
smiles_col: "smiles"
label_cols: ["A", "B", "C", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "u0", "u298", "h298", "g298", "cv", "u0_atom", "u298_atom", "h298_atom", "g298_atom"]
# sample_size: 2000 # use sample_size for test
splits_path: ${constants.data_dir}/qm9_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/qm9_random_splits.pt`
seed: ${constants.seed} #*seed
task_level: graph
label_normalization:
normalize_val_test: True
method: "normal"
tox21:
df: null
df_path: ${constants.data_dir}/Tox21-7k-12-labels.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz
# or set path as the URL directly
smiles_col: "smiles"
label_cols: ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"]
# sample_size: 2000 # use sample_size for test
splits_path: ${constants.data_dir}/Tox21_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/Tox21_random_splits.pt`
seed: ${constants.seed}
task_level: graph
zinc:
df: null
df_path: ${constants.data_dir}/ZINC12k.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz
# or set path as the URL directly
smiles_col: "smiles"
label_cols: ["SA", "logp", "score"]
# sample_size: 2000 # use sample_size for test
splits_path: ${constants.data_dir}/ZINC12k_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt`
seed: ${constants.seed}
task_level: graph
label_normalization:
normalize_val_test: True
method: "normal"
================================================
FILE: expts/hydra-configs/tasks/pcba_1328.yaml
================================================
# NOTE: We cannot have a single config, since for fine-tuning we will
# only want to override the loss_metrics_datamodule, whereas for training we will
# want to override both.
defaults:
- task_heads: pcba_1328
- loss_metrics_datamodule: pcba_1328
================================================
FILE: expts/hydra-configs/tasks/pcqm4m.yaml
================================================
# NOTE: We cannot have a single config, since for fine-tuning we will
# only want to override the loss_metrics_datamodule, whereas for training we will
# want to override both.
defaults:
- task_heads: pcqm4m
- loss_metrics_datamodule: pcqm4m
================================================
FILE: expts/hydra-configs/tasks/pcqm4m_g25.yaml
================================================
# NOTE: We cannot have a single config, since for fine-tuning we will
# only want to override the loss_metrics_datamodule, whereas for training we will
# want to override both.
defaults:
- task_heads: pcqm4m_g25
- loss_metrics_datamodule: pcqm4m_g25
================================================
FILE: expts/hydra-configs/tasks/pcqm4m_n4.yaml
================================================
# NOTE: We cannot have a single config, since for fine-tuning we will
# only want to override the loss_metrics_datamodule, whereas for training we will
# want to override both.
defaults:
- task_heads: pcqm4m_n4
- loss_metrics_datamodule: pcqm4m_n4
================================================
FILE: expts/hydra-configs/tasks/task_heads/admet.yaml
================================================
# @package _global_
architecture:
task_heads:
caco2_wang: ®ression_head
task_level: graph
out_dim: 1
hidden_dims: 64
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.5
normalization: &normalization "layer_norm"
last_normalization: "none"
residual_type: none
hia_hou: &classification_head
task_level: graph
out_dim: 1
hidden_dims: 64
depth: 2
activation: relu
last_activation: sigmoid
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
pgp_broccatelli: *classification_head
bioavailability_ma: *classification_head
lipophilicity_astrazeneca: *regression_head
solubility_aqsoldb: *regression_head
bbb_martins: *classification_head
ppbr_az: *regression_head
vdss_lombardo: *regression_head
cyp2d6_veith: *classification_head
cyp3a4_veith: *classification_head
cyp2c9_veith: *classification_head
cyp2d6_substrate_carbonmangels: *classification_head
cyp3a4_substrate_carbonmangels: *classification_head
cyp2c9_substrate_carbonmangels: *classification_head
half_life_obach: *regression_head
clearance_microsome_az: *regression_head
clearance_hepatocyte_az: *regression_head
herg: *classification_head
ames: *classification_head
dili: *classification_head
ld50_zhu: *regression_head
================================================
FILE: expts/hydra-configs/tasks/task_heads/l1000_mcf7.yaml
================================================
# @package _global_
architecture:
task_heads:
l1000_mcf7:
task_level: graph
out_dim: 2934
hidden_dims: 128
depth: 2
activation: none
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: "none"
residual_type: none
================================================
FILE: expts/hydra-configs/tasks/task_heads/l1000_vcap.yaml
================================================
# @package _global_
architecture:
task_heads:
l1000_vcap:
task_level: graph
out_dim: 2934
hidden_dims: 128
depth: 2
activation: none
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: "none"
residual_type: none
================================================
FILE: expts/hydra-configs/tasks/task_heads/largemix.yaml
================================================
# @package _global_
architecture:
task_heads:
l1000_vcap:
task_level: graph
out_dim: 2934
hidden_dims: 256
depth: 2
activation: none
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: "none"
residual_type: none
l1000_mcf7:
task_level: graph
out_dim: 2934
hidden_dims: 256
depth: 2
activation: none
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: "none"
residual_type: none
pcba_1328:
task_level: graph
out_dim: 1328
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: "none"
residual_type: none
pcqm4m_g25:
task_level: graph
out_dim: 25
hidden_dims: 64
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: "none"
residual_type: none
pcqm4m_n4:
task_level: node
out_dim: 4
hidden_dims: 64
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: "none"
residual_type: none
================================================
FILE: expts/hydra-configs/tasks/task_heads/pcba_1328.yaml
================================================
# @package _global_
architecture:
task_heads:
pcba_1328:
task_level: graph
out_dim: 1328
hidden_dims: 64
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: "none"
residual_type: none
================================================
FILE: expts/hydra-configs/tasks/task_heads/pcqm4m.yaml
================================================
# @package _global_
architecture:
task_heads:
homolumo:
task_level: graph
out_dim: 1
hidden_dims: 256
depth: 2 # Not needed if we have hidden_dims
activation: relu
last_activation: none
dropout: 0.18
normalization: layer_norm
last_normalization: "none"
residual_type: none
================================================
FILE: expts/hydra-configs/tasks/task_heads/pcqm4m_g25.yaml
================================================
# @package _global_
architecture:
task_heads:
pcqm4m_g25:
task_level: graph
out_dim: 25
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: "none"
residual_type: none
================================================
FILE: expts/hydra-configs/tasks/task_heads/pcqm4m_n4.yaml
================================================
# @package _global_
architecture:
task_heads:
pcqm4m_n4:
task_level: node
out_dim: 4
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: "none"
residual_type: none
================================================
FILE: expts/hydra-configs/tasks/task_heads/toymix.yaml
================================================
# @package _global_
architecture:
task_heads:
qm9:
task_level: graph
out_dim: 19
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: "none"
residual_type: none
tox21:
task_level: graph
out_dim: 12
hidden_dims: 64
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: "none"
residual_type: none
zinc:
task_level: graph
out_dim: 3
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: "none"
residual_type: none
================================================
FILE: expts/hydra-configs/tasks/toymix.yaml
================================================
# NOTE: We cannot have a single config, since for fine-tuning we will
# only want to override the loss_metrics_datamodule, whereas for training we will
# want to override both.
defaults:
- task_heads: toymix
- loss_metrics_datamodule: toymix
================================================
FILE: expts/hydra-configs/training/accelerator/largemix_cpu.yaml
================================================
# @package _global_
datamodule:
args:
batch_size_training: 200
batch_size_inference: 200
featurization_n_jobs: 20
num_workers: 20
predictor:
metrics_every_n_train_steps: 1000
torch_scheduler_kwargs:
max_num_epochs: ${constants.max_epochs}
trainer:
trainer:
precision: 32
accumulate_grad_batches: 2
max_epochs: ${constants.max_epochs}
================================================
FILE: expts/hydra-configs/training/accelerator/largemix_gpu.yaml
================================================
# @package _global_
accelerator:
float32_matmul_precision: medium
datamodule:
args:
batch_size_training: 2048
batch_size_inference: 2048
featurization_n_jobs: 6
num_workers: 6
predictor:
metrics_every_n_train_steps: 1000
torch_scheduler_kwargs:
max_num_epochs: ${constants.max_epochs}
trainer:
trainer:
precision: 16-mixed
max_epochs: ${constants.max_epochs}
================================================
FILE: expts/hydra-configs/training/accelerator/largemix_ipu.yaml
================================================
# @package _global_
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 35 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 100
# Data handling-related
batch_size_training: 30
batch_size_inference: 30
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 16-true
accumulate_grad_batches: 2
================================================
FILE: expts/hydra-configs/training/accelerator/pcqm4m_ipu.yaml
================================================
# @package _global_
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 16 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 60
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 30 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 120
# Data handling-related
batch_size_inference: 16
predictor:
metrics_every_n_train_steps: 100
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 16-true
================================================
FILE: expts/hydra-configs/training/accelerator/toymix_cpu.yaml
================================================
# @package _global_
datamodule:
args:
batch_size_training: 200
batch_size_inference: 200
featurization_n_jobs: 4
num_workers: 4
predictor:
optim_kwargs: {}
metrics_every_n_train_steps: 300
torch_scheduler_kwargs:
max_num_epochs: ${constants.max_epochs}
trainer:
trainer:
precision: 32
accumulate_grad_batches: 1
max_epochs: ${constants.max_epochs}
================================================
FILE: expts/hydra-configs/training/accelerator/toymix_gpu.yaml
================================================
# @package _global_
accelerator:
float32_matmul_precision: medium
datamodule:
args:
batch_size_training: 200
batch_size_inference: 200
featurization_n_jobs: 4
num_workers: 4
predictor:
optim_kwargs: {}
metrics_every_n_train_steps: 300
torch_scheduler_kwargs:
max_num_epochs: ${constants.max_epochs}
trainer:
trainer:
accumulate_grad_batches: 1
max_epochs: ${constants.max_epochs}
================================================
FILE: expts/hydra-configs/training/accelerator/toymix_ipu.yaml
================================================
# @package _global_
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 44 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 80
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 44 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 80
# Data handling-related
batch_size_training: 50
batch_size_inference: 50
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
accumulate_grad_batches: 4
================================================
FILE: expts/hydra-configs/training/largemix.yaml
================================================
# @package _global_
predictor:
random_seed: ${constants.seed}
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 5
verbose: False
scheduler_kwargs:
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
trainer:
seed: ${constants.seed}
logger:
save_dir: logs/neurips2023-large/
name: ${constants.name}
project: ${constants.name}
model_checkpoint:
dirpath: model_checkpoints/large-dataset/${now:%Y-%m-%d_%H-%M-%S}/
filename: ${constants.name}
save_last: True # saving last model
# save_top_k: 1 # and best model
# monitor: loss/val # wrt validation loss
trainer:
precision: 16-mixed
max_epochs: ${predictor.torch_scheduler_kwargs.max_num_epochs}
min_epochs: 1
check_val_every_n_epoch: 10
================================================
FILE: expts/hydra-configs/training/model/largemix_gated_gcn.yaml
================================================
# @package _global_
constants:
name: large_data_gated_gcn
wandb:
name: ${constants.name}
project: neurips2023-expts
entity: multitask-gnn
entity: multitask-gnn
seed: 42
max_epochs: 200
data_dir: ../data/graphium/large-dataset/
raise_train_error: true
datacache_path: ../datacache/large-dataset/
gnn_dim: 512
gnn_edge_dim: 128
norm: "layer_norm"
trainer:
model_checkpoint:
dirpath: model_checkpoints/large-dataset/gated_gcn/${now:%Y-%m-%d_%H-%M-%S}/
================================================
FILE: expts/hydra-configs/training/model/largemix_gcn.yaml
================================================
# @package _global_
constants:
name: large_data_gcn
wandb:
name: ${constants.name}
project: neurips2023-expts
entity: multitask-gnn
entity: multitask-gnn
seed: 42
max_epochs: 200
data_dir: ../data/graphium/large-dataset/
raise_train_error: true
datacache_path: ../datacache/large-dataset/
norm: "layer_norm"
trainer:
model_checkpoint:
dirpath: model_checkpoints/large-dataset/gcn/${now:%Y-%m-%d_%H-%M-%S}/
================================================
FILE: expts/hydra-configs/training/model/largemix_gin.yaml
================================================
# @package _global_
constants:
name: large_data_gin
wandb:
name: ${constants.name}
project: neurips2023-expts
entity: multitask-gnn
entity: multitask-gnn
seed: 42
max_epochs: 200
data_dir: ../data/graphium/large-dataset/
raise_train_error: true
datacache_path: ../datacache/large-dataset/
norm: "layer_norm"
trainer:
model_checkpoint:
dirpath: model_checkpoints/large-dataset/gin/${now:%Y-%m-%d_%H-%M-%S}/
================================================
FILE: expts/hydra-configs/training/model/largemix_gine.yaml
================================================
# @package _global_
constants:
name: large_data_gine
wandb:
name: ${constants.name}
project: neurips2023-expts
entity: multitask-gnn
entity: multitask-gnn
seed: 42
max_epochs: 200
data_dir: ../data/graphium/large-dataset/
raise_train_error: true
datacache_path: ../datacache/large-dataset/
gnn_dim: 512
gnn_edge_dim: 32
norm: "layer_norm"
trainer:
model_checkpoint:
dirpath: model_checkpoints/large-dataset/gine/${now:%Y-%m-%d_%H-%M-%S}/
================================================
FILE: expts/hydra-configs/training/model/largemix_mpnn.yaml
================================================
# @package _global_
constants:
name: large_data_mpnn
wandb:
name: ${constants.name}
project: neurips2023-expts
entity: multitask-gnn
entity: multitask-gnn
seed: 42
max_epochs: 200
data_dir: ../data/graphium/large-dataset/
raise_train_error: true
datacache_path: ../datacache/large-dataset/
gnn_dim: 512
gnn_edge_dim: 256
norm: "layer_norm"
trainer:
model_checkpoint:
dirpath: model_checkpoints/large-dataset/mpnn/${now:%Y-%m-%d_%H-%M-%S}/
================================================
FILE: expts/hydra-configs/training/model/pcqm4m_gpspp.yaml
================================================
# @package _global_
# GPS++ model with the PCQMv2 dataset.
constants:
name: pcqm4mv2_gpspp_4layer
seed: 42
max_epochs: 100
raise_train_error: true # Whether the code should raise an error if it crashes during training
datacache_path: "/localdata/PCQM4Mv2/"
trainer:
model_checkpoint:
dirpath: models_checkpoints/PCMQ4Mv2/gpspp/${now:%Y-%m-%d_%H-%M-%S}/
================================================
FILE: expts/hydra-configs/training/model/pcqm4m_mpnn.yaml
================================================
# @package _global_
# MPNN model with the PCQMv2 dataset.
constants:
name: pcqm4mv2_mpnn_4layer
seed: 42
max_epochs: 100
raise_train_error: true # Whether the code should raise an error if it crashes during training
datacache_path: "/localdata/PCQM4Mv2/"
trainer:
model_checkpoint:
dirpath: models_checkpoints/PCMQ4Mv2/mpnn/${now:%Y-%m-%d_%H-%M-%S}/
================================================
FILE: expts/hydra-configs/training/model/toymix_gcn.yaml
================================================
# @package _global_
constants:
name: neurips2023_small_data_gcn
seed: 42
max_epochs: 100
data_dir: expts/data/neurips2023/small-dataset
raise_train_error: true
datacache_path: ../datacache/neurips2023-small/
trainer:
model_checkpoint:
dirpath: models_checkpoints/small-dataset/gcn/${now:%Y-%m-%d_%H-%M-%S}/
================================================
FILE: expts/hydra-configs/training/model/toymix_gin.yaml
================================================
# @package _global_
constants:
name: neurips2023_small_data_gin
seed: 42
max_epochs: 100
data_dir: expts/data/neurips2023/small-dataset
raise_train_error: true
datacache_path: ../datacache/neurips2023-small/
trainer:
model_checkpoint:
dirpath: models_checkpoints/neurips2023-small-gin/${now:%Y-%m-%d_%H-%M-%S}/
================================================
FILE: expts/hydra-configs/training/pcqm4m.yaml
================================================
# @package _global_
predictor:
random_seed: ${constants.seed}
optim_kwargs:
lr: 4.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: ${constants.max_epochs}
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor homolumo/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore: ignore nan values from loss
flag_kwargs:
n_steps: 0 # 1
alpha: 0.0 # 0.01
trainer:
seed: ${constants.seed}
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/PCMQ4Mv2/${now:%Y-%m-%d_%H-%M-%S}/
filename: ${constants.name}
#monitor: *monitor
#mode: *mode
save_top_k: 1
every_n_epochs: 100
trainer:
max_epochs: ${constants.max_epochs}
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/hydra-configs/training/toymix.yaml
================================================
# @package _global_
predictor:
random_seed: ${constants.seed}
optim_kwargs:
lr: 4.e-5 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: ${constants.max_epochs}
warmup_epochs: 10
verbose: False
scheduler_kwargs: null
target_nan_mask: null
multitask_handling: flatten # flatten, mean-per-label
trainer:
seed: ${constants.seed}
model_checkpoint:
filename: ${constants.name}
save_last: True
trainer:
precision: 16
max_epochs: ${constants.max_epochs}
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/main_run_get_fingerprints.py
================================================
# General imports
import os
from os.path import dirname, abspath
import yaml
import numpy as np
import pandas as pd
import torch
import fsspec
from lightning.pytorch.utilities.model_summary import ModelSummary
# Current project imports
import graphium
from graphium.config._loader import load_datamodule, load_trainer
from graphium.utils.fs import mkdir
from graphium.trainer.predictor import PredictorModule
# Set up the working directory
MAIN_DIR = dirname(dirname(abspath(graphium.__file__)))
os.chdir(MAIN_DIR)
# MODEL_FILE = "models_checkpoints/micro_ZINC/model.ckpt"
# CONFIG_FILE = "expts/config_micro_ZINC.yaml"
def main() -> None:
LIST_CONCAT_LAST_LAYERS = [1, 0, [1, 2], [0, 1, 2]]
DATA_NAME_ALL = ["molbace"] # , "mollipo", "moltox21", "molHIV"]
MODEL_PATH = "https://storage.valencelabs.com/graphium/pretrained-models"
MODEL_NAME = "graphium-zinc-micro-dummy-test"
MODEL_FILE = f"{MODEL_PATH}/{MODEL_NAME}/model.ckpt"
MODEL_CONFIG = f"{MODEL_PATH}/{MODEL_NAME}/configs.yaml"
predictor = PredictorModule.load_from_checkpoint(MODEL_FILE)
print(predictor.model)
print(ModelSummary(predictor, max_depth=4))
for data_name in DATA_NAME_ALL:
DATA_CONFIG = f"{MAIN_DIR}/expts/config_{data_name}_pretrained.yaml"
with fsspec.open(DATA_CONFIG, "r") as f:
data_cfg = yaml.safe_load(f)
with fsspec.open(os.path.join(MODEL_CONFIG), "r") as f:
model_cfg = yaml.safe_load(f)
# Load and initialize the dataset
data_cfg["datamodule"]["args"]["featurization"] = model_cfg["datamodule"]["args"]["featurization"]
datamodule = load_datamodule(data_cfg)
print("\ndatamodule:\n", datamodule, "\n")
for concat_last_layers in LIST_CONCAT_LAST_LAYERS:
export_dir = f"predictions/fingerprints-model-{MODEL_NAME}"
export_df_path = f"{export_dir}/{data_name}-concatlayers-{concat_last_layers}.csv.gz"
predictor.model.concat_last_layers = concat_last_layers
trainer = load_trainer(data_cfg)
# Run the model prediction
preds = trainer.predict(model=predictor, datamodule=datamodule)
if isinstance(preds[0], torch.Tensor):
preds = [p.detach().cpu().numpy() for p in preds]
preds = np.concatenate(preds, axis=0)
# Generate output dataframe
df = {"SMILES": datamodule.dataset.smiles}
target = datamodule.dataset.labels
for ii in range(target.shape[1]):
df[f"Target-{ii}"] = target[:, ii]
for ii in range(preds.shape[1]):
df[f"Preds-{ii}"] = preds[:, ii]
df = pd.DataFrame(df)
mkdir(export_dir)
df.to_csv(export_df_path)
print(df)
print(f"file saved to:`{export_df_path}`")
if __name__ == "__main__":
main()
================================================
FILE: expts/main_run_multitask.py
================================================
import hydra
from omegaconf import DictConfig
@hydra.main(version_base=None, config_path="hydra-configs", config_name="main")
def main(cfg: DictConfig) -> None:
raise DeprecationWarning(
"This script is deprecated. Use `python graphium/cli/train_finetune.py` (or `graphium-train`) instead!"
)
if __name__ == "__main__":
main()
================================================
FILE: expts/main_run_predict.py
================================================
# General imports
import os
from os.path import dirname, abspath
import yaml
from copy import deepcopy
from omegaconf import DictConfig
import numpy as np
import pandas as pd
import torch
from lightning.pytorch.utilities.model_summary import ModelSummary
# Current project imports
import graphium
from graphium.config._loader import load_datamodule, load_trainer
from graphium.utils.fs import mkdir
from graphium.trainer.predictor import PredictorModule
# Set up the working directory
MAIN_DIR = dirname(dirname(abspath(graphium.__file__)))
os.chdir(MAIN_DIR)
DATA_NAME = "molhiv"
MODEL_FILE = "models_checkpoints/ogb-molpcba/model-v2.ckpt"
CONFIG_FILE = f"expts/config_{DATA_NAME}_pretrained.yaml"
# MODEL_FILE = "models_checkpoints/micro_ZINC/model.ckpt"
# CONFIG_FILE = "expts/config_micro_ZINC.yaml"
def main(cfg: DictConfig) -> None:
cfg = deepcopy(cfg)
# Load and initialize the dataset
datamodule = load_datamodule(cfg)
print("\ndatamodule:\n", datamodule, "\n")
export_df_path = f"predictions/preds-{DATA_NAME}.csv.gz"
predictor = PredictorModule.load_from_checkpoint(MODEL_FILE)
print(predictor.model)
print(ModelSummary(predictor, max_depth=4))
trainer = load_trainer(cfg)
# Run the model prediction
preds = trainer.predict(model=predictor, datamodule=datamodule)
if isinstance(preds[0], torch.Tensor):
preds = [p.detach().cpu().numpy() for p in preds]
preds = np.concatenate(preds, axis=0)
# Generate output dataframe
df = {"SMILES": datamodule.dataset.smiles}
target = datamodule.dataset.labels
for ii in range(target.shape[1]):
df[f"Target-{ii}"] = target[:, ii]
for ii in range(preds.shape[1]):
df[f"Preds-{ii}"] = preds[:, ii]
df = pd.DataFrame(df)
mkdir("predictions")
df.to_csv(export_df_path)
print(df)
print(f"file saved to:`{export_df_path}`")
if __name__ == "__main__":
with open(os.path.join(MAIN_DIR, CONFIG_FILE), "r") as f:
cfg = yaml.safe_load(f)
main(cfg)
================================================
FILE: expts/main_run_test.py
================================================
# General imports
import os
from os.path import dirname, abspath
import yaml
from copy import deepcopy
from omegaconf import DictConfig
from lightning.pytorch.utilities.model_summary import ModelSummary
# Current project imports
import graphium
from graphium.config._loader import load_datamodule, load_metrics, load_trainer
from graphium.trainer.predictor import PredictorModule
# Set up the working directory
MAIN_DIR = dirname(dirname(abspath(graphium.__file__)))
os.chdir(MAIN_DIR)
MODEL_FILE = "models_checkpoints/ogb-molpcba/model-v2.ckpt"
CONFIG_FILE = "expts/config_molPCBA.yaml"
def main(cfg: DictConfig) -> None:
cfg = deepcopy(cfg)
# Load and initialize the dataset
datamodule = load_datamodule(cfg)
print("\ndatamodule:\n", datamodule, "\n")
metrics = load_metrics(cfg)
print(metrics)
predictor = PredictorModule.load_from_checkpoint(MODEL_FILE)
predictor.metrics = metrics
print(predictor.model)
print(ModelSummary(predictor, max_depth=4))
trainer = load_trainer(cfg)
# Run the model testing
trainer.test(model=predictor, datamodule=datamodule, ckpt_path=MODEL_FILE)
if __name__ == "__main__":
with open(os.path.join(MAIN_DIR, CONFIG_FILE), "r") as f:
cfg = yaml.safe_load(f)
main(cfg)
================================================
FILE: expts/neurips2023_configs/base_config/large.yaml
================================================
# @package _global_
constants:
seed: 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
entity: multitask-gnn
datacache_path: "/localdata/neurips2023-large/"
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 35 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 100
# Data handling-related
batch_size_training: 30
batch_size_inference: 30
predictor:
metrics_every_n_train_steps: 1000
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 16-true
accumulate_grad_batches: 2
ipu_config:
- deviceIterations(30) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 96)
- Precision.enableStochasticRounding(True)
# - Precision.enableFloatingPointExceptions(True)
ipu_inference_config:
# set device iteration and replication factor to 1 during inference
# gradient accumulation was set to 1 in the code
- deviceIterations(1)
- replicationFactor(1)
- Precision.enableStochasticRounding(False)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# args:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
l1000_vcap:
df: null
df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_VCAP_0-2_th2.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt`
epoch_sampling_fraction: 1.0
l1000_mcf7:
df: null
df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_MCF7_0-2_th2.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt`
epoch_sampling_fraction: 1.0
pcba_1328:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCBA_1328_1564k.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCBA_1328_1564k.parquet
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: assayID-* # assayID-* means all columns starting with "assayID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt`
epoch_sampling_fraction: 1.0
pcqm4m_g25:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: graph_* # graph_* means all columns starting with "graph_"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
label_normalization:
normalize_val_test: True
method: "normal"
epoch_sampling_fraction: 1.0
pcqm4m_n4:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: node_* # node_* means all columns starting with "node_"
# sample_size: 2000 # use sample_size for test
task_level: node
splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
seed: 42
label_normalization:
normalize_val_test: True
method: "normal"
epoch_sampling_fraction: 1.0
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
dataloading_from: disk
processed_graph_data_path: ${constants.datacache_path}
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 32 # -1 to use all
persistent_workers: True # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.1
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 768
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
node:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
l1000_vcap:
task_level: graph
out_dim: 2934
hidden_dims: 128
depth: 2
activation: none
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
l1000_mcf7:
task_level: graph
out_dim: 2934
hidden_dims: 128
depth: 2
activation: none
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
pcba_1328:
task_level: graph
out_dim: 1328
hidden_dims: 64
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
pcqm4m_g25:
task_level: graph
out_dim: 25
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
pcqm4m_n4:
task_level: node
out_dim: 4
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
l1000_vcap: []
l1000_mcf7: []
pcba_1328: []
pcqm4m_g25: []
pcqm4m_n4: []
metrics_on_training_set:
l1000_vcap: []
l1000_mcf7: []
pcba_1328: []
pcqm4m_g25: []
pcqm4m_n4: []
loss_fun:
l1000_vcap:
name: hybrid_ce_ipu
n_brackets: 3
alpha: 0.5
l1000_mcf7:
name: hybrid_ce_ipu
n_brackets: 3
alpha: ${predictor.loss_fun.l1000_vcap.alpha}
pcba_1328: bce_logits_ipu
pcqm4m_g25: mae_ipu
pcqm4m_n4: mae_ipu
random_seed: ${constants.seed}
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
l1000_vcap: &classif_metrics
- name: auroc
metric: auroc
num_classes: 3
task: multiclass
target_to_int: True
target_nan_mask: -1000
ignore_index: -1000
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: averageprecision
num_classes: 3
task: multiclass
target_to_int: True
target_nan_mask: -1000
ignore_index: -1000
multitask_handling: mean-per-label
threshold_kwargs: null
l1000_mcf7: *classif_metrics
pcba_1328:
# use auroc and averageprecision (non_ipu version) so tha nans are handled correctly
- name: auroc
metric: auroc
task: binary
multitask_handling: mean-per-label
target_nan_mask: ignore
threshold_kwargs: null
- name: avpr
metric: averageprecision
task: binary
multitask_handling: mean-per-label
target_nan_mask: ignore
threshold_kwargs: null
pcqm4m_g25: &pcqm_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2
metric: r2_score_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
pcqm4m_n4: *pcqm_metrics
trainer:
seed: ${constants.seed}
logger:
save_dir: logs/neurips2023-large/
name: ${constants.name}
project: ${constants.name}
model_checkpoint:
dirpath: models_checkpoints/${constants.name}/
filename: ${constants.name}
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: ${predictor.torch_scheduler_kwargs.max_num_epochs}
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/base_config/large_pcba.yaml
================================================
# @package _global_
constants:
seed: 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
entity: multitask-gnn
datacache_path: "/localdata/neurips2023-large/"
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 35 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 100
# Data handling-related
batch_size_training: 30
batch_size_inference: 30
predictor:
metrics_every_n_train_steps: 1000
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 16-true
accumulate_grad_batches: 2
ipu_config:
- deviceIterations(30) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 96)
- Precision.enableStochasticRounding(True)
# - Precision.enableFloatingPointExceptions(True)
ipu_inference_config:
# set device iteration and replication factor to 1 during inference
# gradient accumulation was set to 1 in the code
- deviceIterations(1)
- replicationFactor(1)
- Precision.enableStochasticRounding(False)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# args:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
# df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_VCAP_0-2_th2.csv.gz
# # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
# # or set path as the URL directly
# smiles_col: "SMILES"
# label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# # sample_size: 2000 # use sample_size for test
# task_level: graph
# splits_path: graphium/data/neurips2023/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt`
# epoch_sampling_fraction: 1.0
# l1000_mcf7:
# df: null
# df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_MCF7_0-2_th2.csv.gz
# # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
# # or set path as the URL directly
# smiles_col: "SMILES"
# label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# # sample_size: 2000 # use sample_size for test
# task_level: graph
# splits_path: graphium/data/neurips2023/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt`
# epoch_sampling_fraction: 1.0
pcba_1328:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCBA_1328_1564k.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCBA_1328_1564k.parquet
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: assayID-* # assayID-* means all columns starting with "assayID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt`
epoch_sampling_fraction: 1.0
# pcqm4m_g25:
# df: null
# df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
# # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# # or set path as the URL directly
# smiles_col: "ordered_smiles"
# label_cols: graph_* # graph_* means all columns starting with "graph_"
# # sample_size: 2000 # use sample_size for test
# task_level: graph
# splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
# label_normalization:
# normalize_val_test: True
# method: "normal"
# epoch_sampling_fraction: 1.0
# pcqm4m_n4:
#df: null
#df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
## wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
## or set path as the URL directly
#smiles_col: "ordered_smiles"
#label_cols: node_* # node_* means all columns starting with "node_"
## sample_size: 2000 # use sample_size for test
#task_level: node
#splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
#seed: 42
#label_normalization:
# normalize_val_test: True
# method: "normal"
#epoch_sampling_fraction: 1.0
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
dataloading_from: disk
processed_graph_data_path: ${constants.datacache_path}
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 32 # -1 to use all
persistent_workers: True # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.1
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 768
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
node:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
# l1000_vcap:
# task_level: graph
# out_dim: 2934
# hidden_dims: 128
# depth: 2
# activation: none
# last_activation: none
# dropout: *dropout
# normalization: *normalization
# last_normalization: "none"
# residual_type: none
# l1000_mcf7:
# task_level: graph
# out_dim: 2934
# hidden_dims: 128
# depth: 2
# activation: none
# last_activation: none
# dropout: *dropout
# normalization: *normalization
# last_normalization: "none"
# residual_type: none
pcba_1328:
task_level: graph
out_dim: 1328
hidden_dims: 64
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
# pcqm4m_g25:
# task_level: graph
# out_dim: 25
# hidden_dims: 32
# depth: 2
# activation: relu
# last_activation: none
# dropout: *dropout
# normalization: *normalization
# last_normalization: "none"
# residual_type: none
#pcqm4m_n4:
# task_level: node
# out_dim: 4
# hidden_dims: 32
# depth: 2
# activation: relu
# last_activation: none
# dropout: *dropout
# normalization: *normalization
# last_normalization: "none"
# residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
# l1000_vcap: []
# l1000_mcf7: []
pcba_1328: []
# pcqm4m_g25: []
#pcqm4m_n4: []
metrics_on_training_set:
# l1000_vcap: []
# l1000_mcf7: []
pcba_1328: []
# pcqm4m_g25: []
#pcqm4m_n4: []
loss_fun:
# l1000_vcap:
# name: hybrid_ce_ipu
# n_brackets: 3
# alpha: 0.5
# l1000_mcf7:
# name: hybrid_ce_ipu
# n_brackets: 3
# alpha: ${predictor.loss_fun.l1000_vcap.alpha}
pcba_1328: bce_logits_ipu
# pcqm4m_g25: mae_ipu
#pcqm4m_n4: mae_ipu
random_seed: ${constants.seed}
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
# l1000_vcap: &classif_metrics
# - name: auroc
# metric: auroc
# num_classes: 3
# task: multiclass
# target_to_int: True
# target_nan_mask: -1000
# ignore_index: -1000
# multitask_handling: mean-per-label
# threshold_kwargs: null
# - name: avpr
# metric: averageprecision
# num_classes: 3
# task: multiclass
# target_to_int: True
# target_nan_mask: -1000
# ignore_index: -1000
# multitask_handling: mean-per-label
# threshold_kwargs: null
# l1000_mcf7: *classif_metrics
pcba_1328:
# use auroc and averageprecision (non_ipu version) so tha nans are handled correctly
- name: auroc
metric: auroc
task: binary
multitask_handling: mean-per-label
target_nan_mask: ignore
threshold_kwargs: null
- name: avpr
metric: averageprecision
task: binary
multitask_handling: mean-per-label
target_nan_mask: ignore
threshold_kwargs: null
# pcqm4m_n4: &pcqm_metrics
#- name: mae
#metric: mae_ipu
#target_nan_mask: null
#multitask_handling: mean-per-label
#threshold_kwargs: null
#- name: pearsonr
#metric: pearsonr_ipu
#threshold_kwargs: null
#target_nan_mask: null
#multitask_handling: mean-per-label
#- name: r2
#metric: r2_score_ipu
#threshold_kwargs: null
#target_nan_mask: null
#multitask_handling: mean-per-label
# pcqm4m_n4: *pcqm_metrics
trainer:
seed: ${constants.seed}
logger:
save_dir: logs/neurips2023-large/
name: ${constants.name}
project: ${constants.name}
model_checkpoint:
dirpath: models_checkpoints/${constants.name}/
filename: ${constants.name}
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: ${predictor.torch_scheduler_kwargs.max_num_epochs}
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/base_config/large_pcqm_g25.yaml
================================================
# @package _global_
constants:
seed: 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
entity: multitask-gnn
datacache_path: "/localdata/neurips2023-large/"
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 35 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 100
# Data handling-related
batch_size_training: 30
batch_size_inference: 30
predictor:
metrics_every_n_train_steps: 1000
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 16-true
accumulate_grad_batches: 2
ipu_config:
- deviceIterations(30) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 96)
- Precision.enableStochasticRounding(True)
# - Precision.enableFloatingPointExceptions(True)
ipu_inference_config:
# set device iteration and replication factor to 1 during inference
# gradient accumulation was set to 1 in the code
- deviceIterations(1)
- replicationFactor(1)
- Precision.enableStochasticRounding(False)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# args:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
# df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_VCAP_0-2_th2.csv.gz
# # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
# # or set path as the URL directly
# smiles_col: "SMILES"
# label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# # sample_size: 2000 # use sample_size for test
# task_level: graph
# splits_path: graphium/data/neurips2023/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt`
# epoch_sampling_fraction: 1.0
# l1000_mcf7:
# df: null
# df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_MCF7_0-2_th2.csv.gz
# # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
# # or set path as the URL directly
# smiles_col: "SMILES"
# label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# # sample_size: 2000 # use sample_size for test
# task_level: graph
# splits_path: graphium/data/neurips2023/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt`
# epoch_sampling_fraction: 1.0
# pcba_1328:
# df: null
# df_path: graphium/data/neurips2023/large-dataset/PCBA_1328_1564k.parquet
# # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCBA_1328_1564k.parquet
# # or set path as the URL directly
# smiles_col: "SMILES"
# label_cols: assayID-* # assayID-* means all columns starting with "assayID-"
# # sample_size: 2000 # use sample_size for test
# task_level: graph
# splits_path: graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt`
# epoch_sampling_fraction: 1.0
pcqm4m_g25:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: graph_* # graph_* means all columns starting with "graph_"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
label_normalization:
normalize_val_test: True
method: "normal"
epoch_sampling_fraction: 1.0
# pcqm4m_n4:
# df: null
# df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
# # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# # or set path as the URL directly
# smiles_col: "ordered_smiles"
# label_cols: node_* # node_* means all columns starting with "node_"
# # sample_size: 2000 # use sample_size for test
# task_level: node
# splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
# seed: 42
# label_normalization:
# normalize_val_test: True
# method: "normal"
# epoch_sampling_fraction: 1.0
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
dataloading_from: disk
processed_graph_data_path: ${constants.datacache_path}
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 32 # -1 to use all
persistent_workers: True # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.1
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 768
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
node:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
# l1000_vcap:
# task_level: graph
# out_dim: 2934
# hidden_dims: 128
# depth: 2
# activation: none
# last_activation: none
# dropout: *dropout
# normalization: *normalization
# last_normalization: "none"
# residual_type: none
# l1000_mcf7:
# task_level: graph
# out_dim: 2934
# hidden_dims: 128
# depth: 2
# activation: none
# last_activation: none
# dropout: *dropout
# normalization: *normalization
# last_normalization: "none"
# residual_type: none
# pcba_1328:
# task_level: graph
# out_dim: 1328
# hidden_dims: 64
# depth: 2
# activation: relu
# last_activation: none
# dropout: *dropout
# normalization: *normalization
# last_normalization: "none"
# residual_type: none
pcqm4m_g25:
task_level: graph
out_dim: 25
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
# pcqm4m_n4:
# task_level: node
# out_dim: 4
# hidden_dims: 32
# depth: 2
# activation: relu
# last_activation: none
# dropout: *dropout
# normalization: *normalization
# last_normalization: "none"
# residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
# l1000_vcap: []
# l1000_mcf7: []
# pcba_1328: []
pcqm4m_g25: []
# pcqm4m_n4: []
metrics_on_training_set:
# l1000_vcap: []
# l1000_mcf7: []
# pcba_1328: []
pcqm4m_g25: []
# pcqm4m_n4: []
loss_fun:
# l1000_vcap:
# name: hybrid_ce_ipu
# n_brackets: 3
# alpha: 0.5
# l1000_mcf7:
# name: hybrid_ce_ipu
# n_brackets: 3
# alpha: ${predictor.loss_fun.l1000_vcap.alpha}
# pcba_1328: bce_logits_ipu
pcqm4m_g25: mae_ipu
# pcqm4m_n4: mae_ipu
random_seed: ${constants.seed}
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
# l1000_vcap: &classif_metrics
# - name: auroc
# metric: auroc
# num_classes: 3
# task: multiclass
# target_to_int: True
# target_nan_mask: -1000
# ignore_index: -1000
# multitask_handling: mean-per-label
# threshold_kwargs: null
# - name: avpr
# metric: averageprecision
# num_classes: 3
# task: multiclass
# target_to_int: True
# target_nan_mask: -1000
# ignore_index: -1000
# multitask_handling: mean-per-label
# threshold_kwargs: null
# l1000_mcf7: *classif_metrics
# pcba_1328:
# # use auroc and averageprecision (non_ipu version) so tha nans are handled correctly
# - name: auroc
# metric: auroc
# task: binary
# multitask_handling: mean-per-label
# target_nan_mask: ignore
# threshold_kwargs: null
# - name: avpr
# metric: averageprecision
# task: binary
# multitask_handling: mean-per-label
# target_nan_mask: ignore
# threshold_kwargs: null
pcqm4m_g25: &pcqm_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2
metric: r2_score_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
# pcqm4m_n4: *pcqm_metrics
trainer:
seed: ${constants.seed}
logger:
save_dir: logs/neurips2023-large/
name: ${constants.name}
project: ${constants.name}
model_checkpoint:
dirpath: models_checkpoints/${constants.name}/
filename: ${constants.name}
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: ${predictor.torch_scheduler_kwargs.max_num_epochs}
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/base_config/large_pcqm_n4.yaml
================================================
# @package _global_
constants:
seed: 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
entity: multitask-gnn
datacache_path: "/localdata/neurips2023-large/"
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 35 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 100
# Data handling-related
batch_size_training: 30
batch_size_inference: 30
predictor:
metrics_every_n_train_steps: 1000
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 16-true
accumulate_grad_batches: 2
ipu_config:
- deviceIterations(30) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 96)
- Precision.enableStochasticRounding(True)
# - Precision.enableFloatingPointExceptions(True)
ipu_inference_config:
# set device iteration and replication factor to 1 during inference
# gradient accumulation was set to 1 in the code
- deviceIterations(1)
- replicationFactor(1)
- Precision.enableStochasticRounding(False)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# args:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
# df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_VCAP_0-2_th2.csv.gz
# # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
# # or set path as the URL directly
# smiles_col: "SMILES"
# label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# # sample_size: 2000 # use sample_size for test
# task_level: graph
# splits_path: graphium/data/neurips2023/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt`
# epoch_sampling_fraction: 1.0
# l1000_mcf7:
# df: null
# df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_MCF7_0-2_th2.csv.gz
# # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
# # or set path as the URL directly
# smiles_col: "SMILES"
# label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# # sample_size: 2000 # use sample_size for test
# task_level: graph
# splits_path: graphium/data/neurips2023/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt`
# epoch_sampling_fraction: 1.0
# pcba_1328:
# df: null
# df_path: graphium/data/neurips2023/large-dataset/PCBA_1328_1564k.parquet
# # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCBA_1328_1564k.parquet
# # or set path as the URL directly
# smiles_col: "SMILES"
# label_cols: assayID-* # assayID-* means all columns starting with "assayID-"
# # sample_size: 2000 # use sample_size for test
# task_level: graph
# splits_path: graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt`
# epoch_sampling_fraction: 1.0
# pcqm4m_g25:
# df: null
# df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
# # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# # or set path as the URL directly
# smiles_col: "ordered_smiles"
# label_cols: graph_* # graph_* means all columns starting with "graph_"
# # sample_size: 2000 # use sample_size for test
# task_level: graph
# splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
# label_normalization:
# normalize_val_test: True
# method: "normal"
# epoch_sampling_fraction: 1.0
pcqm4m_n4:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: node_* # node_* means all columns starting with "node_"
# sample_size: 2000 # use sample_size for test
task_level: node
splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
seed: 42
label_normalization:
normalize_val_test: True
method: "normal"
epoch_sampling_fraction: 1.0
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
dataloading_from: disk
processed_graph_data_path: ${constants.datacache_path}
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 32 # -1 to use all
persistent_workers: True # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.1
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 768
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
node:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
# l1000_vcap:
# task_level: graph
# out_dim: 2934
# hidden_dims: 128
# depth: 2
# activation: none
# last_activation: none
# dropout: *dropout
# normalization: *normalization
# last_normalization: "none"
# residual_type: none
# l1000_mcf7:
# task_level: graph
# out_dim: 2934
# hidden_dims: 128
# depth: 2
# activation: none
# last_activation: none
# dropout: *dropout
# normalization: *normalization
# last_normalization: "none"
# residual_type: none
# pcba_1328:
# task_level: graph
# out_dim: 1328
# hidden_dims: 64
# depth: 2
# activation: relu
# last_activation: none
# dropout: *dropout
# normalization: *normalization
# last_normalization: "none"
# residual_type: none
# pcqm4m_g25:
# task_level: graph
# out_dim: 25
# hidden_dims: 32
# depth: 2
# activation: relu
# last_activation: none
# dropout: *dropout
# normalization: *normalization
# last_normalization: "none"
# residual_type: none
pcqm4m_n4:
task_level: node
out_dim: 4
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
# l1000_vcap: []
# l1000_mcf7: []
# pcba_1328: []
# pcqm4m_g25: []
pcqm4m_n4: []
metrics_on_training_set:
# l1000_vcap: []
# l1000_mcf7: []
# pcba_1328: []
# pcqm4m_g25: []
pcqm4m_n4: []
loss_fun:
# l1000_vcap:
# name: hybrid_ce_ipu
# n_brackets: 3
# alpha: 0.5
# l1000_mcf7:
# name: hybrid_ce_ipu
# n_brackets: 3
# alpha: ${predictor.loss_fun.l1000_vcap.alpha}
# pcba_1328: bce_logits_ipu
# pcqm4m_g25: mae_ipu
pcqm4m_n4: mae_ipu
random_seed: ${constants.seed}
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
# l1000_vcap: &classif_metrics
# - name: auroc
# metric: auroc
# num_classes: 3
# task: multiclass
# target_to_int: True
# target_nan_mask: -1000
# ignore_index: -1000
# multitask_handling: mean-per-label
# threshold_kwargs: null
# - name: avpr
# metric: averageprecision
# num_classes: 3
# task: multiclass
# target_to_int: True
# target_nan_mask: -1000
# ignore_index: -1000
# multitask_handling: mean-per-label
# threshold_kwargs: null
# l1000_mcf7: *classif_metrics
# pcba_1328:
# # use auroc and averageprecision (non_ipu version) so tha nans are handled correctly
# - name: auroc
# metric: auroc
# task: binary
# multitask_handling: mean-per-label
# target_nan_mask: ignore
# threshold_kwargs: null
# - name: avpr
# metric: averageprecision
# task: binary
# multitask_handling: mean-per-label
# target_nan_mask: ignore
# threshold_kwargs: null
pcqm4m_n4: &pcqm_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2
metric: r2_score_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
# pcqm4m_n4: *pcqm_metrics
trainer:
seed: ${constants.seed}
logger:
save_dir: logs/neurips2023-large/
name: ${constants.name}
project: ${constants.name}
model_checkpoint:
dirpath: models_checkpoints/${constants.name}/
filename: ${constants.name}
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: ${predictor.torch_scheduler_kwargs.max_num_epochs}
min_epochs: 1
check_val_every_n_epoch: 20
h: 20
================================================
FILE: expts/neurips2023_configs/base_config/small.yaml
================================================
# @package _global_
constants:
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
entity: multitask-gnn
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 44 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 80
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 44 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 80
# Data handling-related
batch_size_training: 50
batch_size_inference: 50
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 16
accumulate_grad_batches: 4
ipu_config:
- deviceIterations(5) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
qm9:
df: null
df_path: data/neurips2023/small-dataset/qm9.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/qm9.csv.gz
# or set path as the URL directly
smiles_col: "smiles"
label_cols: ["A", "B", "C", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "u0", "u298", "h298", "g298", "cv", "u0_atom", "u298_atom", "h298_atom", "g298_atom"]
# sample_size: 2000 # use sample_size for test
splits_path: data/neurips2023/small-dataset/qm9_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/qm9_random_splits.pt`
seed: *seed
task_level: graph
label_normalization:
normalize_val_test: True
method: "normal"
tox21:
df: null
df_path: data/neurips2023/small-dataset/Tox21-7k-12-labels.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz
# or set path as the URL directly
smiles_col: "smiles"
label_cols: ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"]
# sample_size: 2000 # use sample_size for test
splits_path: data/neurips2023/small-dataset/Tox21_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/Tox21_random_splits.pt`
seed: *seed
task_level: graph
zinc:
df: null
df_path: data/neurips2023/small-dataset/ZINC12k.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz
# or set path as the URL directly
smiles_col: "smiles"
label_cols: ["SA", "logp", "score"]
# sample_size: 2000 # use sample_size for test
splits_path: data/neurips2023/small-dataset/ZINC12k_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt`
seed: *seed
task_level: graph
label_normalization:
normalize_val_test: True
method: "normal"
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-small/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: null # Set as null to avoid a pre-nn network
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 96
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_kwargs: null # Parameters for the model itself. You could define dropout_attn: 0.1
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
qm9:
task_level: graph
out_dim: 19
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
tox21:
task_level: graph
out_dim: 12
hidden_dims: 64
depth: 2
activation: relu
last_activation: sigmoid
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
zinc:
task_level: graph
out_dim: 3
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
qm9: ["mae"]
tox21: ["auroc"]
zinc: ["mae"]
loss_fun:
qm9: mae_ipu
tox21: bce_ipu
zinc: mae_ipu
random_seed: *seed
optim_kwargs:
lr: 4.e-5 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
qm9: &qm9_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2_score
metric: r2_score_ipu
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
tox21:
- name: auroc
metric: auroc_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
- name: f1 > 0.5
metric: f1
multitask_handling: mean-per-label
target_to_int: True
num_classes: 2
average: micro
threshold_kwargs: &threshold_05
operator: greater
threshold: 0.5
th_on_preds: True
th_on_target: True
- name: precision > 0.5
metric: precision
multitask_handling: mean-per-label
average: micro
threshold_kwargs: *threshold_05
zinc: *qm9_metrics
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-small/
name: ${constants.name}
project: ${constants.name}
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/${constants.name}/
filename: ${constants.name}
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml
================================================
# Testing the gcn model with the PCQMv2 dataset on IPU.
constants:
name: &name neurips2023_small_data_gcn
seed: &seed 3000
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 44 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 80
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 100 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 200
# Data handling-related
batch_size_training: 5
batch_size_inference: 2
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 16
accumulate_grad_batches: 4
ipu_config:
- deviceIterations(5) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
qm9:
df: null
df_path: data/neurips2023/small-dataset/qm9.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/qm9.csv.gz
# or set path as the URL directly
smiles_col: "smiles"
label_cols: ["A", "B", "C", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "u0", "u298", "h298", "g298", "cv", "u0_atom", "u298_atom", "h298_atom", "g298_atom"]
# sample_size: 2000 # use sample_size for test
splits_path: data/neurips2023/small-dataset/qm9_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/qm9_random_splits.pt`
seed: *seed
task_level: graph
label_normalization:
normalize_val_test: True
method: "normal"
tox21:
df: null
df_path: data/neurips2023/small-dataset/Tox21-7k-12-labels.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz
# or set path as the URL directly
smiles_col: "smiles"
label_cols: ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"]
# sample_size: 2000 # use sample_size for test
splits_path: data/neurips2023/small-dataset/Tox21_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/Tox21_random_splits.pt`
seed: *seed
task_level: graph
zinc:
df: null
df_path: data/neurips2023/small-dataset/ZINC12k.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz
# or set path as the URL directly
smiles_col: "smiles"
label_cols: ["SA", "logp", "score"]
# sample_size: 2000 # use sample_size for test
splits_path: data/neurips2023/small-dataset/ZINC12k_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt`
seed: *seed
task_level: graph
label_normalization:
normalize_val_test: True
method: "normal"
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-small/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.1
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: null # Set as null to avoid a pre-nn network
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 128
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
layer_kwargs: null # Parameters for the model itself. You could define dropout_attn: 0.1
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
qm9:
task_level: graph
out_dim: 19
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
tox21:
task_level: graph
out_dim: 12
hidden_dims: 64
depth: 2
activation: relu
last_activation: sigmoid
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
zinc:
task_level: graph
out_dim: 3
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
qm9: ["mae"]
tox21: ["auroc"]
zinc: ["mae"]
loss_fun:
qm9: mae_ipu
tox21: bce_ipu
zinc: mae_ipu
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 300
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
qm9: &qm9_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2_score
metric: r2_score_ipu
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
tox21:
- name: auroc
metric: auroc_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
- name: f1 > 0.5
metric: f1
multitask_handling: flatten
target_to_int: True
num_classes: 2
average: micro
threshold_kwargs: &threshold_05
operator: greater
threshold: 0.5
th_on_preds: True
th_on_target: True
- name: precision > 0.5
metric: precision
multitask_handling: flatten
average: micro
threshold_kwargs: *threshold_05
zinc: *qm9_metrics
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-small/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-small-gcn/
filename: *name
#monitor: *monitor
#mode: *mode
save_top_k: 1
every_n_epochs: 100
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/baseline/config_small_gin_baseline.yaml
================================================
# Testing the gin model with the PCQMv2 dataset on IPU.
constants:
name: &name neurips2023_small_data_gin
config_override: "expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml"
seed: &seed 1000
architecture:
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 96
hidden_dims: *gnn_dim
layer_type: 'pyg:gin' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
trainer:
seed: *seed
logger:
name: *name
project: *name
model_checkpoint:
dirpath: models_checkpoints/neurips2023-small-gin/
filename: *name
================================================
FILE: expts/neurips2023_configs/baseline/config_small_gine_baseline.yaml
================================================
# Testing the gine model with the PCQMv2 dataset on IPU.
constants:
name: &name neurips2023_small_data_gine
config_override: "expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml"
seed: &seed 1000
architecture:
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: 0.1
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
gnn: # Set as null to avoid a post-nn network
out_dim: &gnn_dim 96
hidden_dims: *gnn_dim
layer_type: 'pyg:gine' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
trainer:
seed: *seed
logger:
name: *name
project: *name
model_checkpoint:
dirpath: models_checkpoints/neurips2023-small-gine/
filename: *name
================================================
FILE: expts/neurips2023_configs/config_classifigression_l1000.yaml
================================================
# Testing the mpnn only model with the PCQMv2 dataset on IPU.
constants:
name: &name neurips2023_small_data_mpnn
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
#accelerator:
# type: ipu # cpu or ipu or gpu
# config_override:
# datamodule:
# args:
# ipu_dataloader_training_opts:
# mode: async
# max_num_nodes_per_graph: 24 # train max nodes: 20, max_edges: 54
# max_num_edges_per_graph: 60
# ipu_dataloader_inference_opts:
# mode: async
# max_num_nodes_per_graph: 24 # valid max nodes: 51, max_edges: 118
# max_num_edges_per_graph: 60
# # Data handling-related
# batch_size_training: 50
# batch_size_inference: 50
## predictor:
## optim_kwargs:
## loss_scaling: 1024
# trainer:
# trainer:
# precision: 16
# accumulate_grad_batches: 4
#
# ipu_config:
# - deviceIterations(20) # IPU would require large batches to be ready for the model.
# - replicationFactor(16)
# # - enableProfiling("graph_analyser") # The folder where the profile will be stored
# # - enableExecutableCaching("pop_compiler_cache")
# - TensorLocations.numIOTiles(128)
# - _Popart.set("defaultBufferingDepth", 128)
# - Precision.enableStochasticRounding(True)
accelerator:
type: gpu # cpu or ipu or gpu
config_override:
datamodule:
batch_size_training: 64
batch_size_inference: 256
trainer:
trainer:
precision: 32
accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
l1000_vcap:
df: null
df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/small-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt`
l1000_mcf7:
df: null
df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/small-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt`
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 1
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-small/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 5 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
out_dim: &gnn_dim 64
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gps' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
layer_kwargs: # Parameters for the model itself. You could define dropout_attn: 0.1
node_residual: false
mpnn_type: 'pyg:mpnnplus'
mpnn_kwargs:
in_dim: *gnn_dim
out_dim: *gnn_dim
in_dim_edges: 32
out_dim_edges: 32
attn_type: "none" # "full-attention", "none"
attn_kwargs: null
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
l1000_vcap:
task_level: graph
out_dim: 4890 # n_columns (978) x n_brackets (5)
hidden_dims: 128
depth: 2
activation: none
last_activation: sigmoid
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
l1000_mcf7:
task_level: graph
out_dim: 4890 # n_columns (978) x n_brackets (5)
hidden_dims: 128
depth: 2
activation: none
last_activation: sigmoid
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
l1000_vcap: []
l1000_mcf7: []
loss_fun:
l1000_vcap:
name: hybrid_ce
n_brackets: 5
l1000_mcf7:
name: hybrid_ce
n_brackets: 5
random_seed: *seed
optim_kwargs:
lr: 4.e-3 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 0
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: ignore
multitask_handling: flatten
# Task-specific
metrics:
l1000_vcap: &classif_metrics
- name: auc
metric: auroc
target_nan_mask: ignore
multitask_handling: flatten
squeeze_targets: True
target_to_int: True
task: multiclass
num_classes: 5
- name: ap
metric: averageprecision
target_nan_mask: ignore
multitask_handling: flatten
squeeze_targets: True
target_to_int: True
task: multiclass
num_classes: 5
- name: accuracy
metric: accuracy
target_nan_mask: ignore
multitask_handling: flatten
squeeze_targets: True
target_to_int: True
task: multiclass
num_classes: 5
top_k: 1
l1000_mcf7: *classif_metrics
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-small/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-small/
filename: *name
#monitor: *monitor
#mode: *mode
save_top_k: 1
every_n_epochs: 100
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 1
================================================
FILE: expts/neurips2023_configs/config_large_gcn.yaml
================================================
# Running the gcn model with the largemix dataset on IPU.
defaults:
- base_config: large
- _self_
constants:
name: neurips2023_large_data_gcn
wandb:
name: ${constants.name}
project: neurips2023_large_graphcore
entity: multitask-gnn
architecture:
gnn: # Set as null to avoid a post-nn network
layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
================================================
FILE: expts/neurips2023_configs/config_large_gcn_g25.yaml
================================================
# Running the gcn model with the largemix dataset on IPU.
defaults:
# - base_config: large
- base_config: large_pcqm_g25
# - base_config: large_pcqm_n4
- _self_
constants:
name: neurips2023_large_data_gcn
wandb:
name: ${constants.name}
project: neurips2023_large_graphcore
entity: multitask-gnn
architecture:
gnn: # Set as null to avoid a post-nn network
layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
================================================
FILE: expts/neurips2023_configs/config_large_gcn_gpu.yaml
================================================
# Testing GCN on LargeMix with FP16/32 on GPU
defaults:
- base_config: large
- _self_
constants:
name: neurips2023_large_data_gcn_gpu
architecture:
gnn: # Set as null to avoid a post-nn network
layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
accelerator:
type: gpu
float32_matmul_precision: medium
config_override:
datamodule:
args:
batch_size_training: 80
batch_size_inference: 80
predictor:
metrics_every_n_train_steps: 1000
trainer:
trainer:
precision: 32 # 16-mixed
accumulate_grad_batches: 1
datamodule:
args:
task_specific_args:
l1000_vcap:
df_path: expts/data/neurips2023/large-dataset/LINCS_L1000_VCAP_0-4.csv.gz # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
splits_path: expts/data/neurips2023/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt`
l1000_mcf7:
df_path: expts/data/neurips2023/large-dataset/LINCS_L1000_MCF7_0-4.csv.gz # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
splits_path: expts/data/neurips2023/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt`
pcba_1328:
df_path: expts/data/neurips2023/large-dataset/PCBA_1328_1564k.parquet # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCBA_1328_1564k.parquet
splits_path: expts/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt`
pcqm4m_g25:
df_path: expts/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
splits_path: expts/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
pcqm4m_n4:
df_path: expts/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet # wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
splits_path: expts/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
featurization_n_jobs: 4 # 30
processed_graph_data_path: "../datacache/neurips2023-small/"
num_workers: 4 # 30
architecture:
task_heads:
pcba_1328:
last_activation: null
predictor:
loss_fun:
pcba_1328: bce_logits_ipu
torch_scheduler_kwargs:
max_num_epochs: &max_epochs 20
trainer:
trainer:
max_epochs: *max_epochs
check_val_every_n_epoch: 1
================================================
FILE: expts/neurips2023_configs/config_large_gcn_n4.yaml
================================================
# Running the gcn model with the largemix dataset on IPU.
defaults:
# - base_config: large
# - base_config: large_pcqm_g25
- base_config: large_pcqm_n4
- _self_
constants:
name: neurips2023_large_data_gcn
wandb:
name: ${constants.name}
project: neurips2023_large_graphcore
entity: multitask-gnn
architecture:
gnn: # Set as null to avoid a post-nn network
layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
================================================
FILE: expts/neurips2023_configs/config_large_gcn_pcba.yaml
================================================
# Running the gcn model with the largemix dataset on IPU.
defaults:
# - base_config: large
# - base_config: large_pcqm_g25
# - base_config: large_pcqm_n4
- base_config: large_pcba
- _self_
constants:
name: neurips2023_large_data_gcn
wandb:
name: ${constants.name}
project: neurips2023_large_graphcore
entity: multitask-gnn
architecture:
gnn: # Set as null to avoid a post-nn network
layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
================================================
FILE: expts/neurips2023_configs/config_large_gin.yaml
================================================
# Running the gin model with the largemix dataset on IPU.
defaults:
- base_config: large
- _self_
constants:
name: neurips2023_large_data_gin
wandb:
name: ${constants.name}
project: neurips2023_large_graphcore
entity: multitask-gnn
architecture:
gnn: # Set as null to avoid a post-nn network
layer_type: 'pyg:gin'
================================================
FILE: expts/neurips2023_configs/config_large_gin_g25.yaml
================================================
# Running the gin model with the largemix dataset on IPU.
defaults:
# - base_config: large
- base_config: large_pcqm_g25
# - base_config: large_pcqm_n4
- _self_
constants:
name: neurips2023_large_data_gin
wandb:
name: ${constants.name}
project: neurips2023_large_graphcore
entity: multitask-gnn
architecture:
gnn: # Set as null to avoid a post-nn network
layer_type: 'pyg:gin'
================================================
FILE: expts/neurips2023_configs/config_large_gin_n4.yaml
================================================
# Running the gin model with the largemix dataset on IPU.
defaults:
# - base_config: large
# - base_config: large_pcqm_g25
- base_config: large_pcqm_n4
- _self_
constants:
name: neurips2023_large_data_gin
wandb:
name: ${constants.name}
project: neurips2023_large_graphcore
entity: multitask-gnn
architecture:
gnn: # Set as null to avoid a post-nn network
layer_type: 'pyg:gin'
================================================
FILE: expts/neurips2023_configs/config_large_gin_pcba.yaml
================================================
# Running the gin model with the largemix dataset on IPU.
defaults:
# - base_config: large
# - base_config: large_pcqm_g25
# - base_config: large_pcqm_n4
- base_config: large_pcba
- _self_
constants:
name: neurips2023_large_data_gin
wandb:
name: ${constants.name}
project: neurips2023_large_graphcore
entity: multitask-gnn
architecture:
gnn: # Set as null to avoid a post-nn network
layer_type: 'pyg:gin'
================================================
FILE: expts/neurips2023_configs/config_large_gine.yaml
================================================
# Running the gine model with the largemix dataset on IPU.
defaults:
- base_config: large
# - base_config: large_pcqm_g25
# - base_config: large_pcqm_n4
- _self_
constants:
name: neurips2023_large_data_gine
wandb:
name: ${constants.name}
project: neurips2023_large_graphcore
entity: multitask-gnn
architecture:
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: ${architecture.pre_nn.normalization}
residual_type: none
gnn:
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
layer_type: 'pyg:gine'
graph_output_nn:
graph:
out_dim: *gnn_dim
hidden_dims: *gnn_dim
node:
out_dim: *gnn_dim
hidden_dims: *gnn_dim
================================================
FILE: expts/neurips2023_configs/config_large_gine_g25.yaml
================================================
# Running the gine model with the largemix dataset on IPU.
defaults:
# - base_config: large
- base_config: large_pcqm_g25
# - base_config: large_pcqm_n4
- _self_
constants:
name: neurips2023_large_data_gine
wandb:
name: ${constants.name}
project: neurips2023_large_graphcore
entity: multitask-gnn
architecture:
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: ${architecture.pre_nn.normalization}
residual_type: none
gnn:
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
layer_type: 'pyg:gine'
graph_output_nn:
graph:
out_dim: *gnn_dim
hidden_dims: *gnn_dim
node:
out_dim: *gnn_dim
hidden_dims: *gnn_dim
================================================
FILE: expts/neurips2023_configs/config_large_gine_n4.yaml
================================================
# Running the gine model with the largemix dataset on IPU.
defaults:
# - base_config: large
# - base_config: large_pcqm_g25
- base_config: large_pcqm_n4
- _self_
constants:
name: neurips2023_large_data_gine
wandb:
name: ${constants.name}
project: neurips2023_large_graphcore
entity: multitask-gnn
architecture:
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: ${architecture.pre_nn.normalization}
residual_type: none
gnn:
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
layer_type: 'pyg:gine'
graph_output_nn:
graph:
out_dim: *gnn_dim
hidden_dims: *gnn_dim
node:
out_dim: *gnn_dim
hidden_dims: *gnn_dim
================================================
FILE: expts/neurips2023_configs/config_large_gine_pcba.yaml
================================================
# Running the gine model with the largemix dataset on IPU.
defaults:
# - base_config: large
# - base_config: large_pcqm_g25
# - base_config: large_pcqm_n4
- base_config: large_pcba
- _self_
constants:
name: neurips2023_large_data_gine
wandb:
name: ${constants.name}
project: neurips2023_large_graphcore
entity: multitask-gnn
architecture:
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: ${architecture.pre_nn.normalization}
residual_type: none
gnn:
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
layer_type: 'pyg:gine'
graph_output_nn:
graph:
out_dim: *gnn_dim
hidden_dims: *gnn_dim
node:
out_dim: *gnn_dim
hidden_dims: *gnn_dim
================================================
FILE: expts/neurips2023_configs/config_large_mpnn.yaml
================================================
# Running the mpnn model with the largemix dataset on IPU.
defaults:
- base_config: large
constants:
name: neurips2023_large_data_mpnn
architecture:
pre_nn:
out_dim: 160
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.1
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges:
out_dim: 64
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: 0.18
normalization: ${architecture.pre_nn.normalization}
last_normalization: ${architecture.pre_nn.normalization}
residual_type: none
gnn: # Set as null to avoid a post-nn network
in_dim: 160 # should be consistent with pre_nn.out_dim
out_dim: 256
hidden_dims: &gnn_dim 160 # should consistent with pre_nn.out_dim when multi-layer mpnn is used (ffn layer)
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gps'
layer_kwargs:
node_residual: false
mpnn_type: 'pyg:mpnnplus'
mpnn_kwargs:
in_dim: 160 # should consistent with pre_nn.out_dim when multi-layer mpnn is used (node_model layer)
out_dim: 160 # should consistent with pre_nn.out_dim when multi-layer mpnn is used (node_model layer)
in_dim_edges: 64 # should consistent with pre_nn_edges.out_dim when multi-layer mpnn is used (edge_model layer)
out_dim_edges: 64 # should consistent with pre_nn_edges.out_dim when multi-layer mpnn is used (edge_model layer)
attn_type: "none" # "full-attention", "none"
# biased_attention: false
attn_kwargs: null
================================================
FILE: expts/neurips2023_configs/config_luis_jama.yaml
================================================
# Testing the mpnn only model with the PCQMv2 dataset on IPU.
constants:
name: &name neurips2023_small_data_mpnn
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
# accelerator:
# type: ipu # cpu or ipu or gpu
# config_override:
# datamodule:
# args:
# ipu_dataloader_training_opts:
# mode: async
# max_num_nodes_per_graph: 24 # train max nodes: 20, max_edges: 54
# max_num_edges_per_graph: 60
# ipu_dataloader_inference_opts:
# mode: async
# max_num_nodes_per_graph: 24 # valid max nodes: 51, max_edges: 118
# max_num_edges_per_graph: 60
# # Data handling-related
# batch_size_training: 50
# batch_size_inference: 50
# predictor:
# optim_kwargs:
# loss_scaling: 1024
# trainer:
# trainer:
# precision: 16
# accumulate_grad_batches: 4
# ipu_config:
# - deviceIterations(20) # IPU would require large batches to be ready for the model.
# - replicationFactor(16)
# # - enableProfiling("graph_analyser") # The folder where the profile will be stored
# # - enableExecutableCaching("pop_compiler_cache")
# - TensorLocations.numIOTiles(128)
# - _Popart.set("defaultBufferingDepth", 128)
# - Precision.enableStochasticRounding(True)
accelerator:
type: cpu # cpu or ipu or gpu
config_override:
datamodule:
batch_size_training: 64
batch_size_inference: 256
trainer:
trainer:
precision: 32
accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
pcqm20k_g13:
df: null
df_path: graphium/data/neurips2023/dummy-dataset/PCQM20k_G13_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Dummy-dataset/PCQM20k_G13_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: graph_* # graph_* means all columns starting with "graph_"
sample_size: 2000 # use sample_size for test
split_val: 0.1
split_test: 0.1
task_level: graph
label_normalization:
method: "normal"
pcqm20k_n4:
df: null
df_path: graphium/data/neurips2023/dummy-dataset/PCQM20k_G13_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Dummy-dataset/PCQM20k_G13_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: node_* # node_* means all columns starting with "node_"
sample_size: 2000 # use sample_size for test
split_val: 0.1
split_test: 0.1
seed: *seed
task_level: node
label_normalization:
method: "normal"
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-small/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 4 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
out_dim: &gnn_dim 64
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gps' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
layer_kwargs: # Parameters for the model itself. You could define dropout_attn: 0.1
node_residual: false
mpnn_type: 'pyg:mpnnplus'
mpnn_kwargs:
in_dim: *gnn_dim
out_dim: *gnn_dim
in_dim_edges: 32
out_dim_edges: 32
attn_type: "none" # "full-attention", "none"
attn_kwargs: null
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
node:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
pcqm20k_g13:
task_level: graph
out_dim: 13
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
pcqm20k_n4:
task_level: node
out_dim: 4
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
pcqm20k_g13: []
pcqm20k_n4: []
loss_fun:
pcqm20k_g13: mae_ipu
pcqm20k_n4: mae_ipu
random_seed: *seed
optim_kwargs:
lr: 4.e-5 # warmup can be scheduled using torch_scheduler_kwargs
#weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: ignore # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
pcqm20k_g13: &pcqm_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
pcqm20k_n4: *pcqm_metrics
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-small/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-small/
filename: *name
#monitor: *monitor
#mode: *mode
save_top_k: 1
every_n_epochs: 100
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/config_small_gated_gcn.yaml
================================================
# Testing the gated_gcn model with the PCQMv2 dataset on IPU.
defaults:
- base_config: small
- _self_
constants:
name: neurips2023_small_data_gated_gcn
architecture:
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: ${architecture.pre_nn.normalization}
residual_type: none
gnn: # Set as null to avoid a post-nn network
layer_type: 'pyg:gated-gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
================================================
FILE: expts/neurips2023_configs/config_small_gcn.yaml
================================================
# Testing the gcn model with the toymix dataset on IPU.
defaults:
- base_config: small
- _self_
constants:
name: neurips2023_small_data_gcn
architecture:
gnn: # Set as null to avoid a post-nn network
layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
================================================
FILE: expts/neurips2023_configs/config_small_gcn_gpu.yaml
================================================
# Testing GCN on ToyMix with FP16/32 on GPU
defaults:
- base_config: small
- _self_
constants:
name: neurips2023_small_data_gcn_gpu
architecture:
gnn: # Set as null to avoid a post-nn network
layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
accelerator:
type: gpu # cpu or ipu or gpu
float32_matmul_precision: medium
config_override:
datamodule:
args:
# Data handling-related
batch_size_training: 200
batch_size_inference: 200
predictor:
metrics_every_n_train_steps: 300
trainer:
trainer:
precision: 32 # 16-mixed
accumulate_grad_batches: 1
datamodule:
args:
task_specific_args: # To be replaced by a new class "DatasetParams"
qm9:
df_path: expts/data/neurips2023/small-dataset/qm9.csv.gz
splits_path: expts/data/neurips2023/small-dataset/qm9_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/qm9_random_splits.pt`
tox21:
df_path: expts/data/neurips2023/small-dataset/Tox21-7k-12-labels.csv.gz
splits_path: expts/data/neurips2023/small-dataset/Tox21_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/Tox21_random_splits.pt`
zinc:
df_path: expts/data/neurips2023/small-dataset/ZINC12k.csv.gz
splits_path: expts/data/neurips2023/small-dataset/ZINC12k_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt`
featurization_n_jobs: 4 # 30
processed_graph_data_path: "../datacache/neurips2023-small/"
num_workers: 4 # 30
architecture:
task_heads:
tox21:
last_activation: none
predictor:
loss_fun:
tox21: bce_logits_ipu
torch_scheduler_kwargs:
max_num_epochs: &max_epochs 300
trainer:
trainer:
max_epochs: *max_epochs
================================================
FILE: expts/neurips2023_configs/config_small_gin.yaml
================================================
# Testing the gin model with the PCQMv2 dataset on IPU.
defaults:
- base_config: small
- _self_
constants:
name: neurips2023_small_data_gin
architecture:
gnn: # Set as null to avoid a post-nn network
layer_type: 'pyg:gin'
================================================
FILE: expts/neurips2023_configs/config_small_gine.yaml
================================================
# Testing the gine model with the PCQMv2 dataset on IPU.
defaults:
- base_config: small
- _self_
constants:
name: neurips2023_small_data_gine
architecture:
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: ${architecture.pre_nn.normalization}
residual_type: none
gnn: # Set as null to avoid a post-nn network
layer_type: 'pyg:gine' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
================================================
FILE: expts/neurips2023_configs/config_small_mpnn.yaml
================================================
# Testing the mpnn only model with the PCQMv2 dataset on IPU.
defaults:
- base_config: small
constants:
name: neurips2023_small_data_mpnn
architecture:
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: ${architecture.pre_nn.normalization}
residual_type: none
gnn: # Set as null to avoid a post-nn network
out_dim: &gnn_dim 64
hidden_dims: *gnn_dim
layer_type: 'pyg:gps' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
layer_kwargs: # Parameters for the model itself. You could define dropout_attn: 0.1
mpnn_type: 'pyg:mpnnplus'
out_dim_edges: 32
================================================
FILE: expts/neurips2023_configs/single_task_gcn/config_large_gcn_mcf7.yaml
================================================
# Testing the gcn model
constants:
name: &name neurips2023_large_data_gcn_mcf7
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 60 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 100
# Data handling-related
batch_size_training: 10
batch_size_inference: 2
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 32
accumulate_grad_batches: 8
ipu_config:
- deviceIterations(1) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
l1000_mcf7:
df: null
df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt`
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-large/mcf7/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 768
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
l1000_mcf7:
task_level: graph
out_dim: 4890
hidden_dims: 128
depth: 2
activation: none
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
l1000_mcf7: []
metrics_on_training_set:
l1000_mcf7: []
loss_fun:
l1000_mcf7:
name: hybrid_ce_ipu
n_brackets: 5
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 20
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
l1000_mcf7: &classif_metrics
- name: auroc
metric: auroc_ipu
num_classes: 5
task: multiclass
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
num_classes: 5
task: multiclass
target_to_int: True
target_nan_mask: -1000
ignore_index: -1000
multitask_handling: mean-per-label
threshold_kwargs: null
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-large/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-large-gcn/mcf7/
filename: *name
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/single_task_gcn/config_large_gcn_pcba.yaml
================================================
# Testing the gcn model
constants:
name: &name neurips2023_large_data_gcn_pcba
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 200 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 400
# Data handling-related
batch_size_training: 10
batch_size_inference: 2
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 32
accumulate_grad_batches: 8
ipu_config:
- deviceIterations(10) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
pcba_1328:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCBA_1328_1564k.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCBA_1328_1564k.parquet
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: assayID-* # assayID-* means all columns starting with "assayID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt`
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-large/pcba/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 768
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
pcba_1328:
task_level: graph
out_dim: 1328
hidden_dims: 64
depth: 2
activation: relu
last_activation: sigmoid
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
pcba_1328: []
metrics_on_training_set:
pcba_1328: []
loss_fun:
pcba_1328: bce_ipu
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 20
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
pcba_1328:
- name: auroc
metric: auroc_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-large/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-large-gcn/pcba/
filename: *name
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/single_task_gcn/config_large_gcn_vcap.yaml
================================================
# Testing the gcn model
constants:
name: &name neurips2023_large_data_gcn_vcap
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 60 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 150
# Data handling-related
batch_size_training: 10
batch_size_inference: 2
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 32
accumulate_grad_batches: 8
ipu_config:
- deviceIterations(1) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
l1000_vcap:
df: null
df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt`
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-large/vcap/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 768
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
l1000_vcap:
task_level: graph
out_dim: 4890
hidden_dims: 128
depth: 2
activation: none
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
l1000_vcap: []
metrics_on_training_set:
l1000_vcap: []
loss_fun:
l1000_vcap:
name: hybrid_ce_ipu
n_brackets: 5
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 20
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
l1000_vcap: &classif_metrics
- name: auroc
metric: auroc_ipu
num_classes: 5
task: multiclass
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
num_classes: 5
task: multiclass
target_to_int: True
target_nan_mask: -1000
ignore_index: -1000
multitask_handling: mean-per-label
threshold_kwargs: null
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-large/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-large-gcn/vcap/
filename: *name
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/single_task_gin/config_large_gin_g25.yaml
================================================
# Testing the gin model
constants:
name: &name neurips2023_large_data_gin_g25
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 30 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 100
# Data handling-related
batch_size_training: 10
batch_size_inference: 10
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 32
accumulate_grad_batches: 8
ipu_config:
- deviceIterations(10) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
pcqm4m_g25:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: graph_* # graph_* means all columns starting with "graph_"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
label_normalization:
normalize_val_test: True
method: "normal"
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-large/g25/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gin' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
pcqm4m_g25:
task_level: graph
out_dim: 25
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
pcqm4m_g25: []
metrics_on_training_set:
pcqm4m_g25: []
loss_fun:
pcqm4m_g25: mae_ipu
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 20
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
pcqm4m_g25: &pcqm_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2
metric: r2_score_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-large/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-large-gin/g25/
filename: *name
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/single_task_gin/config_large_gin_mcf7.yaml
================================================
# Testing the gine model with the PCQMv2 dataset on IPU.
constants:
name: &name neurips2023_large_data_gine_mcf7
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 60 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 100
# Data handling-related
batch_size_training: 10
batch_size_inference: 2
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 32
accumulate_grad_batches: 8
ipu_config:
- deviceIterations(1) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
l1000_mcf7:
df: null
df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt`
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-large/mcf7/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gin' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
l1000_mcf7:
task_level: graph
out_dim: 4890
hidden_dims: 128
depth: 2
activation: none
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
l1000_mcf7: []
metrics_on_training_set:
l1000_mcf7: []
loss_fun:
l1000_mcf7:
name: hybrid_ce_ipu
n_brackets: 5
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 20
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
l1000_mcf7: &classif_metrics
- name: auroc
metric: auroc_ipu
num_classes: 5
task: multiclass
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
num_classes: 5
task: multiclass
target_to_int: True
target_nan_mask: -1000
ignore_index: -1000
multitask_handling: mean-per-label
threshold_kwargs: null
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-large/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-large-gine/mcf7/
filename: *name
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/single_task_gin/config_large_gin_n4.yaml
================================================
# Testing the gin model
constants:
name: &name neurips2023_large_data_gin_n4
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 30 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 100
# Data handling-related
batch_size_training: 10
batch_size_inference: 10
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 32
accumulate_grad_batches: 8
ipu_config:
- deviceIterations(10) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
pcqm4m_n4:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: node_* # node_* means all columns starting with "node_"
# sample_size: 2000 # use sample_size for test
task_level: node
splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
seed: *seed
label_normalization:
normalize_val_test: True
method: "normal"
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-large/n4/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gin' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
graph_output_nn:
node:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
pcqm4m_n4:
task_level: node
out_dim: 4
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
pcqm4m_n4: []
metrics_on_training_set:
pcqm4m_n4: []
loss_fun:
pcqm4m_n4: mae_ipu
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 20
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
pcqm4m_n4: &pcqm_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2
metric: r2_score_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-large/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-large-gin/n4/
filename: *name
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/single_task_gin/config_large_gin_pcba.yaml
================================================
# Testing the gin model
constants:
name: &name neurips2023_large_data_gin_pcba
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 200 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 400
# Data handling-related
batch_size_training: 10
batch_size_inference: 2
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 32
accumulate_grad_batches: 8
ipu_config:
- deviceIterations(10) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
pcba_1328:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCBA_1328_1564k.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCBA_1328_1564k.parquet
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: assayID-* # assayID-* means all columns starting with "assayID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt`
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-large/pcba/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gin' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
pcba_1328:
task_level: graph
out_dim: 1328
hidden_dims: 64
depth: 2
activation: relu
last_activation: sigmoid
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
pcba_1328: []
metrics_on_training_set:
pcba_1328: []
loss_fun:
pcba_1328: bce_ipu
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 20
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
pcba_1328:
- name: auroc
metric: auroc_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-large/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-large-gin/pcba/
filename: *name
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/single_task_gin/config_large_gin_pcq.yaml
================================================
# Testing the gin model
constants:
name: &name neurips2023_large_data_gin_pcq
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 30 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 100
# Data handling-related
batch_size_training: 10
batch_size_inference: 10
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 32
accumulate_grad_batches: 8
ipu_config:
- deviceIterations(10) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
pcqm4m_g25:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: graph_* # graph_* means all columns starting with "graph_"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
label_normalization:
normalize_val_test: True
method: "normal"
pcqm4m_n4:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: node_* # node_* means all columns starting with "node_"
# sample_size: 2000 # use sample_size for test
task_level: node
splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
seed: *seed
label_normalization:
normalize_val_test: True
method: "normal"
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-large/pcq/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gin' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
node:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
pcqm4m_g25:
task_level: graph
out_dim: 25
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
pcqm4m_n4:
task_level: node
out_dim: 4
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
pcqm4m_g25: []
pcqm4m_n4: []
metrics_on_training_set:
pcqm4m_g25: []
pcqm4m_n4: []
loss_fun:
pcqm4m_g25: mae_ipu
pcqm4m_n4: mae_ipu
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 20
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
pcqm4m_g25: &pcqm_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2
metric: r2_score_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
pcqm4m_n4: *pcqm_metrics
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-large/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-large-gin/pcq/
filename: *name
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/single_task_gin/config_large_gin_vcap.yaml
================================================
# Testing the gine model with the PCQMv2 dataset on IPU.
constants:
name: &name neurips2023_large_data_gine_vcap
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 60 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 150
# Data handling-related
batch_size_training: 10
batch_size_inference: 2
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 32
accumulate_grad_batches: 8
ipu_config:
- deviceIterations(1) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
l1000_vcap:
df: null
df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt`
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-large/vcap/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
in_dim: 64 # or otherwise the correct value
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gin' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
l1000_vcap:
task_level: graph
out_dim: 4890
hidden_dims: 128
depth: 2
activation: none
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
l1000_vcap: []
metrics_on_training_set:
l1000_vcap: []
loss_fun:
l1000_vcap:
name: hybrid_ce_ipu
n_brackets: 5
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 20
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
l1000_vcap: &classif_metrics
- name: auroc
metric: auroc_ipu
num_classes: 5
task: multiclass
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
num_classes: 5
task: multiclass
target_to_int: True
target_nan_mask: -1000
ignore_index: -1000
multitask_handling: mean-per-label
threshold_kwargs: null
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-large/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-large-gine/vcap/
filename: *name
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/single_task_gine/config_large_gine_g25.yaml
================================================
# Testing the gine model with the PCQMv2 dataset on IPU.
constants:
name: &name neurips2023_large_data_gine_g25
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 30 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 100
# Data handling-related
batch_size_training: 10
batch_size_inference: 10
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 32
accumulate_grad_batches: 8
ipu_config:
- deviceIterations(10) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
pcqm4m_g25:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: graph_* # graph_* means all columns starting with "graph_"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
label_normalization:
normalize_val_test: True
method: "normal"
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-large/g25/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gine' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
pcqm4m_g25:
task_level: graph
out_dim: 25
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
pcqm4m_g25: []
metrics_on_training_set:
pcqm4m_g25: []
loss_fun:
pcqm4m_g25: mae_ipu
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 20
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
pcqm4m_g25: &pcqm_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2
metric: r2_score_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-large/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-large-gine/g25/
filename: *name
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/single_task_gine/config_large_gine_mcf7.yaml
================================================
# Testing the gine model with the PCQMv2 dataset on IPU.
constants:
name: &name neurips2023_large_data_gine_mcf7
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 60 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 100
# Data handling-related
batch_size_training: 10
batch_size_inference: 2
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 32
accumulate_grad_batches: 8
ipu_config:
- deviceIterations(1) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
l1000_mcf7:
df: null
df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_MCF7_0-4.csv.gz
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/l1000_mcf7_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_mcf7_random_splits.pt`
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-large/mcf7/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gine' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
l1000_mcf7:
task_level: graph
out_dim: 4890
hidden_dims: 128
depth: 2
activation: none
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
l1000_mcf7: []
metrics_on_training_set:
l1000_mcf7: []
loss_fun:
l1000_mcf7:
name: hybrid_ce_ipu
n_brackets: 5
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 20
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
l1000_mcf7: &classif_metrics
- name: auroc
metric: auroc_ipu
num_classes: 5
task: multiclass
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
num_classes: 5
task: multiclass
target_to_int: True
target_nan_mask: -1000
ignore_index: -1000
multitask_handling: mean-per-label
threshold_kwargs: null
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-large/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-large-gine/mcf7/
filename: *name
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/single_task_gine/config_large_gine_n4.yaml
================================================
# Testing the gine model with the PCQMv2 dataset on IPU.
constants:
name: &name neurips2023_large_data_gine_n4
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 30 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 100
# Data handling-related
batch_size_training: 10
batch_size_inference: 10
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 32
accumulate_grad_batches: 8
ipu_config:
- deviceIterations(10) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
pcqm4m_n4:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: node_* # node_* means all columns starting with "node_"
# sample_size: 2000 # use sample_size for test
task_level: node
splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
seed: *seed
label_normalization:
normalize_val_test: True
method: "normal"
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-large/n4/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gine' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
graph_output_nn:
node:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
pcqm4m_n4:
task_level: node
out_dim: 4
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
pcqm4m_n4: []
metrics_on_training_set:
pcqm4m_n4: []
loss_fun:
pcqm4m_n4: mae_ipu
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 20
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
pcqm4m_n4: &pcqm_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2
metric: r2_score_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-large/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-large-gine/n4/
filename: *name
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/single_task_gine/config_large_gine_pcba.yaml
================================================
# Testing the gine model with the PCQMv2 dataset on IPU.
constants:
name: &name neurips2023_large_data_gine_pcba
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 200 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 400
# Data handling-related
batch_size_training: 10
batch_size_inference: 2
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 32
accumulate_grad_batches: 8
ipu_config:
- deviceIterations(10) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
pcba_1328:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCBA_1328_1564k.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCBA_1328_1564k.parquet
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: assayID-* # assayID-* means all columns starting with "assayID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt`
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-large/pcba/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gine' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
pcba_1328:
task_level: graph
out_dim: 1328
hidden_dims: 64
depth: 2
activation: relu
last_activation: sigmoid
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
pcba_1328: []
metrics_on_training_set:
pcba_1328: []
loss_fun:
pcba_1328: bce_ipu
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 20
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
pcba_1328:
- name: auroc
metric: auroc_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-large/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-large-gine/pcba/
filename: *name
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/single_task_gine/config_large_gine_pcq.yaml
================================================
# Testing the gine model with the PCQMv2 dataset on IPU.
constants:
name: &name neurips2023_large_data_gine_pcq
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 30 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 30 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 100
# Data handling-related
batch_size_training: 10
batch_size_inference: 10
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 32
accumulate_grad_batches: 8
ipu_config:
- deviceIterations(10) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
pcqm4m_g25:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: graph_* # graph_* means all columns starting with "graph_"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
label_normalization:
normalize_val_test: True
method: "normal"
pcqm4m_n4:
df: null
df_path: graphium/data/neurips2023/large-dataset/PCQM4M_G25_N4.parquet
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/PCQM4M_G25_N4.parquet
# or set path as the URL directly
smiles_col: "ordered_smiles"
label_cols: node_* # node_* means all columns starting with "node_"
# sample_size: 2000 # use sample_size for test
task_level: node
splits_path: graphium/data/neurips2023/large-dataset/pcqm4m_g25_n4_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/pcqm4m_g25_n4_random_splits.pt`
seed: *seed
label_normalization:
normalize_val_test: True
method: "normal"
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-large/pcq/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gine' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
node:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
pcqm4m_g25:
task_level: graph
out_dim: 25
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
pcqm4m_n4:
task_level: node
out_dim: 4
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
pcqm4m_g25: []
pcqm4m_n4: []
metrics_on_training_set:
pcqm4m_g25: []
pcqm4m_n4: []
loss_fun:
pcqm4m_g25: mae_ipu
pcqm4m_n4: mae_ipu
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 20
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
pcqm4m_g25: &pcqm_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2
metric: r2_score_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
pcqm4m_n4: *pcqm_metrics
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-large/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-large-gine/pcq/
filename: *name
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/neurips2023_configs/single_task_gine/config_large_gine_vcap.yaml
================================================
# Testing the gine model with the PCQMv2 dataset on IPU.
constants:
name: &name neurips2023_large_data_gine_vcap
seed: &seed 42
raise_train_error: true # Whether the code should raise an error if it crashes during training
accelerator:
type: ipu # cpu or ipu or gpu
config_override:
datamodule:
args:
ipu_dataloader_training_opts:
mode: async
max_num_nodes_per_graph: 60 # train max nodes: 20, max_edges: 54
max_num_edges_per_graph: 100
ipu_dataloader_inference_opts:
mode: async
max_num_nodes_per_graph: 60 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 150
# Data handling-related
batch_size_training: 10
batch_size_inference: 2
predictor:
optim_kwargs:
loss_scaling: 1024
trainer:
trainer:
precision: 32
accumulate_grad_batches: 8
ipu_config:
- deviceIterations(1) # IPU would require large batches to be ready for the model.
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
- TensorLocations.numIOTiles(128)
- _Popart.set("defaultBufferingDepth", 128)
- Precision.enableStochasticRounding(True)
# accelerator:
# type: cpu # cpu or ipu or gpu
# config_override:
# datamodule:
# batch_size_training: 64
# batch_size_inference: 256
# trainer:
# trainer:
# precision: 32
# accumulate_grad_batches: 1
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
l1000_vcap:
df: null
df_path: graphium/data/neurips2023/large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
# wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/LINCS_L1000_VCAP_0-4.csv.gz
# or set path as the URL directly
smiles_col: "SMILES"
label_cols: geneID-* # geneID-* means all columns starting with "geneID-"
# sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/l1000_vcap_random_splits.pt # Download with `wget https://storage.valencelabs.com/graphium/datasets/neurips_2023/Large-dataset/l1000_vcap_random_splits.pt`
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-large/vcap/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization_backend: "loky"
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: &dropout 0.18
normalization: &normalization layer_norm
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
pe_encoders:
out_dim: 32
pool: "sum" #"mean" "max"
last_norm: None #"batch_norm", "layer_norm"
encoders: #la_pos | rw_pos
la_pos: # Set as null to avoid a pre-nn network
encoder_type: "laplacian_pe"
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
num_layers: 2
num_layers_post: 1 # Num. layers to apply after pooling
dropout: 0.1
first_normalization: "none" #"batch_norm" or "layer_norm"
rw_pos:
encoder_type: "mlp"
input_keys: ["rw_return_probs"]
output_keys: ["feat"]
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: "layer_norm" #"batch_norm" or "layer_norm"
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
gnn: # Set as null to avoid a post-nn network
out_dim: &gnn_dim 704
hidden_dims: *gnn_dim
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: "layer_norm"
last_normalization: *normalization
residual_type: simple
virtual_node: 'none'
layer_type: 'pyg:gine' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
graph_output_nn:
graph:
pooling: [sum]
out_dim: *gnn_dim
hidden_dims: *gnn_dim
depth: 1
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads:
l1000_vcap:
task_level: graph
out_dim: 4890
hidden_dims: 128
depth: 2
activation: none
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
#Task-specific
predictor:
metrics_on_progress_bar:
l1000_vcap: []
metrics_on_training_set:
l1000_vcap: []
loss_fun:
l1000_vcap:
name: hybrid_ce_ipu
n_brackets: 5
random_seed: *seed
optim_kwargs:
lr: 1.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 20
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor qm9/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
multitask_handling: flatten # flatten, mean-per-label
# Task-specific
metrics:
l1000_vcap: &classif_metrics
- name: auroc
metric: auroc_ipu
num_classes: 5
task: multiclass
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
num_classes: 5
task: multiclass
target_to_int: True
target_nan_mask: -1000
ignore_index: -1000
multitask_handling: mean-per-label
threshold_kwargs: null
trainer:
seed: *seed
logger:
save_dir: logs/neurips2023-large/
name: *name
project: *name
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/neurips2023-large-gine/vcap/
filename: *name
# monitor: *monitor
# mode: *mode
# save_top_k: 1
save_last: True
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
================================================
FILE: expts/run_validation_test.py
================================================
# General imports
import argparse
import os
from os.path import dirname, abspath
import yaml
from copy import deepcopy
from omegaconf import DictConfig, OmegaConf
import timeit
from loguru import logger
from datetime import datetime
from pytorch_lightning.utilities.model_summary import ModelSummary
# Current project imports
import graphium
from graphium.config._loader import (
load_datamodule,
load_metrics,
load_architecture,
load_predictor,
load_trainer,
save_params_to_wandb,
load_accelerator,
)
from graphium.utils.safe_run import SafeRun
import hydra
# WandB
import wandb
# Set up the working directory
MAIN_DIR = dirname(dirname(abspath(graphium.__file__)))
os.chdir(MAIN_DIR)
@hydra.main(version_base=None, config_path="hydra-configs", config_name="main")
def main(cfg: DictConfig) -> None:
cfg = OmegaConf.to_container(cfg, resolve=True)
run_name: str = "main"
add_date_time: bool = True
st = timeit.default_timer()
date_time_suffix = ""
if add_date_time:
date_time_suffix = datetime.now().strftime("%d.%m.%Y_%H.%M.%S")
wandb.init(entity=cfg["constants"]["entity"], project=cfg["constants"]["name"], config=cfg)
# Initialize the accelerator
cfg, accelerator_type = load_accelerator(cfg)
# Load and initialize the dataset
datamodule = load_datamodule(cfg, accelerator_type)
# Initialize the network
model_class, model_kwargs = load_architecture(
cfg,
in_dims=datamodule.in_dims,
)
datamodule.prepare_data()
metrics = load_metrics(cfg)
logger.info(metrics)
predictor = load_predictor(
cfg, model_class, model_kwargs, metrics, accelerator_type, datamodule.task_norms
)
logger.info(predictor.model)
logger.info(ModelSummary(predictor, max_depth=4))
trainer = load_trainer(cfg, run_name, accelerator_type, date_time_suffix)
save_params_to_wandb(trainer.logger, cfg, predictor, datamodule)
# Determine the max num nodes and edges in training and validation
predictor.set_max_nodes_edges_per_graph(datamodule, stages=["val"])
# Run the model validation
with SafeRun(name="VALIDATING", raise_error=cfg["constants"]["raise_train_error"], verbose=True):
trainer.validate(
model=predictor,
ckpt_path=f'{cfg["trainer"]["model_checkpoint"]["dirpath"]}{cfg["trainer"]["seed"]}/{cfg["trainer"]["model_checkpoint"]["filename"]}.ckpt',
datamodule=datamodule,
)
# Run the model testing
with SafeRun(name="TESTING", raise_error=cfg["constants"]["raise_train_error"], verbose=True):
trainer.test(
model=predictor,
ckpt_path=f'{cfg["trainer"]["model_checkpoint"]["dirpath"]}{cfg["trainer"]["seed"]}/{cfg["trainer"]["model_checkpoint"]["filename"]}.ckpt',
datamodule=datamodule,
)
logger.info("--------------------------------------------")
logger.info("total computation used", timeit.default_timer() - st)
logger.info("--------------------------------------------")
wandb.finish()
return trainer.callback_metrics
if __name__ == "__main__":
main()
================================================
FILE: graphium/__init__.py
================================================
from ._version import __version__
from .config import load_config
from . import utils
from . import features
from . import data
from . import nn
from . import trainer
================================================
FILE: graphium/_version.py
================================================
from importlib.metadata import version
from importlib.metadata import PackageNotFoundError
try:
__version__ = version("graphium")
except PackageNotFoundError:
# package is not installed
__version__ = "dev"
================================================
FILE: graphium/cli/__init__.py
================================================
from .data import data_app
from .parameters import param_app
from .finetune_utils import finetune_app
from .main import app
================================================
FILE: graphium/cli/__main__.py
================================================
from graphium.cli.main import app
if __name__ == "__main__":
app()
================================================
FILE: graphium/cli/data.py
================================================
import timeit
from typing import List
from omegaconf import OmegaConf
import typer
import graphium
from loguru import logger
from hydra import initialize, compose
from .main import app
from graphium.config._loader import load_datamodule
data_app = typer.Typer(help="Graphium datasets.")
app.add_typer(data_app, name="data")
@data_app.command(name="download", help="Download a Graphium dataset.")
def download(name: str, output: str, progress: bool = True):
args = {}
args["name"] = name
args["output_path"] = output
args["extract_zip"] = True
args["progress"] = progress
logger.info(f"Download dataset '{name}' into {output}.")
fpath = graphium.data.utils.download_graphium_dataset(**args)
logger.info(f"Dataset available at {fpath}.")
@data_app.command(name="list", help="List available Graphium dataset.")
def list():
logger.info("Graphium datasets:")
logger.info(graphium.data.utils.list_graphium_datasets())
@data_app.command(name="prepare", help="Prepare a Graphium dataset.")
def prepare_data(overrides: List[str]) -> None:
with initialize(version_base=None, config_path="../../expts/hydra-configs"):
cfg = compose(
config_name="main",
overrides=overrides,
)
cfg = OmegaConf.to_container(cfg, resolve=True)
st = timeit.default_timer()
# Checking that `processed_graph_data_path` is provided
path = cfg["datamodule"]["args"].get("processed_graph_data_path", None)
if path is None:
raise ValueError(
"Please provide `datamodule.args.processed_graph_data_path` to specify the caching dir."
)
logger.info(f"The caching dir is set to '{path}'")
# Data-module
datamodule = load_datamodule(cfg, "cpu")
datamodule.prepare_data()
logger.info(f"Data preparation took {timeit.default_timer() - st:.2f} seconds.")
================================================
FILE: graphium/cli/finetune_utils.py
================================================
from typing import List, Literal, Optional
import fsspec
import numpy as np
import torch
import tqdm
import typer
import yaml
from datamol.utils import fs
from hydra import compose, initialize
from hydra.core.hydra_config import HydraConfig
from loguru import logger
from omegaconf import OmegaConf
from graphium.config._loader import load_accelerator, load_datamodule
from graphium.finetuning.fingerprinting import Fingerprinter
from graphium.utils import fs
from graphium.trainer.predictor import PredictorModule
from .main import app
from .train_finetune_test import run_training_finetuning_testing
finetune_app = typer.Typer(help="Utility CLI for extra fine-tuning utilities.")
app.add_typer(finetune_app, name="finetune")
@finetune_app.command(name="admet")
def benchmark_tdc_admet_cli(
overrides: List[str],
name: Optional[List[str]] = None,
inclusive_filter: bool = True,
):
"""
Utility CLI to easily fine-tune a model on (a subset of) the benchmarks in the TDC ADMET group.
A major limitation is that we cannot use all features of the Hydra CLI, such as multiruns.
"""
try:
from tdc.utils import retrieve_benchmark_names
except ImportError:
raise ImportError("TDC needs to be installed to use this CLI. Run `pip install PyTDC`.")
# Get the benchmarks to run this for
if len(name) == 0:
name = retrieve_benchmark_names("admet_group")
if not inclusive_filter:
name = [n for n in retrieve_benchmark_names("admet_group") if n not in name]
logger.info(f"Running fine-tuning for the following benchmarks: {name}")
results = {}
# Use the Compose API to construct the config
for n in name:
overrides += ["+finetuning=admet", f"constants.task={n}"]
with initialize(version_base=None, config_path="../../expts/hydra-configs"):
cfg = compose(
config_name="main",
overrides=overrides,
)
# Run the training loop
ret = run_training_finetuning_testing(cfg)
ret = {k: v.item() for k, v in ret.items()}
results[n] = ret
# Save to the results_dir by default or to the Hydra output_dir if needed.
# This distinction is needed, because Hydra's output_dir cannot be remote.
save_dir = cfg["constants"].get("results_dir", HydraConfig.get()["runtime"]["output_dir"])
fs.mkdir(save_dir, exist_ok=True)
path = fs.join(save_dir, "results.yaml")
logger.info(f"Saving results to {path}")
with fsspec.open(path, "w") as f:
yaml.dump(results, f)
@finetune_app.command(name="fingerprint")
def get_fingerprints_from_model(
fingerprint_layer_spec: List[str],
pretrained_model: str,
save_destination: str,
output_type: str = typer.Option("torch", help="Either numpy (.npy) or torch (.pt) output"),
overrides: Optional[List[str]] = typer.Option(None, "--override", "-o", help="Hydra overrides"),
):
"""Endpoint for getting fingerprints from a pretrained model.
The pretrained model should be a `.ckpt` path or pre-specified, named model within Graphium.
The fingerprint layer specification should be of the format `module:layer`.
If specified as a list, the fingerprints from all the specified layers will be concatenated.
See the docs of the `graphium.finetuning.fingerprinting.Fingerprinter` class for more info.
"""
if overrides is None:
overrides = []
with initialize(version_base=None, config_path="../../expts/hydra-configs"):
cfg = compose(config_name="main", overrides=overrides)
cfg = OmegaConf.to_container(cfg, resolve=True)
## == Instantiate all required objects from their respective configs ==
# Accelerator
cfg, accelerator_type = load_accelerator(cfg)
# Data-module
datamodule = load_datamodule(cfg, accelerator_type)
datamodule.prepare_data()
# The predict_dataloader() returns either predict or test, so we need to run both.
datamodule.setup("test")
datamodule.setup("predict")
# Model
predictor = PredictorModule.load_pretrained_model(
pretrained_model,
device=accelerator_type,
)
## == Fingerprinter
with Fingerprinter(model=predictor, fingerprint_spec=fingerprint_layer_spec, out_type=output_type) as fp:
fps = fp.get_fingerprints_for_dataset(datamodule.predict_dataloader())
fs.mkdir(save_destination, exist_ok=True)
if output_type == "numpy":
path = fs.join(save_destination, "fingerprints.npy")
logger.info(f"Saving fingerprints to {path}")
with fsspec.open(path, "wb") as f:
np.save(path, fps)
else:
path = fs.join(save_destination, "fingerprints.pt")
logger.info(f"Saving fingerprints to {path}")
torch.save(fps, path)
def get_tdc_task_specific(task: str, output: Literal["name", "mode", "last_activation"]):
if output == "last_activation":
config_arch_path = "expts/hydra-configs/tasks/task_heads/admet.yaml"
with open(config_arch_path, "r") as yaml_file:
config_tdc_arch = yaml.load(yaml_file, Loader=yaml.FullLoader)
return config_tdc_arch["architecture"]["task_heads"][task]["last_activation"]
else:
config_metrics_path = "expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml"
with open(config_metrics_path, "r") as yaml_file:
config_tdc_task_metric = yaml.load(yaml_file, Loader=yaml.FullLoader)
metric = config_tdc_task_metric["predictor"]["metrics_on_progress_bar"][task][0]
metric_mode_map = {
"mae": "min",
"auroc": "max",
"auprc": "max",
"spearman": "max",
}
if output == "name":
return metric
elif output == "mode":
return metric_mode_map[metric]
OmegaConf.register_new_resolver("get_metric_name", lambda x: get_tdc_task_specific(x, output="name"))
OmegaConf.register_new_resolver("get_metric_mode", lambda x: get_tdc_task_specific(x, output="mode"))
OmegaConf.register_new_resolver(
"get_last_activation", lambda x: get_tdc_task_specific(x, output="last_activation")
)
OmegaConf.register_new_resolver("eval", lambda x: eval(x, {"np": np}))
================================================
FILE: graphium/cli/fingerprints.py
================================================
from .main import app
@app.command(name="fp")
def get_fingerprints_from_model(): ...
================================================
FILE: graphium/cli/main.py
================================================
import typer
app = typer.Typer(add_completion=False)
if __name__ == "__main__":
app()
================================================
FILE: graphium/cli/parameters.py
================================================
import timeit
from typing import List
from omegaconf import DictConfig, OmegaConf
import typer
import numpy as np
from loguru import logger
from hydra import initialize, compose
from .main import app
from graphium.config._loader import (
load_accelerator,
load_architecture,
load_datamodule,
)
from graphium.trainer.predictor_options import ModelOptions
param_app = typer.Typer(help="Parameter counts.")
app.add_typer(param_app, name="params")
@param_app.command(name="infer", help="Infer parameter count.")
def infer_parameter_count(overrides: List[str] = []) -> int:
with initialize(version_base=None, config_path="../../expts/hydra-configs"):
cfg = compose(
config_name="main",
overrides=overrides,
)
cfg = OmegaConf.to_container(cfg, resolve=True)
## Accelerator
cfg, accelerator_type = load_accelerator(cfg)
## Datamodule
datamodule = load_datamodule(cfg, accelerator_type)
## Architecture
model_class, model_kwargs = load_architecture(cfg, in_dims=datamodule.in_dims)
model_options = ModelOptions(
model_class=model_class,
model_kwargs=model_kwargs,
)
model = model_options.model_class(**model_options.model_kwargs)
# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Number of parameters: {num_params}.")
return num_params
================================================
FILE: graphium/cli/train_finetune_test.py
================================================
from typing import List, Literal, Union
import os
import time
import timeit
from datetime import datetime
import fsspec
import hydra
import numpy as np
import torch
import wandb
import yaml
from datamol.utils import fs
from hydra.core.hydra_config import HydraConfig
from hydra.types import RunMode
from lightning.pytorch.utilities.model_summary import ModelSummary
from loguru import logger
from omegaconf import DictConfig, OmegaConf
from graphium.config._loader import (
load_accelerator,
load_architecture,
load_datamodule,
load_metrics,
load_predictor,
load_trainer,
save_params_to_wandb,
get_checkpoint_path,
)
from graphium.finetuning import (
FINETUNING_CONFIG_KEY,
GraphFinetuning,
modify_cfg_for_finetuning,
)
from graphium.hyper_param_search import (
HYPER_PARAM_SEARCH_CONFIG_KEY,
extract_main_metric_for_hparam_search,
)
from graphium.trainer.predictor import PredictorModule
from graphium.utils.safe_run import SafeRun
import graphium.cli.finetune_utils
TESTING_ONLY_CONFIG_KEY = "testing_only"
@hydra.main(version_base=None, config_path="../../expts/hydra-configs", config_name="main")
def cli(cfg: DictConfig) -> None:
"""
The main CLI endpoint for training, fine-tuning and evaluating Graphium models.
"""
return run_training_finetuning_testing(cfg)
def get_replication_factor(cfg):
try:
ipu_config = cfg.get("accelerator", {}).get("ipu_config", [])
for item in ipu_config:
if "replicationFactor" in item:
# Extract the number between parentheses
start = item.find("(") + 1
end = item.find(")")
if start != 0 and end != -1:
return int(item[start:end])
except Exception as e:
print(f"An error occurred: {e}")
# Return default value if replicationFactor is not found or an error occurred
return 1
def get_gradient_accumulation_factor(cfg):
"""
WARNING: This MUST be called after accelerator overrides have been applied
(i.e. after `load_accelerator` has been called)
"""
try:
# Navigate through the nested dictionaries and get the gradient accumulation factor
grad_accumulation_factor = cfg.get("trainer", {}).get("trainer", {}).get("accumulate_grad_batches", 1)
# Ensure that the extracted value is an integer
return int(grad_accumulation_factor)
except Exception as e:
print(f"An error occurred: {e}")
# Return default value if an error occurred
return 1
def get_training_batch_size(cfg):
"""
WARNING: This MUST be called after accelerator overrides have been applied
(i.e. after `load_accelerator` has been called)
"""
try:
# Navigate through the nested dictionaries and get the training batch size
batch_size_training = cfg.get("datamodule", {}).get("args", {}).get("batch_size_training", 1)
# Ensure that the extracted value is an integer
return int(batch_size_training)
except Exception as e:
print(f"An error occurred: {e}")
# Return default value if an error occurred
return 1
def get_training_device_iterations(cfg):
try:
ipu_config = cfg.get("accelerator", {}).get("ipu_config", [])
for item in ipu_config:
if "deviceIterations" in item:
# Extract the number between parentheses
start = item.find("(") + 1
end = item.find(")")
if start != 0 and end != -1:
return int(item[start:end])
except Exception as e:
print(f"An error occurred: {e}")
# Return default value if deviceIterations is not found or an error occurred
return 1
def run_training_finetuning_testing(cfg: DictConfig) -> None:
"""
The main (pre-)training and fine-tuning loop.
"""
unresolved_cfg = OmegaConf.to_container(cfg, resolve=False)
cfg = OmegaConf.to_container(cfg, resolve=True)
# Get the current date and time
now = datetime.now()
# Format the datetime as a string
filename_datetime_suffix = now.strftime("%Y%m%d_%H%M%S")
# Append the datetime string to the existing filename in the cfg dictionary
cfg["trainer"]["model_checkpoint"]["filename"] += f"_{filename_datetime_suffix}"
cfg["trainer"]["model_checkpoint"]["dirpath"] = (
cfg["trainer"]["model_checkpoint"]["dirpath"][:-1] + f"_{filename_datetime_suffix}"
)
dst_dir = cfg["constants"].get("results_dir")
hydra_cfg = HydraConfig.get()
output_dir = hydra_cfg["runtime"]["output_dir"]
if dst_dir is not None and fs.exists(dst_dir) and len(fs.get_mapper(dst_dir).fs.ls(dst_dir)) > 0:
logger.warning(
"The destination directory is not empty. "
"If files already exist, this would lead to a crash at the end of training."
)
# We pause here briefly, to make sure the notification is seen as there's lots of logs afterwards
time.sleep(5)
# Modify the config for finetuning
if FINETUNING_CONFIG_KEY in cfg:
cfg = modify_cfg_for_finetuning(cfg)
st = timeit.default_timer()
# Initialize wandb only on first rank
if os.environ.get("RANK", "0") == "0":
# Disable wandb if the user is not logged in.
wandb_cfg = cfg["constants"].get("wandb")
if wandb_cfg is not None and wandb.login() is False:
logger.info(
"Not logged in to wandb - disabling wandb logging.\n"
+ "To enable wandb, run `wandb login` from the command line."
)
wandb.init(mode="disabled")
elif wandb_cfg is not None:
wandb.init(config=cfg, **wandb_cfg)
else:
wandb_cfg = None
## == Instantiate all required objects from their respective configs ==
# Accelerator
cfg, accelerator_type = load_accelerator(cfg)
## Data-module
datamodule = load_datamodule(cfg, accelerator_type)
datamodule.prepare_data()
testing_only = cfg.get(TESTING_ONLY_CONFIG_KEY, False)
if testing_only:
# Load pre-trained model
predictor = PredictorModule.load_pretrained_model(
name_or_path=get_checkpoint_path(cfg), device=accelerator_type
)
else:
## Architecture
model_class, model_kwargs = load_architecture(cfg, in_dims=datamodule.in_dims)
## Metrics
metrics = load_metrics(cfg)
# Note: these MUST be called after `cfg, accelerator = load_accelerator(cfg)`
replicas = get_replication_factor(cfg)
gradient_acc = get_gradient_accumulation_factor(cfg)
micro_bs = get_training_batch_size(cfg)
device_iterations = get_training_device_iterations(cfg)
global_bs = replicas * gradient_acc * micro_bs * device_iterations
## Predictor
predictor = load_predictor(
config=cfg,
model_class=model_class,
model_kwargs=model_kwargs,
metrics=metrics,
task_levels=datamodule.get_task_levels(),
accelerator_type=accelerator_type,
featurization=datamodule.featurization,
task_norms=datamodule.task_norms,
replicas=replicas,
gradient_acc=gradient_acc,
global_bs=global_bs,
)
logger.info(predictor.model)
logger.info(ModelSummary(predictor, max_depth=4))
## Trainer
date_time_suffix = datetime.now().strftime("%d.%m.%Y_%H.%M.%S")
trainer = load_trainer(cfg, accelerator_type, date_time_suffix)
if not testing_only:
# Add the fine-tuning callback to trainer
if FINETUNING_CONFIG_KEY in cfg:
finetuning_training_kwargs = cfg["finetuning"]["training_kwargs"]
trainer.callbacks.append(GraphFinetuning(**finetuning_training_kwargs))
if wandb_cfg is not None:
save_params_to_wandb(trainer.logger, cfg, predictor, datamodule, unresolved_config=unresolved_cfg)
# Determine the max num nodes and edges in training and validation
logger.info("Computing the maximum number of nodes and edges per graph")
predictor.set_max_nodes_edges_per_graph(datamodule, stages=["train", "val"])
# When resuming training from a checkpoint, we need to provide the path to the checkpoint in the config
resume_ckpt_path = cfg["trainer"].get("resume_from_checkpoint", None)
# Run the model training
with SafeRun(name="TRAINING", raise_error=cfg["constants"]["raise_train_error"], verbose=True):
trainer.fit(model=predictor, datamodule=datamodule, ckpt_path=resume_ckpt_path)
# Save validation metrics - Base utility in case someone doesn't use a logger.
results = trainer.callback_metrics
results = {k: v.item() if torch.is_tensor(v) else v for k, v in results.items()}
with fsspec.open(fs.join(output_dir, "val_results.yaml"), "w") as f:
yaml.dump(results, f)
# Determine the max num nodes and edges in testing
predictor.set_max_nodes_edges_per_graph(datamodule, stages=["test"])
# When checkpoints are logged during training, we can, e.g., use the best or last checkpoint for testing
test_ckpt_path = None
test_ckpt_name = cfg["trainer"].get("test_from_checkpoint", None)
test_ckpt_dir = cfg["trainer"]["model_checkpoint"].get("dirpath", None)
if test_ckpt_name is not None and test_ckpt_dir is not None:
test_ckpt_path = os.path.join(test_ckpt_dir, test_ckpt_name)
# Run the model testing
with SafeRun(name="TESTING", raise_error=cfg["constants"]["raise_train_error"], verbose=True):
trainer.test(model=predictor, datamodule=datamodule, ckpt_path=test_ckpt_path)
logger.info("-" * 50)
logger.info("Total compute time:", timeit.default_timer() - st)
logger.info("-" * 50)
save_checkpoint_to_wandb = cfg["trainer"].get("save_checkpoint_to_wandb")
if save_checkpoint_to_wandb is True:
# Save initial model state - and upload checkpoint to wandb
if cfg["trainer"]["model_checkpoint"]["save_last"] is True:
checkpoint_path = f"{cfg['trainer']['model_checkpoint']['dirpath']}/{cfg['trainer']['model_checkpoint']['filename']}-v1.ckpt"
# Log the initial model checkpoint to wandb
wandb.save(checkpoint_path)
wandb.finish()
# Save test metrics - Base utility in case someone doesn't use a logger.
results = trainer.callback_metrics
results = {k: v.item() if torch.is_tensor(v) else v for k, v in results.items()}
with fsspec.open(fs.join(output_dir, "test_results.yaml"), "w") as f:
yaml.dump(results, f)
# When part of of a hyper-parameter search, we are very specific about how we save our results
# NOTE (cwognum): We also check if the we are in multi-run mode, as the sweeper is otherwise not active.
if HYPER_PARAM_SEARCH_CONFIG_KEY in cfg and hydra_cfg.mode == RunMode.MULTIRUN:
results = extract_main_metric_for_hparam_search(results, cfg[HYPER_PARAM_SEARCH_CONFIG_KEY])
# Copy the current working directory to remote
# By default, processes should just write results to Hydra's output directory.
# However, this currently does not support remote storage, which is why we copy the results here if needed.
# For more info, see also: https://github.com/facebookresearch/hydra/issues/993
if dst_dir is not None:
src_dir = hydra_cfg["runtime"]["output_dir"]
dst_dir = fs.join(dst_dir, fs.get_basename(src_dir))
fs.mkdir(dst_dir, exist_ok=True)
fs.copy_dir(src_dir, dst_dir)
return results
if __name__ == "__main__":
cli()
================================================
FILE: graphium/config/README.md
================================================
The Graph Of LIfe Library.
## What is in this folder?
- `_load.py` and `config_convert.py`: loading from `yaml` config file
- `_loader.py`: retrieve relevant arguments from the config and loading different modules with the correct arguments
================================================
FILE: graphium/config/__init__.py
================================================
from ._load import load_config
from ._loader import load_architecture
from ._loader import load_datamodule
from ._loader import load_metrics
from ._loader import load_predictor
from ._loader import load_trainer
from ._loader import save_params_to_wandb
from ._loader import load_accelerator
================================================
FILE: graphium/config/_load.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import importlib.resources
import omegaconf
def load_config(name: str):
"""Load a default config file by its name.
Args:
name: name of the config to load.
"""
with importlib.resources.open_text("graphium.config", f"{name}.yaml") as f:
config = omegaconf.OmegaConf.load(f)
return config
================================================
FILE: graphium/config/_loader.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
# Misc
import os
from copy import deepcopy
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union
import joblib
import mup
import omegaconf
# Torch
import torch
import yaml
# Lightning
from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import Logger, WandbLogger
from loguru import logger
from graphium.data.datamodule import BaseDataModule, MultitaskFromSmilesDataModule
from graphium.finetuning.finetuning_architecture import FullGraphFinetuningNetwork
from graphium.ipu.ipu_dataloader import IPUDataloaderOptions
from graphium.ipu.ipu_utils import import_poptorch, load_ipu_options
from graphium.nn.architectures import FullGraphMultiTaskNetwork
from graphium.nn.utils import MupMixin
from graphium.trainer.metrics import MetricWrapper
from graphium.trainer.predictor import PredictorModule
from graphium.utils.command_line_utils import get_anchors_and_aliases, update_config
# Graphium
from graphium.utils.mup import set_base_shapes
from graphium.utils.spaces import DATAMODULE_DICT, GRAPHIUM_PRETRAINED_MODELS_DICT
from graphium.utils import fs
def get_accelerator(
config_acc: Union[omegaconf.DictConfig, Dict[str, Any]],
) -> str:
"""
Get the accelerator from the config file, and ensure that they are
consistant.
"""
# Get the accelerator type
accelerator_type = config_acc["type"]
# Get the GPU info
if (accelerator_type == "gpu") and (not torch.cuda.is_available()):
raise ValueError(f"GPUs selected, but GPUs are not available on this device")
# Get the IPU info
if accelerator_type == "ipu":
poptorch = import_poptorch()
if poptorch is None:
raise ValueError("IPUs selected, but PopTorch is not available")
if not poptorch.ipuHardwareIsAvailable():
raise ValueError(
"IPUs selected, but no IPU is available/visible on this device. "
"If you do have IPUs, please check that the IPUOF_VIPU_API_PARTITION_ID and "
"IPUOF_VIPU_API_HOST environment variables are set."
)
# Fall on cpu at the end
if accelerator_type is None:
accelerator_type = "cpu"
return accelerator_type
def _get_ipu_opts(config: Union[omegaconf.DictConfig, Dict[str, Any]]) -> Tuple[str, str]:
r"""
Get the paths of the IPU-specific config files from the main YAML config
"""
accelerator_options = config["accelerator"]
accelerator_type = accelerator_options["type"]
if accelerator_type != "ipu":
return None, None
ipu_opts = accelerator_options["ipu_config"]
ipu_inference_opts = accelerator_options.get("ipu_inference_config", None)
return ipu_opts, ipu_inference_opts
def load_datamodule(
config: Union[omegaconf.DictConfig, Dict[str, Any]], accelerator_type: str
) -> BaseDataModule:
"""
Load the datamodule from the specified configurations at the key
`datamodule: args`.
If the accelerator is IPU, load the IPU options as well.
Parameters:
config: The config file, with key `datamodule: args`
accelerator_type: The accelerator type, e.g. "cpu", "gpu", "ipu"
Returns:
datamodule: The datamodule used to process and load the data
"""
cfg_data = config["datamodule"]["args"]
# Instanciate the datamodule
module_class = DATAMODULE_DICT[config["datamodule"]["module_type"]]
if accelerator_type != "ipu":
datamodule = module_class(
**config["datamodule"]["args"],
)
return datamodule
# IPU specific adjustments
else:
ipu_opts, ipu_inference_opts = _get_ipu_opts(config)
# Default empty values for the IPU configurations
ipu_training_opts = None
ipu_dataloader_training_opts = cfg_data.pop("ipu_dataloader_training_opts", {})
ipu_dataloader_inference_opts = cfg_data.pop("ipu_dataloader_inference_opts", {})
ipu_training_opts, ipu_inference_opts = load_ipu_options(
ipu_opts=ipu_opts,
seed=config["constants"]["seed"],
model_name=config["constants"]["name"],
gradient_accumulation=config["trainer"]["trainer"].get("accumulate_grad_batches", None),
ipu_inference_opts=ipu_inference_opts,
precision=config["trainer"]["trainer"].get("precision"),
)
# Define the Dataloader options for the IPU on the training sets
bz_train = cfg_data["batch_size_training"]
ipu_dataloader_training_opts = IPUDataloaderOptions(
batch_size=bz_train, **ipu_dataloader_training_opts
)
ipu_dataloader_training_opts.set_kwargs()
# Define the Dataloader options for the IPU on the inference sets
bz_test = cfg_data["batch_size_inference"]
ipu_dataloader_inference_opts = IPUDataloaderOptions(
batch_size=bz_test, **ipu_dataloader_inference_opts
)
ipu_dataloader_inference_opts.set_kwargs()
datamodule = module_class(
ipu_training_opts=ipu_training_opts,
ipu_inference_opts=ipu_inference_opts,
ipu_dataloader_training_opts=ipu_dataloader_training_opts,
ipu_dataloader_inference_opts=ipu_dataloader_inference_opts,
**config["datamodule"]["args"],
)
return datamodule
def load_metrics(config: Union[omegaconf.DictConfig, Dict[str, Any]]) -> Dict[str, MetricWrapper]:
"""
Loading the metrics to be tracked.
Parameters:
config: The config file, with key `metrics`
Returns:
metrics: A dictionary of all the metrics
"""
task_metrics = {}
cfg_metrics = config.get("metrics", None)
if cfg_metrics is None:
return task_metrics
cfg_metrics = {key: deepcopy(value) for key, value in config["metrics"].items()}
# Wrap every metric in the class `MetricWrapper` to standardize them
for task in cfg_metrics:
task_metrics[task] = {}
if cfg_metrics[task] is None:
cfg_metrics[task] = []
for this_metric in cfg_metrics[task]:
name = this_metric.pop("name")
task_metrics[task][name] = MetricWrapper(**this_metric)
return task_metrics
def load_architecture(
config: Union[omegaconf.DictConfig, Dict[str, Any]],
in_dims: Dict[str, int],
) -> Union[FullGraphMultiTaskNetwork, torch.nn.Module]:
"""
Loading the architecture used for training.
Parameters:
config: The config file, with key `architecture`
in_dims: Dictionary of the input dimensions for various
Returns:
architecture: The datamodule used to process and load the data
"""
if isinstance(config, dict) and "finetuning" not in config:
config = omegaconf.OmegaConf.create(config)
cfg_arch = config["architecture"]
# Select the architecture
model_type = cfg_arch["model_type"].lower()
if model_type == "fullgraphmultitasknetwork":
model_class = FullGraphMultiTaskNetwork
elif model_type == "fullgraphfinetuningnetwork":
model_class = FullGraphFinetuningNetwork
else:
raise ValueError(f"Unsupported model_type=`{model_type}`")
# Prepare the various kwargs
pe_encoders_kwargs = (
dict(cfg_arch["pe_encoders"]) if cfg_arch.get("pe_encoders", None) is not None else None
)
pre_nn_kwargs = dict(cfg_arch["pre_nn"]) if cfg_arch["pre_nn"] is not None else None
pre_nn_edges_kwargs = dict(cfg_arch["pre_nn_edges"]) if cfg_arch["pre_nn_edges"] is not None else None
gnn_kwargs = dict(cfg_arch["gnn"])
graph_output_nn_kwargs = (
dict(cfg_arch["graph_output_nn"]) if cfg_arch["graph_output_nn"] is not None else None
)
task_heads_kwargs = (
cfg_arch["task_heads"] if cfg_arch["task_heads"] is not None else None
) # This is of type ListConfig containing TaskHeadParams
# Initialize the input dimension for the positional encoders
if pe_encoders_kwargs is not None:
pe_encoders_kwargs = dict(pe_encoders_kwargs)
for encoder in pe_encoders_kwargs["encoders"]:
pe_encoders_kwargs["encoders"][encoder] = dict(pe_encoders_kwargs["encoders"][encoder])
pe_encoders_kwargs.setdefault(
"in_dims", in_dims
) # set the input dimensions of all pe with info from the data-module
pe_out_dim = 0 if pe_encoders_kwargs is None else pe_encoders_kwargs.get("out_dim", None)
edge_pe_out_dim = 0 if pe_encoders_kwargs is None else pe_encoders_kwargs.get("edge_out_dim", None)
# Set the default `node` input dimension for the pre-processing neural net and graph neural net
in_dim = in_dims["feat"]
if pe_out_dim is not None:
in_dim += pe_out_dim
if pre_nn_kwargs is not None:
pre_nn_kwargs = dict(pre_nn_kwargs)
pre_nn_kwargs.setdefault("in_dim", in_dim)
else:
gnn_kwargs.setdefault("in_dim", in_dim)
# Set the default `edge` input dimension for the pre-processing neural net and graph neural net
edge_in_dim = in_dims["edge_feat"]
if edge_pe_out_dim is not None:
edge_in_dim += edge_pe_out_dim
if pre_nn_edges_kwargs is not None:
pre_nn_edges_kwargs = dict(pre_nn_edges_kwargs)
pre_nn_edges_kwargs.setdefault("in_dim", edge_in_dim)
else:
gnn_kwargs.setdefault("in_dim", edge_in_dim)
# Set the parameters for the full network
if "finetuning" not in config:
task_heads_kwargs = omegaconf.OmegaConf.to_object(task_heads_kwargs)
# Set all the input arguments for the model
model_kwargs = dict(
gnn_kwargs=gnn_kwargs,
pre_nn_kwargs=pre_nn_kwargs,
pre_nn_edges_kwargs=pre_nn_edges_kwargs,
pe_encoders_kwargs=pe_encoders_kwargs,
graph_output_nn_kwargs=graph_output_nn_kwargs,
task_heads_kwargs=task_heads_kwargs,
)
# Get accelerator_kwargs if they exist
accelerator_kwargs = config["accelerator"].get("accelerator_kwargs", None)
if accelerator_kwargs is not None:
model_kwargs["accelerator_kwargs"] = accelerator_kwargs
if model_class is FullGraphFinetuningNetwork:
finetuning_head_kwargs = config["finetuning"].pop("finetuning_head", None)
pretrained_overwriting_kwargs = config["finetuning"].pop("overwriting_kwargs")
pretrained_model = pretrained_overwriting_kwargs.pop("pretrained_model")
model_kwargs = {
"pretrained_model_kwargs": deepcopy(model_kwargs),
"pretrained_overwriting_kwargs": pretrained_overwriting_kwargs,
"pretrained_model": pretrained_model,
"finetuning_head_kwargs": finetuning_head_kwargs,
}
return model_class, model_kwargs
def load_predictor(
config: Union[omegaconf.DictConfig, Dict[str, Any]],
model_class: Type[torch.nn.Module],
model_kwargs: Dict[str, Any],
metrics: Dict[str, MetricWrapper],
task_levels: Dict[str, str],
accelerator_type: str,
featurization: Dict[str, str] = None,
task_norms: Optional[Dict[Callable, Any]] = None,
replicas: int = 1,
gradient_acc: int = 1,
global_bs: int = 1,
) -> PredictorModule:
"""
Defining the predictor module, which handles the training logic from `lightning.LighningModule`
Parameters:
model_class: The torch Module containing the main forward function
accelerator_type: The accelerator type, e.g. "cpu", "gpu", "ipu"
Returns:
predictor: The predictor module
"""
if accelerator_type == "ipu":
from graphium.ipu.ipu_wrapper import PredictorModuleIPU
predictor_class = PredictorModuleIPU
else:
predictor_class = PredictorModule
cfg_pred = dict(deepcopy(config["predictor"]))
predictor = predictor_class(
model_class=model_class,
model_kwargs=model_kwargs,
metrics=metrics,
task_levels=task_levels,
featurization=featurization,
task_norms=task_norms,
replicas=replicas,
gradient_acc=gradient_acc,
global_bs=global_bs,
**cfg_pred,
)
mup_scale_factor = config["architecture"].pop("mup_scale_factor", None)
if mup_scale_factor is not None and mup_scale_factor != 1:
unscaled_model = predictor.model
scaled_model_kwargs = unscaled_model.scale_kwargs(scale_factor=mup_scale_factor)
del predictor
predictor = predictor_class(
model_class=model_class,
model_kwargs=scaled_model_kwargs,
metrics=metrics,
task_levels=task_levels,
featurization=featurization,
task_norms=task_norms,
replicas=replicas,
gradient_acc=gradient_acc,
global_bs=global_bs,
**cfg_pred,
)
# mup base shapes
mup_base_path = config["architecture"].pop("mup_base_path", None)
predictor = load_mup(mup_base_path, predictor)
return predictor
def load_mup(mup_base_path: str, predictor: PredictorModule) -> PredictorModule:
"""
Load the base shapes for the mup, based either on a `.ckpt` or `.yaml` file.
If `.yaml`, it should be generated by `mup.save_base_shapes`
"""
model = predictor.model
if not isinstance(model, MupMixin):
raise TypeError("load_mup can only be applied to models that use the MupMixin")
if mup_base_path is None:
base = model.__class__(**model.make_mup_base_kwargs(divide_factor=2))
elif mup_base_path.endswith(".ckpt"):
base = predictor.__class__.load_from_checkpoint(mup_base_path, map_location="cpu")
elif mup_base_path.endswith(".yaml"):
base = mup_base_path
else:
raise ValueError(f"Unrecognized file type {mup_base_path}")
predictor.model = set_base_shapes(predictor.model, base, rescale_params=False)
return predictor
def load_trainer(
config: Union[omegaconf.DictConfig, Dict[str, Any]],
accelerator_type: str,
date_time_suffix: str = "",
) -> Trainer:
"""
Defining the pytorch-lightning Trainer module.
Parameters:
config: The config file, with key `trainer`
accelerator_type: The accelerator type, e.g. "cpu", "gpu", "ipu"
date_time_suffix: The date and time of the current run. To be used for logging.
Returns:
trainer: the trainer module
"""
cfg_trainer = deepcopy(config["trainer"])
# Define the IPU plugin if required
strategy = cfg_trainer["trainer"].pop("strategy", "auto")
if accelerator_type == "ipu":
ipu_opts, ipu_inference_opts = _get_ipu_opts(config)
training_opts, inference_opts = load_ipu_options(
ipu_opts=ipu_opts,
ipu_inference_opts=ipu_inference_opts,
seed=config["constants"]["seed"],
model_name=config["constants"]["name"],
gradient_accumulation=config["trainer"]["trainer"].get("accumulate_grad_batches", None),
precision=config["trainer"]["trainer"].get("precision"),
)
if strategy != "auto":
raise ValueError("IPUs selected, but strategy is not set to 'auto'")
from lightning_graphcore import IPUStrategy
strategy = IPUStrategy(training_opts=training_opts, inference_opts=inference_opts)
# Get devices
devices = cfg_trainer["trainer"].pop("devices", 1)
if accelerator_type == "ipu":
devices = 1 # number of IPUs used is defined in the ipu options files
# Remove the gradient accumulation from IPUs, since it's handled by the device
if accelerator_type == "ipu":
cfg_trainer["trainer"].pop("accumulate_grad_batches", None)
# Define the early stopping parameters
trainer_kwargs = {}
callbacks = []
if "early_stopping" in cfg_trainer.keys():
callbacks.append(EarlyStopping(**cfg_trainer["early_stopping"]))
# Define the early model checkpoing parameters
if "model_checkpoint" in cfg_trainer.keys():
callbacks.append(ModelCheckpoint(**cfg_trainer["model_checkpoint"]))
if "learning_rate_monitor" in cfg_trainer.keys():
callbacks.append(LearningRateMonitor(**cfg_trainer["learning_rate_monitor"]))
else:
callbacks.append(LearningRateMonitor())
# Define the logger parameters
wandb_cfg = config["constants"].get("wandb")
if wandb_cfg is not None:
name = wandb_cfg.pop("name", "main")
if len(date_time_suffix) > 0:
name += f"_{date_time_suffix}"
trainer_kwargs["logger"] = WandbLogger(name=name, log_model=True, **wandb_cfg)
trainer_kwargs["callbacks"] = callbacks
trainer = Trainer(
detect_anomaly=True,
strategy=strategy,
accelerator=accelerator_type,
devices=devices,
**cfg_trainer["trainer"],
**trainer_kwargs,
)
return trainer
def save_params_to_wandb(
logger: Logger,
config: Union[omegaconf.DictConfig, Dict[str, Any]],
predictor: PredictorModule,
datamodule: MultitaskFromSmilesDataModule,
unresolved_config: Optional[Union[omegaconf.DictConfig, Dict[str, Any]]] = None,
):
"""
Save a few stuff to weights-and-biases WandB
Parameters:
logger: The object used to log the training. Usually WandbLogger
config: The config file, with key `trainer`
predictor: The predictor used to handle the train/val/test steps logic
datamodule: The datamodule used to load the data into training
unresolved_config: The unresolved config file
"""
# Get the wandb runner and directory
wandb_run = logger.experiment
if wandb_run is None:
wandb_dir = ""
else:
wandb_dir = wandb_run.dir
# Save the mup base model to WandB as a yaml file
mup.save_base_shapes(predictor.model, os.path.join(wandb_dir, "mup_base_params.yaml"))
# Save the full configs as a YAML file
with open(os.path.join(wandb_dir, "full_configs.yaml"), "w") as file:
yaml.dump(config, file)
if unresolved_config is not None:
with open(os.path.join(wandb_dir, "unresolved_config.yaml"), "w") as file:
yaml.dump(unresolved_config, file)
# Save the featurizer into wandb
featurizer_path = os.path.join(wandb_dir, "featurizer.pickle")
joblib.dump(datamodule.smiles_transformer, featurizer_path)
# Save the featurizer and configs into wandb
if wandb_run is not None:
wandb_run.save(os.path.join(wandb_dir, "*.yaml"), wandb_dir)
wandb_run.save(os.path.join(wandb_dir, "*.pickle"), wandb_dir)
def load_accelerator(config: Union[omegaconf.DictConfig, Dict[str, Any]]) -> Tuple[Dict[str, Any], str]:
config = deepcopy(config)
config_acc = config.get("accelerator", {})
# Merge the accelerator config with the main config
config_override = config_acc.get("config_override", {})
merge_dicts(config, config_override)
accelerator_type = get_accelerator(config_acc)
if accelerator_type == "gpu":
precision = config_acc.get("float32_matmul_precision", None)
if precision is not None:
torch.set_float32_matmul_precision(precision)
return config, accelerator_type
def load_config_override(
config: Union[omegaconf.DictConfig, Dict[str, Any]], main_dir: Optional[Union[str, os.PathLike]] = None
) -> Dict[str, Any]:
config = deepcopy(config)
config_override_path = config["constants"].get("config_override", None)
if config_override_path is not None:
if main_dir is not None:
config_override_path = os.path.join(main_dir, config_override_path)
with open(config_override_path, "r") as f:
cfg_override = yaml.safe_load(f)
config = merge_dicts(cfg_override, config, on_exist="overwrite")
return config
def load_yaml_config(
config_path: Union[str, os.PathLike],
main_dir: Optional[Union[str, os.PathLike]] = None,
unknown_args=None,
) -> Dict[str, Any]:
"""
Load a YAML config file and return it as a dictionary.
Also returns the anchors `&` and aliases `*` of the YAML file.
Then, update the config with the unknown arguments.
Finally, update the config with the config override file specified in `constants.config_override`.
Parameters:
config_path: The path to the YAML config file
main_dir: The main directory of the project. If specified, the config override file will be loaded from this directory
unknown_args: The unknown arguments to update the config with, taken from `argparse.parse_known_args`
Returns:
config: The config dictionary
"""
if main_dir is not None:
config_path = os.path.join(main_dir, config_path)
with open(config_path, "r") as f:
config = yaml.safe_load(f)
refs = get_anchors_and_aliases(config_path)
if unknown_args is not None:
config = update_config(config, unknown_args, refs)
config = load_config_override(config, main_dir) # This goes here to avoid overriding the hparam search
return config
def merge_dicts(
dict_a: Dict[str, Any], dict_b: Dict[str, Any], previous_dict_path: str = "", on_exist: str = "raise"
) -> None:
"""
Recursively merges dict_b into dict_a. If a key is missing from dict_a,
it is added from dict_b. If a key exists in both, an error is raised.
`dict_a` is modified in-place.
Parameters:
dict_a: The dictionary to merge into. Modified in-place.
dict_b: The dictionary to merge from.
previous_dict_path: The key path of the parent dictionary,
used to track the recursive calls.
on_exist: What to do if a key already exists in dict_a. Options are "raise", "overwrite", "ignore".
Raises:
ValueError: If a key path already exists in dict_a.
"""
assert on_exist in [
"raise",
"overwrite",
"ignore",
], f"on_exist must be one of ['raise', 'overwrite', 'ignore'], got {on_exist}"
for key, value_b in dict_b.items():
if key not in dict_a:
dict_a[key] = value_b
else:
value_a = dict_a[key]
if previous_dict_path == "":
previous_dict_path = key
else:
previous_dict_path = f"{previous_dict_path}/{key}"
if isinstance(value_a, dict) and isinstance(value_b, dict):
merge_dicts(value_a, value_b, previous_dict_path=previous_dict_path, on_exist=on_exist)
else:
if value_a != value_b:
if on_exist == "raise":
raise ValueError(f"Dict path already exists: {previous_dict_path}")
elif on_exist == "overwrite":
dict_a[key] = value_b
elif on_exist == "ignore":
pass
return dict_a
def get_checkpoint_path(config: Union[omegaconf.DictConfig, Dict[str, Any]]) -> str:
"""
Get the checkpoint path from a config file.
If the path is a valid name or a valid path, return it.
Otherwise, assume it refers to a file in the checkpointing dir.
"""
cfg_trainer = config["trainer"]
path = config.get("ckpt_name_for_testing", "last.ckpt")
if path in GRAPHIUM_PRETRAINED_MODELS_DICT or fs.exists(path):
return path
if "model_checkpoint" in cfg_trainer.keys():
dirpath = cfg_trainer["model_checkpoint"]["dirpath"]
path = fs.join(dirpath, path)
if not fs.exists(path):
raise ValueError(f"Checkpoint path `{path}` does not exist")
return path
================================================
FILE: graphium/config/config_convert.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import omegaconf
def recursive_config_reformating(configs):
r"""
For a given configuration file, convert all `DictConfig` to `dict`,
all `ListConfig` to `list`, and all `byte` to `str`.
This helps avoid errors when dumping a yaml file.
"""
if isinstance(configs, omegaconf.DictConfig):
configs = dict(configs)
elif isinstance(configs, omegaconf.ListConfig):
configs = list(configs)
if isinstance(configs, dict):
for k, v in configs.items():
if isinstance(v, bytes):
configs[k] = str(v)
else:
configs[k] = recursive_config_reformating(v)
elif isinstance(configs, list):
for k, v in enumerate(configs):
if isinstance(v, bytes):
configs[k] = str(v)
else:
configs[k] = recursive_config_reformating(v)
return configs
================================================
FILE: graphium/config/dummy_finetuning_from_gnn.yaml
================================================
# Here, we are finetuning a FullGraphMultitaskNetwork
# trained on ToyMix. We finetune from the gnn on the
# TDC dataset lipophilicity_astraceneca
# Here are the changes to the architecture:
# Change gnn:
# depth: 4 -> 4 - 2 + 3 = 5
#
# Keep modules after gnn and apply modifications
# graph_output_nn/graph:
# pooling: sum -> mean
# depth: 1 -> 2
# task_heads/zinc:
# new_sub_module: lipophilicity_astrazeneca
# out_dim: 3 -> 1
#
# Finetuning training:
# unfreeze one additional layer of pretrained gnn
# after 1 epochs, unfreeze all layers
###################################################
########### How to combine information ###########
###################################################
###########################
### FINETUNING-SPECIFIC ###
###########################
finetuning:
# New task
task: lipophilicity_astrazeneca
level: graph
# Pretrained model
pretrained_model: dummy-pretrained-model
finetuning_module: gnn
# Changes to finetuning_module
drop_depth: 2
added_depth: 3
keep_modules_after_finetuning_module: # optional
graph_output_nn-graph:
pooling: [mean]
depth: 2
task_heads-zinc:
new_sub_module: lipophilicity_astrazeneca
out_dim: 1
# Finetuning training
unfreeze_pretrained_depth: 1
epoch_unfreeze_all: 1
constants:
seed: 42
max_epochs: 2
accelerator:
float32_matmul_precision: medium
type: cpu
predictor:
random_seed: ${constants.seed}
optim_kwargs:
lr: 4.e-5
scheduler_kwargs: null
target_nan_mask: null
multitask_handling: flatten # flatten, mean-per-label
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: 2
warmup_epochs: 1
verbose: False
metrics_on_progress_bar:
lipophilicity_astrazeneca: ["mae"]
loss_fun:
lipophilicity_astrazeneca: mae
metrics:
lipophilicity_astrazeneca:
- name: mae
metric: mae
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: spearman
metric: spearmanr
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: pearson
metric: pearsonr
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2_score
metric: r2_score
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
trainer:
seed: ${constants.seed}
trainer:
precision: 32
max_epochs: 2
min_epochs: 1
check_val_every_n_epoch: 1
accumulate_grad_batches: 1
##################
### DATAMODULE ###
##################
datamodule:
### FROM FINETUNING ###
module_type: "ADMETBenchmarkDataModule"
args:
# TDC specific
tdc_benchmark_names: [lipophilicity_astrazeneca]
tdc_train_val_seed: ${constants.seed}
batch_size_training: 200
batch_size_inference: 200
featurization_n_jobs: 0
num_workers: 0
prepare_dict_or_graph: pyg:graph
featurization_progress: True
featurization_backend: "loky"
persistent_workers: False
================================================
FILE: graphium/config/dummy_finetuning_from_task_head.yaml
================================================
# Here, we are finetuning a FullGraphMultitaskNetwork
# trained on ToyMix. We finetune from the zinc task-head
# (graph-level) on the TDC dataset lipophilicity_astraceneca
# Here are the changes to the architecture:
# Change zinc task-head:
# depth: 2 -> 2 - 1 + 2 = 3
# out_dim: 3 -> 8
#
# Add finetuning head
# model_type: FeedForwardNN
# out_dim: 1
# hidden_dims: 8
# depth: 2
#
# Finetuning training:
# after 1 epochs, unfreeze all layers
###################################################
########### How to combine information ###########
###################################################
###########################
### FINETUNING-SPECIFIC ###
###########################
finetuning:
# New task
task: lipophilicity_astrazeneca
level: graph
# Pretrained model
pretrained_model: dummy-pretrained-model
finetuning_module: task_heads
sub_module_from_pretrained: zinc # optional
new_sub_module: lipophilicity_astrazeneca # optional
# keep_modules_after_finetuning_module: # optional
# Changes to finetuning_module
drop_depth: 1
new_out_dim: 8
added_depth: 2
# Optional finetuning head appended to model after finetuning_module
finetuning_head: # none
task: lipophilicity_astrazeneca
previous_module: task_heads
incoming_level: graph
model_type: mlp
in_dim: 8
out_dim: 1
hidden_dims: 8
depth: 2
last_layer_is_readout: true
# Finetuning training
unfreeze_pretrained_depth: 0
epoch_unfreeze_all: 1
constants:
seed: 42
max_epochs: 2
accelerator:
float32_matmul_precision: medium
type: cpu
predictor:
random_seed: ${constants.seed}
optim_kwargs:
lr: 4.e-5
scheduler_kwargs: null
target_nan_mask: null
multitask_handling: flatten # flatten, mean-per-label
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: 2
warmup_epochs: 1
verbose: False
metrics_on_progress_bar:
lipophilicity_astrazeneca: ["mae"]
loss_fun:
lipophilicity_astrazeneca: mae
metrics:
lipophilicity_astrazeneca:
- name: mae
metric: mae
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: spearman
metric: spearmanr
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: pearson
metric: pearsonr
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2_score
metric: r2_score
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
trainer:
seed: ${constants.seed}
trainer:
precision: 32
max_epochs: 2
min_epochs: 1
check_val_every_n_epoch: 1
accumulate_grad_batches: 1
##################
### DATAMODULE ###
##################
datamodule:
### FROM FINETUNING ###
module_type: "ADMETBenchmarkDataModule"
args:
# TDC specific
tdc_benchmark_names: [lipophilicity_astrazeneca]
tdc_train_val_seed: ${constants.seed}
batch_size_training: 200
batch_size_inference: 200
featurization_n_jobs: 0
num_workers: 0
prepare_dict_or_graph: pyg:graph
featurization_progress: True
featurization_backend: "loky"
persistent_workers: False
================================================
FILE: graphium/config/fake_and_missing_multilevel_multitask_pyg.yaml
================================================
datamodule:
module_type: "MultitaskFromSmilesDataModule"
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
score:
df: null
task_level: "edge"
df_path: "./tests/fake_and_missing_multilevel_data.parquet"
smiles_col: "ordered_smiles"
label_cols: ["edge_label_list", "edge_label_np"]
split_val: 0.0
split_test: 0.0
seed: 19
splits_path: null
sample_size: null
idx_col: null
weights_col: null
weights_type: null
logp:
df: null
task_level: "node"
df_path: "./tests/fake_and_missing_multilevel_data.parquet"
smiles_col: "ordered_smiles"
label_cols: ["node_label_list", "node_label_np"]
split_val: 0.0
split_test: 0.0
seed: 19
splits_path: null
sample_size: null
idx_col: null
weights_col: null
weights_type: null
SA:
df: null
task_level: "graph"
df_path: "./tests/fake_and_missing_multilevel_data.parquet"
smiles_col: "ordered_smiles"
label_cols: ["graph_label"]
split_val: 0.0
split_test: 0.0
seed: 19
splits_path: null # This may not always be provided
sample_size: null # This may not always be provided
idx_col: null # This may not always be provided
weights_col: null # This may not always be provided
# Featurization
featurization_n_jobs: 16
featurization_progress: True
featurization:
atom_property_list_onehot: ["atomic-number", "degree"]
atom_property_list_float: []
edge_property_list: ["in-ring", "bond-type-onehot"]
add_self_loop: False
explicit_H: False
use_bonds_weights: False
# Data handling-related
batch_size_training: 16
batch_size_inference: 16
architecture: # The parameters for the full graph network are taken from `config_micro_ZINC.yaml`
model_type: FullGraphMultiTaskNetwork
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 32
depth: 1
activation: relu
last_activation: none
dropout: &dropout 0.1
normalization: &normalization "batch_norm"
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 16
hidden_dims: 16
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
gnn: # Set as null to avoid a post-nn network
out_dim: 32
hidden_dims: 32
depth: 4
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: random
virtual_node: 'sum'
layer_type: 'pyg:pna-msgpass'
layer_kwargs:
# num_heads: 3
aggregators: [mean, max]
scalers: [identity, amplification, attenuation]
graph_output_nn:
node:
out_dim: 16
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
graph:
pooling: [sum, max]
out_dim: 1
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
edge:
out_dim: 16
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
nodepair:
out_dim: 16
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads: # Set as null to avoid task heads. Recall that the arguments for the TaskHeads is a List of TaskHeadParams
task_1:
task_level: "node"
out_dim: 5
hidden_dims: [5, 6, 7]
#depth: none # Not needed if we have hidden_dims
activation: relu
last_activation: none
dropout: 0.2
normalization: none
residual_type: none
task_2:
task_level: "edge"
out_dim: 3
hidden_dims: [8, 9, 10]
activation: relu
last_activation: none
dropout: 0.2
normalization: none
residual_type: none
task_3:
task_level: "graph"
out_dim: 4
hidden_dims: [2, 2, 2]
activation: relu
last_activation: none
dropout: 0.2
normalization: none
residual_type: none
task_4:
task_level: "nodepair"
out_dim: 4
hidden_dims: [2, 2, 2]
activation: relu
last_activation: none
dropout: 0.2
normalization: none
residual_type: none
================================================
FILE: graphium/config/fake_multilevel_multitask_pyg.yaml
================================================
datamodule:
module_type: "MultitaskFromSmilesDataModule"
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
score:
df: null
task_level: "edge"
df_path: "./tests/converted_fake_multilevel_data.parquet"
smiles_col: "ordered_smiles"
label_cols: ["edge_label_list", "edge_label_np"]
split_val: 0.0
split_test: 0.0
seed: 19
splits_path: null
sample_size: null
idx_col: null
weights_col: null
weights_type: null
logp:
df: null
task_level: "node"
df_path: "./tests/converted_fake_multilevel_data.parquet"
smiles_col: "ordered_smiles"
label_cols: ["node_label_list", "node_label_np"]
split_val: 0.0
split_test: 0.0
seed: 19
splits_path: null
sample_size: null
idx_col: null
weights_col: null
weights_type: null
SA:
df: null
task_level: "graph"
df_path: "./tests/converted_fake_multilevel_data.parquet"
smiles_col: "ordered_smiles"
label_cols: ["graph_label"]
split_val: 0.0
split_test: 0.0
seed: 19
splits_path: null # This may not always be provided
sample_size: null # This may not always be provided
idx_col: null # This may not always be provided
weights_col: null # This may not always be provided
# Featurization
featurization_n_jobs: 16
featurization_progress: True
featurization:
atom_property_list_onehot: ["atomic-number", "degree"]
atom_property_list_float: []
edge_property_list: ["in-ring", "bond-type-onehot"]
add_self_loop: False
explicit_H: False
use_bonds_weights: False
# Data handling-related
batch_size_training: 16
batch_size_inference: 16
architecture: # The parameters for the full graph network are taken from `config_micro_ZINC.yaml`
model_type: FullGraphMultiTaskNetwork
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 32
depth: 1
activation: relu
last_activation: none
dropout: &dropout 0.1
normalization: &normalization "batch_norm"
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 16
hidden_dims: 16
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
gnn: # Set as null to avoid a post-nn network
out_dim: 32
hidden_dims: 32
depth: 4
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: random
virtual_node: 'sum'
layer_type: 'pyg:pna-msgpass'
layer_kwargs:
# num_heads: 3
aggregators: [mean, max]
scalers: [identity, amplification, attenuation]
graph_output_nn:
node:
out_dim: 16
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
graph:
pooling: [sum, max]
out_dim: 1
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
edge:
out_dim: 16
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
nodepair:
out_dim: 16
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads: # Set as null to avoid task heads. Recall that the arguments for the TaskHeads is a List of TaskHeadParams
task_1:
task_level: "node"
out_dim: 5
hidden_dims: [5, 6, 7]
#depth: none # Not needed if we have hidden_dims
activation: relu
last_activation: none
dropout: 0.2
normalization: none
residual_type: none
task_2:
task_level: "edge"
out_dim: 3
hidden_dims: [8, 9, 10]
activation: relu
last_activation: none
dropout: 0.2
normalization: none
residual_type: none
task_3:
task_level: "graph"
out_dim: 4
hidden_dims: [2, 2, 2]
activation: relu
last_activation: none
dropout: 0.2
normalization: none
residual_type: none
task_4:
task_level: "nodepair"
out_dim: 4
hidden_dims: [2, 2, 2]
activation: relu
last_activation: none
dropout: 0.2
normalization: none
residual_type: none
================================================
FILE: graphium/config/zinc_default_multitask_pyg.yaml
================================================
datamodule:
module_type: "MultitaskFromSmilesDataModule"
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
SA:
df: null
task_level: "graph"
df_path: "graphium/data/multitask/tiny_ZINC_SA.csv"
smiles_col: "SMILES"
label_cols: ["SA"]
split_val: 0.2
split_test: 0.2
seed: 19
splits_path: null # This may not always be provided
sample_size: null # This may not always be provided
idx_col: null # This may not always be provided
weights_col: null # This may not always be provided
logp:
df: null
task_level: "graph"
df_path: "graphium/data/multitask/tiny_ZINC_logp.csv"
smiles_col: "SMILES"
label_cols: ["logp"]
split_val: 0.2
split_test: 0.2
seed: 19
splits_path: null
sample_size: null
idx_col: null
weights_col: null
weights_type: null
score:
df: null
task_level: "graph"
df_path: "graphium/data/multitask/tiny_ZINC_score.csv"
smiles_col: "SMILES"
label_cols: ["score"]
split_val: 0.2
split_test: 0.2
seed: 19
splits_path: null
sample_size: null
idx_col: null
weights_col: null
weights_type: null
# Featurization
featurization_n_jobs: 16
featurization_progress: True
featurization:
atom_property_list_onehot: ["atomic-number", "degree"]
atom_property_list_float: []
edge_property_list: ["in-ring", "bond-type-onehot"]
add_self_loop: False
explicit_H: False
use_bonds_weights: False
# Data handling-related
batch_size_training: 16
batch_size_inference: 16
architecture: # The parameters for the full graph network are taken from `config_micro_ZINC.yaml`
model_type: FullGraphMultiTaskNetwork
pre_nn: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 32
depth: 1
activation: relu
last_activation: none
dropout: &dropout 0.1
normalization: &normalization "batch_norm"
last_normalization: *normalization
residual_type: none
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 16
hidden_dims: 16
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: none
gnn: # Set as null to avoid a post-nn network
out_dim: 32
hidden_dims: 32
depth: 4
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: *normalization
residual_type: random
virtual_node: 'sum'
layer_type: 'pyg:pna-msgpass'
layer_kwargs:
# num_heads: 3
aggregators: [mean, max]
scalers: [identity, amplification, attenuation]
graph_output_nn:
node:
out_dim: 16
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
graph:
pooling: [sum, max]
out_dim: 1
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
edge:
out_dim: 16
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
nodepair:
out_dim: 16
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none
task_heads: # Set as null to avoid task heads. Recall that the arguments for the TaskHeads is a List of TaskHeadParams
task_1:
task_level: "node"
out_dim: 5
hidden_dims: [5, 6, 7]
#depth: none # Not needed if we have hidden_dims
activation: relu
last_activation: none
dropout: 0.2
normalization: none
residual_type: none
task_2:
task_level: "edge"
out_dim: 3
hidden_dims: [8, 9, 10]
activation: relu
last_activation: none
dropout: 0.2
normalization: none
residual_type: none
task_3:
task_level: "graph"
out_dim: 4
hidden_dims: [2, 2, 2]
activation: relu
last_activation: none
dropout: 0.2
normalization: none
residual_type: none
task_4:
task_level: "nodepair"
out_dim: 4
hidden_dims: [2, 2, 2]
activation: relu
last_activation: none
dropout: 0.2
normalization: none
residual_type: none
accelerator:
type: cpu
================================================
FILE: graphium/data/L1000/datasets_goli-L1000_parse_gctx_to_csv.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Imports\n",
"from os import makedirs\n",
"import pandas as pd\n",
"import numpy as np\n",
"from cmapPy.pandasGEXpress.parse import parse\n",
"from rdkit.Chem.AllChem import MolFromSmiles, MolToSmiles\n",
"from IPython.display import display\n",
"from matplotlib import pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Constants\n",
"COMP_FILE = \"metadata/compoundinfo_beta.txt\"\n",
"GENE_FILE = \"metadata/geneinfo_beta.txt\"\n",
"INFO_FILE = \"data/instinfo_beta.txt\"\n",
"LEVEL5_FILE = \"data/level5_beta_trt_cp_n720216x12328.gctx\"\n",
"THRESHOLD = 4\n",
"THRESHOLD_lst = [-6, -3, 3, 6]\n",
"OUT_DIR_LST = []\n",
"OUT_DIR = f\"out/th_{THRESHOLD}/\"\n",
"OUT_DIR_LST.append(f\"out/less_th_{THRESHOLD_lst[0]}/\")\n",
"for i in range(len(THRESHOLD_lst)-1):\n",
" OUT_DIR_LST.append(f\"out/th_{THRESHOLD_lst[i]}_{THRESHOLD_lst[i+1]}/\")\n",
"OUT_DIR_LST.append(f\"out/bigger_th_{THRESHOLD_lst[3]}/\")\n",
"# U2OS: Bone cancer line with most data (20k mols).\n",
"# HA1E: Non-cancer kidney line with the most data (5k mols).\n",
"CHOSEN_CELL_LINES = [\"U2OS\", \"HA1E\", \"VCAP\", \"A549\", \"MCF7\", \"PC3\", \"A375\"]\n",
"MIN_ACTIVE_PER_COL = 15"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Function to convert a SMILES to a canonical SMILES, or return None in case of error\n",
"def canonical_smiles(smiles):\n",
" out = None\n",
" try:\n",
" out = MolToSmiles(MolFromSmiles(smiles))\n",
" except:\n",
" out = None\n",
"\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Read the compounds data. This allows to relate the perturbation identification \"pert_id\" to molecular SMILES\n",
"comp_df = pd.read_csv(COMP_FILE, sep=\"\\t\", index_col=\"pert_id\", \n",
" usecols=[\"pert_id\", \"canonical_smiles\", \"inchi_key\", \"compound_aliases\"])\n",
"comp_df = comp_df.rename(columns={\"canonical_smiles\": \"SMILES\"})\n",
"print(comp_df.shape)\n",
"comp_df"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(978, 2)\n"
]
},
{
"data": {
"text/html": [
"
## What is in this folder?
- ✅ `datamodule.py`: loading from disc and process into dataloader
- `collate.py` : defining the collate function to use for the dataloader
- `dataset.py`: defining the dataset class
- `normalization.py`: defining label normalization
- `sdf2csv.py`: convertion function from sdf to csv for the ogb pcqm dataset
- `smiles_transform.py`: transform smiles to molecule object
- `utils.py`: utility functions for dataset loading
================================================
FILE: graphium/data/__init__.py
================================================
from .utils import load_micro_zinc
from .utils import load_tiny_zinc
from .collate import graphium_collate_fn
from .datamodule import GraphOGBDataModule
from .datamodule import MultitaskFromSmilesDataModule
from .datamodule import ADMETBenchmarkDataModule
from .datamodule import FakeDataModule
from .dataset import SingleTaskDataset
from .dataset import MultitaskDataset
from .dataset import FakeDataset
================================================
FILE: graphium/data/collate.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from collections.abc import Mapping, Sequence
# from pprint import pprint
import torch
from numpy import ndarray
from scipy.sparse import spmatrix
from torch.utils.data.dataloader import default_collate
from typing import Union, List, Optional, Dict, Type, Any, Iterable
from torch_geometric.data import Data, Batch
from graphium.features import GraphDict, to_dense_array
from graphium.utils.packing import fast_packing, get_pack_sizes, node_to_pack_indices_mask
from loguru import logger
from graphium.data.utils import get_keys
def graphium_collate_fn(
elements: Union[List[Any], Dict[str, List[Any]]],
labels_size_dict: Optional[Dict[str, Any]] = None,
labels_dtype_dict: Optional[Dict[str, Any]] = None,
mask_nan: Union[str, float, Type[None]] = "raise",
do_not_collate_keys: List[str] = [],
batch_size_per_pack: Optional[int] = None,
) -> Union[Any, Dict[str, Any]]:
"""This collate function is identical to the default
pytorch collate function but add support for `pyg.data.Data` to batch graphs.
Beside pyg graph collate, other objects are processed the same way
as the original torch collate function. See https://pytorch.org/docs/stable/data.html#dataloader-collate-fn
for more details.
Note:
If graphium needs to manipulate other tricky-to-batch objects. Support
for them should be added to this single collate function.
Parameters:
elements:
The elements to batch. See `torch.utils.data.dataloader.default_collate`.
labels_size_dict:
(Note): This is an attribute of the `MultitaskDataset`.
A dictionary of the form Dict[tasks, sizes] which has task names as keys
and the size of the label tensor as value. The size of the tensor corresponds to how many
labels/values there are to predict for that task.
labels_dtype_dict:
(Note): This is an attribute of the `MultitaskDataset`.
A dictionary of the form Dict[tasks, dtypes] which has task names as keys
and the dtype of the label tensor as value. This is necessary to ensure the missing labels are added with NaNs of the right dtype
mask_nan:
Deal with the NaN/Inf when calling the function `make_pyg_graph`.
Some values become `Inf` when changing data type. This allows to deal
with that.
- "raise": Raise an error when there is a nan or inf in the featurization
- "warn": Raise a warning when there is a nan or inf in the featurization
- "None": DEFAULT. Don't do anything
- "Floating value": Replace nans or inf by the specified value
do_not_batch_keys:
Keys to ignore for the collate
batch_size_per_pack: The number of graphs to pack together.
This is useful for using packing with the Transformer.
If None, no packing is done.
Otherwise, indices are generated to map the nodes to the pack they belong to under the key `"pack_from_node_idx"`,
with an additional mask to indicate which nodes are from the same graph under the key `"pack_attn_mask"`.
Returns:
The batched elements. See `torch.utils.data.dataloader.default_collate`.
"""
elem = elements[0]
if isinstance(elem, Mapping):
batch = {}
for key in elem:
# Multitask setting: We have to pad the missing labels
if key == "labels":
labels = [d[key] for d in elements]
batch[key] = collate_labels(labels, labels_size_dict, labels_dtype_dict)
# If the features are a dictionary containing GraphDict elements,
# Convert to pyg graphs and use the pyg batching.
elif isinstance(elem[key], GraphDict):
pyg_graphs = [d[key].make_pyg_graph(mask_nan=mask_nan) for d in elements]
batch[key] = collage_pyg_graph(pyg_graphs)
# If a PyG Graph is provided, use the PyG batching
elif isinstance(elem[key], Data):
pyg_graphs = [d[key] for d in elements]
batch[key] = collage_pyg_graph(pyg_graphs, batch_size_per_pack=batch_size_per_pack)
# Ignore the collate for specific keys
elif key in do_not_collate_keys:
batch[key] = [d[key] for d in elements]
# Otherwise, use the default torch batching
else:
batch[key] = default_collate([d[key] for d in elements])
return batch
elif isinstance(elements, Sequence) and isinstance(elem, Sequence):
temp_elements = [{ii: sub_elem for ii, sub_elem in enumerate(elem)} for elem in elements]
batch = graphium_collate_fn(temp_elements)
return list(batch.values())
elif isinstance(elements, Sequence) and not isinstance(elem, Sequence):
temp_elements = [{"temp_key": elem} for elem in elements]
batch = graphium_collate_fn(temp_elements)
return batch["temp_key"]
else:
return default_collate(elements)
def collage_pyg_graph(pyg_graphs: Iterable[Union[Data, Dict]], batch_size_per_pack: Optional[int] = None):
"""
Function to collate pytorch geometric graphs.
Convert all numpy types to torch
Convert edge indices to int64
Parameters:
pyg_graphs: Iterable of PyG graphs
batch_size_per_pack: The number of graphs to pack together.
This is useful for using packing with the Transformer,
"""
# Calculate maximum number of nodes per graph in current batch
num_nodes_list = []
for pyg_graph in pyg_graphs:
num_nodes_list.append(pyg_graph["num_nodes"])
max_num_nodes_per_graph = max(num_nodes_list)
pyg_batch = []
for pyg_graph in pyg_graphs:
for pyg_key in get_keys(pyg_graph):
tensor = pyg_graph[pyg_key]
# Convert numpy/scipy to Pytorch
if isinstance(tensor, (ndarray, spmatrix)):
tensor = torch.as_tensor(to_dense_array(tensor, tensor.dtype))
# pad nodepair-level positional encodings
if pyg_key.startswith("nodepair_"):
pyg_graph[pyg_key] = pad_nodepairs(tensor, pyg_graph["num_nodes"], max_num_nodes_per_graph)
else:
pyg_graph[pyg_key] = tensor
# Convert edge index to int64
pyg_graph.edge_index = pyg_graph.edge_index.to(torch.int64)
pyg_batch.append(pyg_graph)
# Apply the packing at the mini-batch level. This is useful for using packing with the Transformer,
# especially in the case of the large graphs being much larger than the small graphs.
# CAREFUL!!! This changes the order of the graphs in the batch, without changing the order of the labels or other objects.
# An error is raised temporarily.
if batch_size_per_pack is not None:
raise NotImplementedError(
"Packing is not yet functional, as it changes the order of the graphs in the batch without changing the label order"
)
num_nodes = [g.num_nodes for g in pyg_batch]
packed_graph_idx = fast_packing(num_nodes, batch_size_per_pack)
# Get the node to pack indices and the mask
pack_from_node_idx, pack_attn_mask = node_to_pack_indices_mask(packed_graph_idx, num_nodes)
for pyg_graph in pyg_batch:
pyg_graph.pack_from_node_idx = pack_from_node_idx
pyg_graph.pack_attn_mask = pack_attn_mask
return Batch.from_data_list(pyg_batch)
def pad_to_expected_label_size(labels: torch.Tensor, label_size: List[int]):
"""Determine difference of ``labels`` shape to expected shape `label_size` and pad
with ``torch.nan`` accordingly.
"""
if label_size == list(labels.shape):
return labels
missing_dims = len(label_size) - len(labels.shape)
for _ in range(missing_dims):
labels.unsqueeze(-1)
pad_sizes = [(0, expected - actual) for expected, actual in zip(label_size, labels.shape)]
pad_sizes = [item for before_after in pad_sizes for item in before_after]
pad_sizes.reverse()
if any([s < 0 for s in pad_sizes]):
logger.warning(f"More labels available than expected. Will remove data to fit expected size.")
return torch.nn.functional.pad(labels, pad_sizes, value=torch.nan)
def collate_pyg_graph_labels(pyg_labels: List[Data]):
"""
Function to collate pytorch geometric labels.
Convert all numpy types to torch
Parameters:
pyg_labels: Iterable of PyG label Data objects
"""
pyg_batch = []
for pyg_label in pyg_labels:
for pyg_key in set(get_keys(pyg_label)) - set(["x", "edge_index"]):
tensor = pyg_label[pyg_key]
# Convert numpy/scipy to Pytorch
if isinstance(tensor, (ndarray, spmatrix)):
tensor = torch.as_tensor(to_dense_array(tensor, tensor.dtype))
pyg_label[pyg_key] = tensor
pyg_batch.append(pyg_label)
return Batch.from_data_list(pyg_batch)
def get_expected_label_size(label_data: Data, task: str, label_size: List[int]):
"""Determines expected label size based on the specfic graph properties
and the number of targets in the task-dataset.
"""
if task.startswith("graph_"):
num_labels = 1
elif task.startswith("node_"):
num_labels = label_data.x.size(0)
elif task.startswith("edge_"):
num_labels = label_data.edge_index.size(1)
elif task.startswith("nodepair_"):
raise NotImplementedError()
return [num_labels] + label_size
def collate_labels(
labels: List[Data],
labels_size_dict: Optional[Dict[str, Any]] = None,
labels_dtype_dict: Optional[Dict[str, Any]] = None,
):
"""Collate labels for multitask learning.
Parameters:
labels: List of labels
labels_size_dict: Dict of the form Dict[tasks, sizes] which has task names as keys
and the size of the label tensor as value. The size of the tensor corresponds to how many
labels/values there are to predict for that task.
labels_dtype_dict:
(Note): This is an attribute of the `MultitaskDataset`.
A dictionary of the form Dict[tasks, dtypes] which has task names as keys
and the dtype of the label tensor as value. This is necessary to ensure the missing labels are added with NaNs of the right dtype
Returns:
A dictionary of the form Dict[tasks, labels] where tasks is the name of the task and labels
is a tensor of shape (batch_size, *labels_size_dict[task]).
"""
if labels_size_dict is not None:
for this_label in labels:
for task in labels_size_dict.keys():
labels_size_dict[task] = list(labels_size_dict[task])
if len(labels_size_dict[task]) >= 2:
labels_size_dict[task] = labels_size_dict[task][1:]
elif not task.startswith("graph_"):
labels_size_dict[task] = [1]
label_keys_set = set(get_keys(this_label))
empty_task_labels = set(labels_size_dict.keys()) - label_keys_set
for task in empty_task_labels:
labels_size_dict[task] = get_expected_label_size(this_label, task, labels_size_dict[task])
dtype = labels_dtype_dict[task]
this_label[task] = torch.full([*labels_size_dict[task]], torch.nan, dtype=dtype)
for task in label_keys_set - set(["x", "edge_index"]) - empty_task_labels:
labels_size_dict[task] = get_expected_label_size(this_label, task, labels_size_dict[task])
if not isinstance(this_label[task], (torch.Tensor)):
this_label[task] = torch.as_tensor(this_label[task])
# Ensure explicit task dimension also for single task labels
if len(this_label[task].shape) == 1:
# Distinguish whether target dim or entity dim is missing
if labels_size_dict[task][0] == this_label[task].shape[0]:
# num graphs/nodes/edges/nodepairs already matching
this_label[task] = this_label[task].unsqueeze(1)
else:
# data lost unless entity dim is supposed to be 1
if labels_size_dict[task][0] == 1:
this_label[task] = this_label[task].unsqueeze(0)
else:
raise ValueError(
f"Labels for {labels_size_dict[task][0]} nodes/edges/nodepairs expected, got 1."
)
this_label[task] = pad_to_expected_label_size(this_label[task], labels_size_dict[task])
return collate_pyg_graph_labels(labels)
def pad_nodepairs(pe: torch.Tensor, num_nodes: int, max_num_nodes_per_graph: int):
"""
This function zero-pads nodepair-level positional encodings to conform with the batching logic.
Parameters:
pe (torch.Tensor, [num_nodes, num_nodes, num_feat]): Nodepair pe
num_nodes (int): Number of nodes of processed graph
max_num_nodes_per_graph (int): Maximum number of nodes among graphs in current batch
Returns:
padded_pe (torch.Tensor, [num_nodes, max_num_nodes_per_graph, num_feat]): padded nodepair pe tensor
"""
padded_pe = torch.zeros((num_nodes, max_num_nodes_per_graph, pe.size(-1)), dtype=pe.dtype)
padded_pe[:, :num_nodes] = pe[:, :num_nodes]
# Above, pe[:, :num_nodes] in the rhs is needed to "overwrite" zero-padding from previous epoch
return padded_pe
================================================
FILE: graphium/data/datamodule.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import tempfile
from contextlib import redirect_stderr, redirect_stdout
from typing import Type, List, Dict, Union, Any, Callable, Optional, Tuple, Iterable, Literal
from os import PathLike as Path
from dataclasses import dataclass
import os
from functools import partial
import importlib.resources
import zipfile
from copy import deepcopy
import time
import gc
import platformdirs
import re
from graphium.data.utils import get_keys
from loguru import logger
import fsspec
import omegaconf
import pandas as pd
import numpy as np
import datamol as dm
from tqdm import tqdm
import os.path as osp
from fastparquet import ParquetFile
from sklearn.model_selection import train_test_split
import lightning
from lightning.pytorch.trainer.states import RunningStage
import torch
from torch.utils.data.dataloader import DataLoader, Dataset
from torch.utils.data import Subset
from graphium.utils import fs
from graphium.features import (
mol_to_graph_dict,
GraphDict,
mol_to_pyggraph,
)
from graphium.data.sampler import DatasetSubSampler
from graphium.data.utils import graphium_package_path, found_size_mismatch
from graphium.utils.arg_checker import check_arg_iterator
from graphium.utils.hashing import get_md5_hash
from graphium.data.smiles_transform import (
did_featurization_fail,
BatchingSmilesTransform,
smiles_to_unique_mol_ids,
)
from graphium.data.collate import graphium_collate_fn
import graphium.data.dataset as Datasets
from graphium.data.normalization import LabelNormalization
from graphium.data.multilevel_utils import extract_labels
torch.multiprocessing.set_sharing_strategy("file_system")
PCQM4M_meta = {
"num tasks": 1,
"eval metric": "mae",
"download_name": "pcqm4m_kddcup2021",
"url": "https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip", # TODO: Allow PyG
"data type": "mol",
"has_node_attr": True,
"has_edge_attr": True,
"task type": "regression",
"num classes": -1,
"split": "scaffold",
"additional node files": "None",
"additional edge files": "None",
"binary": False,
"version": 1,
}
PCQM4Mv2_meta = deepcopy(PCQM4M_meta)
PCQM4Mv2_meta.update(
{
"download_name": "pcqm4m-v2",
"url": "https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m-v2.zip", # TODO: Allow PyG
"version": 2,
}
)
class BaseDataModule(lightning.LightningDataModule):
def __init__(
self,
batch_size_training: int = 16,
batch_size_inference: int = 16,
batch_size_per_pack: Optional[int] = None,
num_workers: int = 0,
pin_memory: bool = True,
persistent_workers: bool = False,
multiprocessing_context: Optional[str] = None,
collate_fn: Optional[Callable] = None,
):
"""
base dataset module for all datasets (to be inherented)
Parameters:
batch_size_training: batch size for training
batch_size_inference: batch size for inference
num_workers: number of workers for data loading
pin_memory: whether to pin memory
persistent_workers: whether to use persistent workers
multiprocessing_context: multiprocessing context for data worker creation
collate_fn: collate function for batching
"""
super().__init__()
self.batch_size_training = batch_size_training
self.batch_size_inference = batch_size_inference
self.batch_size_per_pack = batch_size_per_pack
if self.batch_size_per_pack is not None:
# Check that batch_size_per_pack is a divisor of batch_size_training and batch_size_inference
assert (
self.batch_size_training % self.batch_size_per_pack == 0
), f"batch_size_training must be a multiple of batch_size_per_pack, provided batch_size_training={self.batch_size_training}, batch_size_per_pack={self.batch_size_per_pack}"
assert (
self.batch_size_inference % self.batch_size_per_pack == 0
), f"batch_size_inference must be a multiple of batch_size_per_pack, provided batch_size_inference={self.batch_size_inference}, batch_size_per_pack={self.batch_size_per_pack}"
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.multiprocessing_context = multiprocessing_context
self.collate_fn = self.get_collate_fn(collate_fn)
self.train_ds = None
self.val_ds = None
self.test_ds = None
self._predict_ds = None
self._data_is_prepared = False
self._data_is_cached = False
def prepare_data(self):
raise NotImplementedError()
def setup(self):
raise NotImplementedError()
def train_dataloader(self, **kwargs):
"""
return the training dataloader
"""
return self.get_dataloader(
dataset=self.train_ds, # type: ignore
shuffle=True,
stage=RunningStage.TRAINING,
**kwargs,
)
def val_dataloader(self, **kwargs):
r"""
return the validation dataloader
"""
return self.get_dataloader(
dataset=self.val_ds, # type: ignore
shuffle=False,
stage=RunningStage.VALIDATING,
**kwargs,
)
def test_dataloader(self, **kwargs):
r"""
return the test dataloader
"""
return self.get_dataloader(
dataset=self.test_ds, # type: ignore
shuffle=False,
stage=RunningStage.TESTING,
**kwargs,
)
def predict_dataloader(self, **kwargs):
"""
return the dataloader for prediction
"""
return self.get_dataloader(
dataset=self.predict_ds, # type: ignore
shuffle=False,
stage=RunningStage.PREDICTING,
**kwargs,
)
def get_collate_fn(self, collate_fn):
if collate_fn is None:
# Some values become `inf` when changing data type. `mask_nan` deals with that
collate_fn = partial(
graphium_collate_fn, mask_nan=0, batch_size_per_pack=self.batch_size_per_pack
)
collate_fn.__name__ = graphium_collate_fn.__name__
return collate_fn
@property
def is_prepared(self):
raise NotImplementedError()
@property
def is_setup(self):
raise NotImplementedError()
@property
def num_node_feats(self):
raise NotImplementedError()
@property
def num_edge_feats(self):
raise NotImplementedError()
@property
def predict_ds(self):
"""Get the dataset used for the prediction"""
if self._predict_ds is None:
return self.test_ds
else:
return self._predict_ds
@property
def get_num_workers(self):
"""
get the number of workers to use
"""
if self.num_workers == -1:
num_workers = os.cpu_count()
num_workers = num_workers if num_workers is not None else 0
else:
num_workers = self.num_workers
return num_workers
@predict_ds.setter
def predict_ds(self, value):
"""Set the dataset for the prediction"""
self._predict_ds = value
def get_fake_graph(self):
raise NotImplementedError()
# Private methods
@staticmethod
def _read_csv(
path: str,
**kwargs,
) -> pd.DataFrame:
"""
private method for reading a csv file
Parameters:
path: path to the csv file
kwargs: keyword arguments for pd.read_csv
Returns:
pd.DataFrame: the panda dataframe storing molecules
"""
path = str(path)
if path.endswith((".csv", ".csv.gz", ".csv.zip", ".csv.bz2")):
sep = ","
elif path.endswith((".tsv", ".tsv.gz", ".tsv.zip", ".tsv.bz2")):
sep = "\t"
else:
raise ValueError(f"unsupported file `{path}`")
kwargs.setdefault("sep", sep)
if path.startswith("graphium://"):
path = graphium_package_path(path)
df = pd.read_csv(path, **kwargs)
return df
@staticmethod
def _get_data_file_type(path):
# Extract the extension
name, ext = os.path.splitext(path)
# support compressed files
_, ext2 = os.path.splitext(name)
if ext2 != "":
ext = f"{ext2}{ext}"
if ext.endswith((".parquet")): # Support parquet files. Compression is implicit
return "parquet"
elif ".sdf" in ext: # support compressed sdf files
return "sdf"
elif ".csv" in ext: # support compressed csv files
return "csv"
elif ".tsv" in ext: # support compressed tsv files
return "tsv"
elif ".pkl" in ext: # support compressed pickle files
return "pkl"
elif ext.endswith(".pt"): # Pytorch tensor files, used for storing index splits
return "pt"
else:
raise ValueError(f"unsupported file `{path}`")
@staticmethod
def _get_table_columns(path: str) -> List[str]:
"""
Get the columns of a table without reading all the data.
Might be slow to decompress the file if the file is compressed.
Parameters:
path: path to the table file
Returns:
List[str]: the column names
"""
datafile_type = BaseDataModule._get_data_file_type(path)
if datafile_type == "parquet":
# Read the schema of a parquet file
file = ParquetFile(path)
schema = file.pandas_metadata["columns"]
column_names = [s["name"] for s in schema if s["name"] is not None]
elif datafile_type == "sdf":
df = BaseDataModule._read_sdf(path, max_num_mols=5, discard_invalud=True, n_jobs=1)
column_names = df.columns
elif datafile_type in ["csv", "tsv"]:
# Read the schema of a csv / tsv file
df = BaseDataModule._read_csv(path, nrows=5)
column_names = df.columns
return column_names
@staticmethod
def _read_parquet(path, **kwargs):
kwargs.pop("dtype", None) # Only useful for csv
column_names = BaseDataModule._get_table_columns(path)
# Change the 'usecols' parameter to 'columns'
columns = kwargs.pop("columns", None)
if "usecols" in kwargs.keys():
assert columns is None, "Ambiguous value of `columns`"
columns = kwargs.pop("usecols")
if columns is None:
columns = column_names
for column in columns:
assert (
column in column_names
), f"Column `{column}` is not in the parquet file with columns {column_names}"
# Read the parquet file per column, and convert the data to float16 to reduce memory consumption
all_series = {}
progress = tqdm(columns)
for col in progress:
# Read single column
progress.set_description(f"Reading parquet column `{col}`")
this_series = pd.read_parquet(path, columns=[col], engine="fastparquet", **kwargs)[col]
# Check if the data is float
first_elem = this_series.values[0]
is_float = False
if isinstance(first_elem, (list, tuple)):
is_float = isinstance(first_elem[0], float) or (first_elem[0].dtype.kind == "f")
elif isinstance(first_elem, np.ndarray):
is_float = isinstance(first_elem, float) or (first_elem.dtype.kind == "f")
# Convert floats to float16
if is_float:
if isinstance(first_elem, (np.ndarray, list)):
this_series.update(
pd.Series([np.asarray(elem).astype(np.float16) for elem in this_series])
)
else:
this_series = this_series.astype(np.float16)
all_series[col] = this_series
gc.collect() # Reset memory after each column
# Merge columns into a dataframe
df = pd.concat(all_series, axis=1)
return df
@staticmethod
def _read_sdf(path: str, mol_col_name: str = "_rdkit_molecule_obj", **kwargs):
r"""
read a given sdf file into a pandas dataframe
uses datamol.read_sdf to read the sdf file
Parameters:
path: path to the sdf file
mol_col_name: name of the column containing the molecule object
kwargs: arguments to pass to datamol.read_sdf
Returns:
pandas dataframe containing the molecules and their conformer coordinates from the sdf file
Note: please change mol_col_name to be the column name corresoonding to the molecule object in your sdf file
"""
# Set default arguments for reading the SDF
kwargs.setdefault("smiles_column", "smiles")
kwargs.setdefault("sanitize", False)
kwargs.setdefault("include_private", True)
kwargs.setdefault("include_computed", True)
kwargs.setdefault("remove_hs", True)
kwargs.setdefault("n_jobs", -1)
kwargs.setdefault("max_num_mols", kwargs.pop("sample_size", None))
kwargs.setdefault("discard_invalid", False)
# Get the interesting columns
mol_cols = mol_col_name # adjust this to be the column name in your sdf file
kwargs.setdefault("mol_column", mol_cols)
usecols = kwargs.pop("usecols", None)
dtype = kwargs.pop("dtype", None)
smiles_col = kwargs["smiles_column"]
# Read the SDF
df = dm.read_sdf(path, as_df=True, **kwargs)
# Keep only the columns needed
if usecols is not None:
df = df[usecols + [mol_cols]]
# Convert the dtypes
if dtype is not None:
label_columns = list(set(usecols) - set([mol_cols, smiles_col]))
dtype_mapper = {col: dtype for col in label_columns}
df.astype(dtype=dtype_mapper, copy=False)
return df
@staticmethod
def _glob(path: str) -> List[str]:
"""
glob a given path
Parameters:
path: path to glob
Returns:
List[str]: list of paths
"""
files = dm.fs.glob(path)
files = [f.replace("file://", "") for f in files]
return files
def _read_table(self, path: str, **kwargs) -> pd.DataFrame:
"""
a general read file function which determines if which function to use, either _read_csv or _read_parquet
Parameters:
path: path to the file to read
kwargs: keyword arguments for pd.read_csv or pd.read_parquet
Returns:
pd.DataFrame: the panda dataframe storing molecules
"""
files = self._glob(path)
if len(files) == 0:
raise FileNotFoundError(f"No such file or directory `{path}`")
if len(files) > 1:
files = tqdm(sorted(files), desc=f"Reading files at `{path}`")
dfs = []
for file in files:
file_type = self._get_data_file_type(file)
if file_type == "parquet":
df = self._read_parquet(file, **kwargs)
elif file_type == "sdf": # support compressed sdf files
df = self._read_sdf(file, **kwargs)
elif file_type in ["csv", "tsv"]: # support compressed csv and tsv files
df = self._read_csv(file, **kwargs)
else:
raise ValueError(f"unsupported file `{file}`")
dfs.append(df)
return pd.concat(dfs, ignore_index=True)
def get_dataloader_kwargs(self, stage: RunningStage, shuffle: bool, **kwargs) -> Dict[str, Any]:
"""
Get the options for the dataloader depending on the current stage.
Parameters:
stage: Whether in Training, Validating, Testing, Sanity-checking, Predicting, or Tuning phase.
shuffle: set to ``True`` to have the data reshuffled at every epoch.
Returns:
Arguments to pass to the `DataLoader` during initialization
"""
loader_kwargs = {}
# Get batch size and IPU options for training set
# if stage in [RunningStage.TRAINING, RunningStage.TUNING]:
if stage in [RunningStage.TRAINING]:
loader_kwargs["batch_size"] = self.batch_size_training
# Get batch size and IPU options for validation / testing sets
elif stage in [RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]:
loader_kwargs["batch_size"] = self.batch_size_inference
else:
raise ValueError(f"Wrong value for `stage`. Provided `{stage}`")
# Set default parameters
loader_kwargs["shuffle"] = shuffle
loader_kwargs["collate_fn"] = self.collate_fn
loader_kwargs["num_workers"] = self.get_num_workers
loader_kwargs["pin_memory"] = self.pin_memory
loader_kwargs["persistent_workers"] = self.persistent_workers
loader_kwargs["multiprocessing_context"] = self.multiprocessing_context
# Update from provided parameters
loader_kwargs.update(**kwargs)
return loader_kwargs
def get_dataloader(self, dataset: Dataset, shuffle: bool, stage: RunningStage) -> DataLoader:
"""
Get the dataloader for a given dataset
Parameters:
dataset: The dataset from which to load the data
shuffle: set to ``True`` to have the data reshuffled at every epoch.
stage: Whether in Training, Validating, Testing, Sanity-checking, Predicting, or Tuning phase.
Returns:
The dataloader to sample from
"""
kwargs = self.get_dataloader_kwargs(stage=stage, shuffle=shuffle)
return self._dataloader(dataset=dataset, shuffle=shuffle, stage=stage, **kwargs)
def _dataloader(self, dataset: Dataset, **kwargs) -> DataLoader:
r"""
Get a dataloader for a given dataset
Parameters:
dataset: The dataset from which to load the data
kwargs: keyword arguments for DataLoader
Returns:
The dataloader to sample from
"""
loader = DataLoader(
dataset=dataset,
**kwargs,
)
return loader
def get_max_num_nodes_datamodule(self, stages: Optional[List[str]] = None) -> int:
"""
Get the maximum number of nodes across all datasets from the datamodule
Parameters:
datamodule: The datamodule from which to extract the maximum number of nodes
stages: The stages from which to extract the max num nodes.
Possible values are ["train", "val", "test", "predict"].
If None, all stages are considered.
Returns:
max_num_nodes: The maximum number of nodes across all datasets from the datamodule
"""
allowed_stages = ["train", "val", "test", "predict"]
if stages is None:
stages = allowed_stages
for stage in stages:
assert stage in allowed_stages, f"stage value `{stage}` not allowed."
max_num_nodes = 0
# Max number of nodes in the training dataset
if (self.train_ds is not None) and ("train" in stages):
logger.info("Max num nodes being calcuated train")
max_num_nodes = max(max_num_nodes, self.train_ds.max_num_nodes_per_graph)
# Max number of nodes in the validation dataset
if (
(self.val_ds is not None)
and ("val" in stages)
and (self.val_ds.max_num_nodes_per_graph is not None)
):
logger.info("Max num nodes being calcuated val")
max_num_nodes = max(max_num_nodes, self.val_ds.max_num_nodes_per_graph)
# Max number of nodes in the test dataset
if (
(self.test_ds is not None)
and ("test" in stages)
and (self.test_ds.max_num_nodes_per_graph is not None)
):
logger.info("Max num nodes being calcuated test")
max_num_nodes = max(max_num_nodes, self.test_ds.max_num_nodes_per_graph)
# Max number of nodes in the predict dataset
if (
(self.predict_ds is not None)
and ("predict" in stages)
and (self.predict_ds.max_num_nodes_per_graph is not None)
):
max_num_nodes = max(max_num_nodes, self.predict_ds.max_num_nodes_per_graph)
return max_num_nodes
def get_max_num_edges_datamodule(self, stages: Optional[List[str]] = None) -> int:
"""
Get the maximum number of edges across all datasets from the datamodule
Parameters:
datamodule: The datamodule from which to extract the maximum number of nodes
stages: The stages from which to extract the max num nodes.
Possible values are ["train", "val", "test", "predict"].
If None, all stages are considered.
Returns:
max_num_edges: The maximum number of edges across all datasets from the datamodule
"""
allowed_stages = ["train", "val", "test", "predict"]
if stages is None:
stages = allowed_stages
for stage in stages:
assert stage in allowed_stages, f"stage value `{stage}` not allowed."
max_num_edges = 0
# Max number of nodes/edges in the training dataset
if (
(self.train_ds is not None)
and ("train" in stages)
and (self.train_ds.max_num_edges_per_graph is not None)
):
max_num_edges = max(max_num_edges, self.train_ds.max_num_edges_per_graph)
# Max number of nodes/edges in the validation dataset
if (
(self.val_ds is not None)
and ("val" in stages)
and (self.val_ds.max_num_edges_per_graph is not None)
):
max_num_edges = max(max_num_edges, self.val_ds.max_num_edges_per_graph)
# Max number of nodes/edges in the test dataset
if (
(self.test_ds is not None)
and ("test" in stages)
and (self.test_ds.max_num_edges_per_graph is not None)
):
max_num_edges = max(max_num_edges, self.test_ds.max_num_edges_per_graph)
# Max number of nodes/edges in the predict dataset
if (
(self.predict_ds is not None)
and ("predict" in stages)
and (self.predict_ds.max_num_edges_per_graph is not None)
):
max_num_edges = max(max_num_edges, self.predict_ds.max_num_edges_per_graph)
return max_num_edges
@dataclass
class DatasetProcessingParams:
def __init__(
self,
task_level: Optional[str] = None,
df: Optional[pd.DataFrame] = None,
df_path: Optional[Union[str, os.PathLike, List[Union[str, os.PathLike]]]] = None,
smiles_col: Optional[str] = None,
label_cols: List[str] = None,
weights_col: Optional[str] = None, # Not needed
weights_type: Optional[str] = None, # Not needed
idx_col: Optional[str] = None,
mol_ids_col: Optional[str] = None,
sample_size: Union[int, float, Type[None]] = None,
split_val: float = 0.2,
split_test: float = 0.2,
seed: int = None,
epoch_sampling_fraction: float = 1.0,
splits_path: Optional[Union[str, os.PathLike]] = None,
split_names: Optional[List[str]] = ["train", "val", "test"],
label_normalization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None,
):
"""
object to store the parameters for the dataset processing
Parameters:
task_level: The task level, wether it is graph, node, edge or nodepair
df: The dataframe containing the data
df_path: The path to the dataframe containing the data. If list, will read all files, sort them alphabetically and concatenate them.
smiles_col: The column name of the smiles
label_cols: The column names of the labels
weights_col: The column name of the weights
weights_type: The type of weights
idx_col: The column name of the indices
mol_ids_col: The column name of the molecule ids
sample_size: The size of the sample
split_val: The fraction of the data to use for validation
split_test: The fraction of the data to use for testing
seed: The seed to use for the splits and subsampling
splits_path: The path to the splits, or a dictionary with the splits
"""
if df is None and df_path is None:
raise ValueError("Either `df` or `df_path` must be provided")
if epoch_sampling_fraction <= 0 or epoch_sampling_fraction > 1:
raise ValueError("The value of epoch_sampling_fraction must be in the range of (0, 1].")
self.df = df
self.task_level = task_level
self.df_path = df_path
self.smiles_col = smiles_col
self.label_cols = label_cols
self.weights_col = weights_col
self.weights_type = weights_type
self.idx_col = idx_col
self.mol_ids_col = mol_ids_col
self.sample_size = sample_size
self.split_val = split_val
self.split_test = split_test
self.seed = seed
self.splits_path = splits_path
self.split_names = split_names
self.label_normalization = label_normalization
self.epoch_sampling_fraction = epoch_sampling_fraction
class IPUDataModuleModifier:
def __init__(
self,
ipu_inference_opts: Optional["poptorch.Options"] = None,
ipu_training_opts: Optional["poptorch.Options"] = None,
ipu_dataloader_training_opts: Optional["IPUDataloaderOptions"] = None,
ipu_dataloader_inference_opts: Optional["IPUDataloaderOptions"] = None,
*args,
**kwargs,
) -> None:
r"""
wrapper functions from the a `DataModule` to support IPU and IPU options To be used in dual inheritance, for example:
```
IPUDataModule(BaseDataModule, IPUDataModuleModifier):
def __init__(self, **kwargs):
BaseDataModule.__init__(self, **kwargs)
IPUDataModuleModifier.__init__(self, **kwargs)
```
Parameters:
ipu_inference_opts: Options for the IPU in inference mode. Ignore if not using IPUs
ipu_training_opts: Options for the IPU in training mode. Ignore if not using IPUs
ipu_dataloader_kwargs_train_val: Options for the dataloader for the IPU. Ignore if not using IPUs
ipu_dataloader_kwargs_test: Options for the dataloader for the IPU. Ignore if not using IPUs
args: Arguments for the `DataModule`
kwargs: Keyword arguments for the `DataModule`
"""
self.ipu_inference_opts = ipu_inference_opts
self.ipu_training_opts = ipu_training_opts
self.ipu_dataloader_training_opts = ipu_dataloader_training_opts
self.ipu_dataloader_inference_opts = ipu_dataloader_inference_opts
def _dataloader(self, dataset: Dataset, **kwargs) -> "poptorch.DataLoader":
"""
Get a poptorch dataloader for a given dataset
Parameters:
dataset: The dataset to use
kwargs: Keyword arguments for the dataloader
Returns:
The poptorch dataloader
"""
# Use regular Dataloader if no IPUs
if ("ipu_options" not in kwargs.keys()) or (kwargs["ipu_options"] is None):
raise ValueError(f"No IPU options provided.")
# Initialize the IPU dataloader
from graphium.ipu.ipu_dataloader import create_ipu_dataloader
loader = create_ipu_dataloader(
dataset=dataset,
**kwargs,
)
return loader
class MultitaskFromSmilesDataModule(BaseDataModule, IPUDataModuleModifier):
def __init__(
self,
task_specific_args: Union[Dict[str, DatasetProcessingParams], Dict[str, Any]],
processed_graph_data_path: Optional[Union[str, os.PathLike]] = None,
dataloading_from: str = "ram",
featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None,
batch_size_training: int = 16,
batch_size_inference: int = 16,
batch_size_per_pack: Optional[int] = None,
num_workers: int = 0,
pin_memory: bool = True,
persistent_workers: bool = False,
multiprocessing_context: Optional[str] = None,
featurization_n_jobs: int = -1,
featurization_progress: bool = False,
featurization_backend: str = "loky",
featurization_batch_size: int = 1000,
collate_fn: Optional[Callable] = None,
prepare_dict_or_graph: str = "pyg:graph",
**kwargs,
):
"""
only for parameters beginning with task_*, we have a dictionary where the key is the task name
and the value is specified below.
Parameters:
task_specific_args: A dictionary where the key is the task name (for the multi-task setting), and
the value is a `DatasetProcessingParams` object. The `DatasetProcessingParams` object
contains multiple parameters to define how to load and process the files, such as:
- `task_level`
- `df`
- `df_path`
- `smiles_col`
- `label_cols`
dataloading_from: Whether to load the data from RAM or from disk. If set to "disk", the data
must have been previously cached with `processed_graph_data_path` set. If set to "ram", the data
will be loaded in RAM and the `processed_graph_data_path` will be ignored.
featurization: args to apply to the SMILES to Graph featurizer.
batch_size_training: batch size for training and val dataset.
batch_size_inference: batch size for test dataset.
num_workers: Number of workers for the dataloader. Use -1 to use all available
cores.
pin_memory: Whether to pin on paginated CPU memory for the dataloader.
featurization_n_jobs: Number of cores to use for the featurization.
featurization_progress: whether to show a progress bar during featurization.
featurization_backend: The backend to use for the molecular featurization.
- "multiprocessing": Found to cause less memory issues.
- "loky": joblib's Default. Found to cause memory leaks.
- "threading": Found to be slow.
featurization_batch_size: Batch size to use for the featurization.
collate_fn: A custom torch collate function. Default is to `graphium.data.graphium_collate_fn`
prepare_dict_or_graph: Whether to preprocess all molecules as Graph dict or PyG graphs.
Possible options:
- "pyg:dict": Process molecules as a `dict`. It's faster and requires less RAM during
pre-processing. It is slower during training with with `num_workers=0` since
pyg `Data` will be created during data-loading, but faster with large
`num_workers`, and less likely to cause memory issues with the parallelization.
- "pyg:graph": Process molecules as `pyg.data.Data`.
"""
BaseDataModule.__init__(
self,
batch_size_training=batch_size_training,
batch_size_inference=batch_size_inference,
batch_size_per_pack=batch_size_per_pack,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
multiprocessing_context=multiprocessing_context,
collate_fn=collate_fn,
)
IPUDataModuleModifier.__init__(self, **kwargs)
self.task_specific_args = task_specific_args
self.task_dataset_processing_params = {}
for task, ds_args in task_specific_args.items():
if not isinstance(ds_args, DatasetProcessingParams):
# This is needed as long as not all classes have been migrated
# to use the new `DatasetProcessingParams` class
ds_args = DatasetProcessingParams(**ds_args)
key = self._get_task_key(ds_args.task_level, task)
self.task_dataset_processing_params[key] = ds_args
self.sampler_task_dict = {
task: self.task_dataset_processing_params[task].epoch_sampling_fraction
for task in self.task_dataset_processing_params.keys()
}
self.featurization_n_jobs = featurization_n_jobs
self.featurization_progress = featurization_progress
self.featurization_backend = featurization_backend
self.featurization_batch_size = featurization_batch_size
self.task_train_indices = None
self.task_val_indices = None
self.task_test_indices = None
self.single_task_datasets = None
self.train_singletask_datasets = None
self.val_singletask_datasets = None
self.test_singletask_datasets = None
self.train_ds = None
self.val_ds = None
self.test_ds = None
self._parse_caching_args(processed_graph_data_path, dataloading_from)
self.task_norms = {}
if featurization is None:
featurization = {}
self.featurization = featurization
# Whether to transform the smiles into a pyg `Data` graph or a dictionary compatible with pyg
if prepare_dict_or_graph == "pyg:dict":
self.smiles_transformer = partial(mol_to_graph_dict, **featurization)
elif prepare_dict_or_graph == "pyg:graph":
self.smiles_transformer = partial(mol_to_pyggraph, **featurization)
else:
raise ValueError(
f"`prepare_dict_or_graph` should be either 'pyg:dict' or 'pyg:graph', Provided: `{prepare_dict_or_graph}`"
)
self.data_hash = self.get_data_hash()
if self.processed_graph_data_path is not None:
if self._ready_to_load_all_from_file():
self._data_is_prepared = True
self._data_is_cached = True
def _parse_caching_args(self, processed_graph_data_path, dataloading_from):
"""
Parse the caching arguments, and raise errors if the arguments are invalid.
"""
# Whether to load the data from RAM or from disk
dataloading_from = dataloading_from.lower()
if dataloading_from not in ["disk", "ram"]:
raise ValueError(
f"`dataloading_from` should be either 'disk' or 'ram', Provided: `{dataloading_from}`"
)
# If loading from disk, the path to the cached data must be provided
if dataloading_from == "disk" and processed_graph_data_path is None:
raise ValueError(
"When `dataloading_from` is 'disk', `processed_graph_data_path` must be provided."
)
self.processed_graph_data_path = processed_graph_data_path
self.dataloading_from = dataloading_from
def _get_task_key(self, task_level: str, task: str):
task_prefix = f"{task_level}_"
if not task.startswith(task_prefix):
task = task_prefix + task
return task
def get_task_levels(self):
task_level_map = {}
for task, task_args in self.task_specific_args.items():
if isinstance(task_args, DatasetProcessingParams):
task_args = task_args.__dict__ # Convert the class to a dictionary
task_level_map.update({task: task_args["task_level"]})
return task_level_map
def prepare_data(self, save_smiles_and_ids: bool = False):
"""Called only from a single process in distributed settings. Steps:
- If each cache is set and exists, reload from cache and return. Otherwise,
- For each single-task dataset:
- Load its dataframe from a path (if provided)
- Subsample the dataframe
- Extract the smiles, labels from the dataframe
- In the previous step, we were also able to get the unique smiles, which we use to compute the features
- For each single-task dataframe and associated data (smiles, labels, etc.):
- Filter out the data corresponding to molecules which failed featurization.
- Create a corresponding SingletaskDataset
- Split the SingletaskDataset according to the task-specific splits for train, val and test
"""
def has_atoms_after_h_removal(smiles):
# Remove all 'H' characters from the SMILES
smiles_without_h = re.sub("H", "", smiles)
# Check if any letters are remaining in the modified string
has_atoms = bool(re.search("[a-zA-Z]", smiles_without_h))
if has_atoms == False:
logger.info(f"Removed Hydrogen molecule: {smiles}")
return has_atoms
if self._data_is_prepared:
logger.info("Data is already prepared.")
self.get_label_statistics(self.processed_graph_data_path, self.data_hash, dataset=None)
return
"""Load all single-task dataframes."""
task_df = {}
for task, args in self.task_dataset_processing_params.items():
if args.label_normalization is None:
args.label_normalization = {}
label_normalization = LabelNormalization(**args.label_normalization)
logger.info(f"Reading data for task '{task}'")
if args.df is None:
# Only load the useful columns, as some datasets can be very large when loading all columns.
label_cols = self._parse_label_cols(
df=None, df_path=args.df_path, label_cols=args.label_cols, smiles_col=args.smiles_col
)
usecols = (
check_arg_iterator(args.smiles_col, enforce_type=list)
+ label_cols
+ check_arg_iterator(args.idx_col, enforce_type=list)
+ check_arg_iterator(args.weights_col, enforce_type=list)
)
label_dtype = {col: np.float32 for col in label_cols}
task_df[task] = self._read_table(args.df_path, usecols=usecols, dtype=label_dtype)
else:
label_cols = self._parse_label_cols(
df=args.df, df_path=None, label_cols=args.label_cols, smiles_col=args.smiles_col
)
task_df[task] = args.df
task_df[task] = task_df[task]
args.label_cols = label_cols
self.task_norms[task] = label_normalization
logger.info("Done reading datasets")
"""Subsample the data frames and extract the necessary data to create SingleTaskDatasets for each task (smiles, labels, extras)."""
task_dataset_args = {}
for task in task_df.keys():
task_dataset_args[task] = {}
for task, df in task_df.items():
# Subsample all the dataframes
sample_size = self.task_dataset_processing_params[task].sample_size
df = self._sub_sample_df(df, sample_size, self.task_dataset_processing_params[task].seed)
logger.info(f"Prepare single-task dataset for task '{task}' with {len(df)} data points.")
logger.info("Filtering the molecules for Hydrogen")
logger.info(f"Looking at column {df.columns[0]}")
logger.info("Filtering done")
# Extract smiles, labels, extras
args = self.task_dataset_processing_params[task]
smiles, labels, sample_idx, extras = self._extract_smiles_labels(
df,
task_level=args.task_level,
smiles_col=args.smiles_col,
label_cols=args.label_cols,
idx_col=args.idx_col,
weights_col=args.weights_col,
weights_type=args.weights_type,
)
# Store the relevant information for each task's dataset
task_dataset_args[task]["smiles"] = smiles
task_dataset_args[task]["labels"] = labels
task_dataset_args[task]["sample_idx"] = sample_idx
task_dataset_args[task]["extras"] = extras
"""Convert SMILES to features (graphs, fingerprints, etc.) for the unique molecules found."""
all_smiles = []
all_tasks = []
idx_per_task = {}
total_len = 0
for task, dataset_args in task_dataset_args.items():
all_smiles.extend(dataset_args["smiles"])
num_smiles = len(dataset_args["smiles"])
idx_per_task[task] = (total_len, total_len + num_smiles)
total_len += num_smiles
for count in range(len(dataset_args["smiles"])):
all_tasks.append(task)
# Get all unique mol ids
all_unique_mol_ids = smiles_to_unique_mol_ids(
all_smiles,
n_jobs=self.featurization_n_jobs,
featurization_batch_size=self.featurization_batch_size,
backend=self.featurization_backend,
)
_, unique_ids_idx, unique_ids_inv = np.unique(
all_unique_mol_ids, return_index=True, return_inverse=True
)
smiles_to_featurize = [all_smiles[ii] for ii in unique_ids_idx]
# Convert SMILES to features
features, _ = self._featurize_molecules(smiles_to_featurize)
# Store the features (including Nones, which will be filtered in the next step)
for task in task_dataset_args.keys():
task_dataset_args[task]["features"] = []
task_dataset_args[task]["idx_none"] = []
# Create a list of features matching up with the original smiles
all_features = [features[unique_idx] for unique_idx in unique_ids_inv]
# Add the features to the task-specific data
for all_idx, task in enumerate(all_tasks):
task_dataset_args[task]["features"].append(all_features[all_idx])
"""Filter data based on molecules which failed featurization. Create single task datasets as well."""
self.single_task_datasets = {}
for task, args in task_dataset_args.items():
# Find out which molecule failed featurization, and filter them out
idx_none = []
for idx, (feat, labels, smiles) in enumerate(
zip(args["features"], args["labels"], args["smiles"])
):
if did_featurization_fail(feat) or found_size_mismatch(task, feat, labels, smiles):
idx_none.append(idx)
this_unique_ids = all_unique_mol_ids[idx_per_task[task][0] : idx_per_task[task][1]]
df, features, smiles, labels, sample_idx, extras, this_unique_ids = self._filter_none_molecules(
idx_none,
task_df[task],
args["features"],
args["smiles"],
args["labels"],
args["sample_idx"],
args["extras"],
this_unique_ids,
)
task_dataset_args[task]["smiles"] = smiles
task_dataset_args[task]["labels"] = labels
task_dataset_args[task]["features"] = features
task_dataset_args[task]["sample_idx"] = sample_idx
task_dataset_args[task]["extras"] = extras
# We have the necessary components to create single-task datasets.
self.single_task_datasets[task] = Datasets.SingleTaskDataset(
features=task_dataset_args[task]["features"],
labels=task_dataset_args[task]["labels"],
smiles=task_dataset_args[task]["smiles"],
unique_ids=this_unique_ids,
indices=task_dataset_args[task]["sample_idx"],
**task_dataset_args[task]["extras"],
)
"""We split the data up to create train, val and test datasets"""
self.task_train_indices = {}
self.task_val_indices = {}
self.task_test_indices = {}
for task, df in task_df.items():
train_indices, val_indices, test_indices = self._get_split_indices(
len(df),
split_val=self.task_dataset_processing_params[task].split_val,
split_test=self.task_dataset_processing_params[task].split_test,
split_seed=self.task_dataset_processing_params[task].seed,
splits_path=self.task_dataset_processing_params[task].splits_path,
split_names=self.task_dataset_processing_params[task].split_names,
sample_idx=task_dataset_args[task]["sample_idx"],
)
self.task_train_indices[task] = train_indices
self.task_val_indices[task] = val_indices
self.task_test_indices[task] = test_indices
(
self.train_singletask_datasets,
self.val_singletask_datasets,
self.test_singletask_datasets,
) = self.get_subsets_of_datasets(
self.single_task_datasets, self.task_train_indices, self.task_val_indices, self.task_test_indices
)
if self.processed_graph_data_path is not None:
self._save_data_to_files(save_smiles_and_ids)
self._data_is_cached = True
self._data_is_prepared = True
def setup(
self,
stage: str = None,
save_smiles_and_ids: bool = False,
):
"""
Prepare the torch dataset. Called on every GPUs. Setting state here is ok.
Parameters:
stage (str): Either 'fit', 'test', or None.
"""
# Can possibly get rid of setup because a single dataset will have molecules exclusively in train, val or test
# Produce the label sizes to update the collate function
labels_size = {}
labels_dtype = {}
if stage == "fit" or stage is None:
if self.train_ds is None:
self.train_ds = self._make_multitask_dataset(
self.dataloading_from, "train", save_smiles_and_ids=save_smiles_and_ids
)
if self.val_ds is None:
self.val_ds = self._make_multitask_dataset(
self.dataloading_from, "val", save_smiles_and_ids=save_smiles_and_ids
)
logger.info(self.train_ds)
logger.info(self.val_ds)
labels_size.update(
self.train_ds.labels_size
) # Make sure that all task label sizes are contained in here. Maybe do the update outside these if statements.
labels_size.update(self.val_ds.labels_size)
labels_dtype.update(self.train_ds.labels_dtype)
labels_dtype.update(self.val_ds.labels_dtype)
if stage == "test" or stage is None:
if self.test_ds is None:
self.test_ds = self._make_multitask_dataset(
self.dataloading_from, "test", save_smiles_and_ids=save_smiles_and_ids
)
logger.info(self.test_ds)
labels_size.update(self.test_ds.labels_size)
labels_dtype.update(self.test_ds.labels_dtype)
default_labels_size_dict = self.collate_fn.keywords.get("labels_size_dict", None)
if default_labels_size_dict is None:
self.collate_fn.keywords["labels_size_dict"] = labels_size
default_labels_dtype_dict = self.collate_fn.keywords.get("labels_dtype_dict", None)
if default_labels_dtype_dict is None:
self.collate_fn.keywords["labels_dtype_dict"] = labels_dtype
def _make_multitask_dataset(
self,
dataloading_from: Literal["disk", "ram"],
stage: Literal["train", "val", "test"],
save_smiles_and_ids: bool,
) -> Datasets.MultitaskDataset:
"""
Create a MultitaskDataset for the given stage using single task datasets
The single task datasets must exist before this can be used
Parameters:
stage: Stage to create multitask dataset for
save_smiles_and_ids: Whether to save SMILES strings and unique IDs
processed_graph_data_path: path to save and load processed graph data from
"""
allowed_stages = ["train", "val", "test"]
assert stage in allowed_stages, f"Multitask dataset stage `{stage}` not in {allowed_stages}"
if stage == "train":
singletask_datasets = self.train_singletask_datasets
about = "training set"
elif stage == "val":
singletask_datasets = self.val_singletask_datasets
about = "validation set"
elif stage == "test":
singletask_datasets = self.test_singletask_datasets
about = "test set"
else:
raise ValueError(f"Unknown stage {stage}")
processed_graph_data_path = self.processed_graph_data_path
multitask_dataset = Datasets.MultitaskDataset(
singletask_datasets,
n_jobs=self.featurization_n_jobs,
backend=self.featurization_backend,
featurization_batch_size=self.featurization_batch_size,
progress=self.featurization_progress,
about=about,
save_smiles_and_ids=save_smiles_and_ids,
data_path=self._path_to_load_from_file(stage) if processed_graph_data_path else None,
dataloading_from=dataloading_from,
data_is_cached=self._data_is_cached,
) # type: ignore
# calculate statistics for the train split and used for all splits normalization
if stage == "train":
self.get_label_statistics(
self.processed_graph_data_path, self.data_hash, multitask_dataset, train=True
)
# Normalization has already been applied in cached data
if not self._data_is_prepared:
self.normalize_label(multitask_dataset, stage)
return multitask_dataset
def _ready_to_load_all_from_file(self) -> bool:
"""
Check if the data for all stages is ready to be loaded from files
"""
paths = [self._path_to_load_from_file(stage) for stage in ["train", "val", "test"]]
ready = all(self._data_ready_at_path(path) for path in paths)
return ready
def _path_to_load_from_file(self, stage: Literal["train", "val", "test"]) -> Optional[str]:
"""
Get path from which to load the data from files
"""
if self.processed_graph_data_path is None:
return None
return osp.join(self.processed_graph_data_path, f"{stage}_{self.data_hash}")
def _data_ready_at_path(self, path: str) -> bool:
"""
Check if data can be loaded from this path
"""
can_load_from_file = osp.exists(path) and self.get_folder_size(path) > 0
return can_load_from_file
def _save_data_to_files(self, save_smiles_and_ids: bool = False) -> None:
"""
Save data to files so that they can be loaded from file during training/validation/test
"""
stages = ["train", "val", "test"]
# At the moment, we need to merge the `SingleTaskDataset`'s into `MultitaskDataset`s in order to save to file
# This is because the combined labels need to be stored together. We can investigate not doing this if this is a problem
temp_datasets = {
stage: self._make_multitask_dataset(
dataloading_from="ram", stage=stage, save_smiles_and_ids=save_smiles_and_ids
)
for stage in stages
}
for stage in stages:
self.save_featurized_data(temp_datasets[stage], self._path_to_load_from_file(stage))
temp_datasets[stage].save_metadata(self._path_to_load_from_file(stage))
# self.train_ds, self.val_ds, self.test_ds will be created during `setup()`
if self.dataloading_from == "disk":
del temp_datasets
else:
self.train_ds = temp_datasets["train"]
self.val_ds = temp_datasets["val"]
self.test_ds = temp_datasets["test"]
def get_folder_size(self, path):
# check if the data items are actually saved into the folders
return sum(os.path.getsize(osp.join(path, f)) for f in os.listdir(path))
def calculate_statistics(self, dataset: Datasets.MultitaskDataset, train: bool = False):
"""
Calculate the statistics of the labels for each task, and overwrites the `self.task_norms` attribute.
Parameters:
dataset: the dataset to calculate the statistics from
train: whether the dataset is the training set
"""
if self.task_norms and train:
for task in dataset.labels_size.keys():
# if the label type is graph_*, we need to stack them as the tensor shape is (num_labels, )
if task.startswith("graph"):
labels = np.stack(
np.array([datum["labels"][task] for datum in dataset if task in datum["labels"]]),
axis=0,
)
# for other tasks with node_ and edge_, the label shape is [num_nodes/num_edges, num_labels]
# we can concatenate them directly
else:
labels = np.concatenate(
[datum["labels"][task] for datum in dataset if task in datum["labels"]], axis=0
)
self.task_norms[task].calculate_statistics(labels)
def get_label_statistics(
self,
data_path: Union[str, os.PathLike],
data_hash: str,
dataset: Datasets.MultitaskDataset,
train: bool = False,
):
"""
Get the label statistics from the dataset, and save them to file, if needed.
`self.task_norms` will be modified in-place with the label statistics.
Parameters:
data_path: the path to save and load the label statistics to. If None, no saving and loading will be done.
data_hash: the hash of the dataset generated by `get_data_hash()`
dataset: the dataset to calculate the statistics from
train: whether the dataset is the training set
"""
if data_path is None:
self.calculate_statistics(dataset, train=train)
else:
path_with_hash = os.path.join(data_path, data_hash)
os.makedirs(path_with_hash, exist_ok=True)
filename = os.path.join(path_with_hash, "task_norms.pkl")
if self.task_norms and train and not os.path.isfile(filename):
self.calculate_statistics(dataset, train=train)
torch.save(self.task_norms, filename, pickle_protocol=4)
# if any of the above three condition does not satisfy, we load from file.
else:
self.task_norms = torch.load(filename)
def normalize_label(self, dataset: Datasets.MultitaskDataset, stage) -> Datasets.MultitaskDataset:
"""
Normalize the labels in the dataset using the statistics in `self.task_norms`.
Parameters:
dataset: the dataset to normalize the labels from
Returns:
the dataset with normalized labels
"""
for task in dataset.labels_size.keys():
# we normalize the dataset if (it is train split) or (it is val/test splits and normalize_val_test is set to true)
if (stage == "train") or (stage in ["val", "test"] and self.task_norms[task].normalize_val_test):
for i in range(len(dataset)):
if task in dataset[i]["labels"]:
dataset[i]["labels"][task] = self.task_norms[task].normalize(
dataset[i]["labels"][task]
)
return dataset
def save_featurized_data(self, dataset: Datasets.MultitaskDataset, processed_data_path):
os.makedirs(processed_data_path) # In case the len(dataset) is 0
for i in range(0, len(dataset), 1000):
os.makedirs(os.path.join(processed_data_path, format(i // 1000, "04d")), exist_ok=True)
process_params = [(index, datum, processed_data_path) for index, datum in enumerate(dataset)]
# Check if "about" is in the Dataset object
about = ""
if hasattr(dataset, "about"):
about = dataset.about
for param in tqdm(process_params, desc=f"Saving featurized data {about}"):
self.process_func(param)
return
def process_func(self, param):
index, datum, folder = param
filename = os.path.join(folder, format(index // 1000, "04d"), format(index, "07d") + ".pkl")
torch.save(
{"graph_with_features": datum["features"], "labels": datum["labels"]},
filename,
pickle_protocol=4,
)
return
def get_dataloader_kwargs(self, stage: RunningStage, shuffle: bool, **kwargs) -> Dict[str, Any]:
"""
Get the options for the dataloader depending on the current stage.
Parameters:
stage: Whether in Training, Validating, Testing, Sanity-checking, Predicting, or Tuning phase.
shuffle: set to ``True`` to have the data reshuffled at every epoch.
Returns:
Arguments to pass to the `DataLoader` during initialization
"""
loader_kwargs = super().get_dataloader_kwargs(stage=stage, shuffle=shuffle, **kwargs)
# Get batch size and IPU options for training set
# if stage in [RunningStage.TRAINING, RunningStage.TUNING]:
if stage in [RunningStage.TRAINING]:
loader_kwargs["ipu_dataloader_options"] = self.ipu_dataloader_training_opts
loader_kwargs["ipu_options"] = self.ipu_training_opts
# Get batch size and IPU options for validation / testing sets
elif stage in [RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]:
loader_kwargs["ipu_dataloader_options"] = self.ipu_dataloader_inference_opts
loader_kwargs["ipu_options"] = self.ipu_inference_opts
else:
raise ValueError(f"Wrong value for `stage`. Provided `{stage}`")
# Remove the IPU options if not available
if loader_kwargs["ipu_options"] is None:
loader_kwargs.pop("ipu_options")
if loader_kwargs["ipu_dataloader_options"] is not None:
logger.warning(
"`ipu_dataloader_options` will be ignored since it is provided without `ipu_options`."
)
loader_kwargs.pop("ipu_dataloader_options")
return loader_kwargs
def get_dataloader(
self, dataset: Dataset, shuffle: bool, stage: RunningStage
) -> Union[DataLoader, "poptorch.DataLoader"]:
"""
Get the poptorch dataloader for a given dataset
Parameters:
dataset: The dataset from which to load the data
shuffle: set to ``True`` to have the data reshuffled at every epoch.
stage: Whether in Training, Validating, Testing, Sanity-checking, Predicting, or Tuning phase.
Returns:
The poptorch dataloader to sample from
"""
kwargs = self.get_dataloader_kwargs(stage=stage, shuffle=shuffle)
sampler = None
# use sampler only when sampler_task_dict is set in the config and during training
if DatasetSubSampler.check_sampling_required(self.sampler_task_dict) and stage in [
RunningStage.TRAINING
]:
sampler = DatasetSubSampler(
dataset, self.sampler_task_dict, self.processed_graph_data_path, self.data_hash
)
# turn shuffle off when sampler is used as sampler option is mutually exclusive with shuffle
kwargs["shuffle"] = False
is_ipu = ("ipu_options" in kwargs.keys()) and (kwargs.get("ipu_options") is not None)
if is_ipu:
loader = IPUDataModuleModifier._dataloader(self, dataset=dataset, sampler=sampler, **kwargs)
else:
loader = BaseDataModule._dataloader(self, dataset=dataset, sampler=sampler, **kwargs)
return loader
def get_collate_fn(self, collate_fn):
if collate_fn is None:
# Some values become `inf` when changing data type. `mask_nan` deals with that
collate_fn = partial(
graphium_collate_fn,
mask_nan=0,
do_not_collate_keys=["smiles", "mol_ids"],
batch_size_per_pack=self.batch_size_per_pack,
)
collate_fn.__name__ = graphium_collate_fn.__name__
return collate_fn
# Cannot be used as is for the multitask version, because sample_idx does not apply.
def _featurize_molecules(self, smiles: Iterable[str]) -> Tuple[List, List]:
"""
Precompute the features (graphs, fingerprints, etc.) from the SMILES.
Features are computed from `self.smiles_transformer`.
A warning is issued to mention which molecules failed featurization.
Note:
(hadim): in case of very large dataset we could:
- or cache the data and read from it during `next(iter(dataloader))`
- or compute the features on-the-fly during `next(iter(dataloader))`
For now we compute in advance and hold everything in memory.
Parameters:
smiles: A list of all the molecular SMILES to featurize
sample_idx: The indexes corresponding to the sampled SMILES.
If not provided, computed from `numpy.arange`.
Returns:
features: A list of all the featurized molecules
idx_none: A list of the indexes that failed featurization
"""
batch_size = BatchingSmilesTransform.parse_batch_size(
numel=len(smiles),
desired_batch_size=self.featurization_batch_size,
n_jobs=self.featurization_n_jobs,
)
# Loop all the smiles and compute the features
features = dm.parallelized_with_batches(
BatchingSmilesTransform(self.smiles_transformer),
smiles,
batch_size=batch_size,
progress=True,
n_jobs=self.featurization_n_jobs,
backend=self.featurization_backend,
tqdm_kwargs={"desc": f"featurizing_smiles, batch={batch_size}"},
)
# Warn about None molecules
idx_none = [ii for ii, feat in enumerate(features) if did_featurization_fail(feat)]
if len(idx_none) > 0:
mols_to_msg = [
f"idx={idx} - smiles={smiles[idx]} - Error_msg[:-200]=\n{str(features[idx])[:-200]}"
for idx in idx_none
]
msg = "\n".join(mols_to_msg)
logger.warning(
(f"{len(idx_none)} molecules will be removed since they failed featurization:\n" + msg)
)
return features, idx_none
@staticmethod
def _filter_none_molecules(
idx_none: Iterable,
*args: Union[pd.DataFrame, pd.Series, np.ndarray, torch.Tensor, list, tuple, Dict[Any, Iterable]],
) -> List[Union[pd.DataFrame, pd.Series, np.ndarray, torch.Tensor, list, tuple, Dict[Any, Iterable]]]:
"""
Filter the molecules, labels, etc. for the molecules that failed featurization.
Parameters:
idx_none: A list of the indexes that failed featurization
args: Any argument from which to filter the failed SMILES.
Can be a `list`, `tuple`, `Tensor`, `np.array`, `Dict`, `pd.DataFrame`, `pd.Series`.
Otherwise, it is not filtered.
WARNING: If a `pd.DataFrame` or `pd.Series` is passed, it filters by the row indexes,
NOT by the `DataFrame.index` or `Series.index`! Be careful!
Returns:
out: All the `args` with the indexes from `idx_none` removed.
"""
if len(idx_none) == 0:
return args
idx_none = np.asarray(idx_none)
out = []
for arg in args:
if isinstance(arg, pd.DataFrame):
new = arg.drop(arg.index[idx_none], axis=0)
elif isinstance(arg, pd.Series):
new = arg.drop(arg.index[idx_none], axis=0)
elif isinstance(arg, np.ndarray):
new = np.delete(arg, idx_none, axis=0)
elif isinstance(arg, torch.Tensor):
not_none = torch.ones(arg.shape[0], dtype=bool)
not_none[idx_none] = False
new = arg[not_none]
elif isinstance(arg, (list, tuple)):
arg = list(arg)
new = [elem for ii, elem in enumerate(arg) if ii not in idx_none]
elif isinstance(arg, dict):
new = {}
for key, val in arg.items():
new[key] = MultitaskFromSmilesDataModule._filter_none_molecules(idx_none, val) # Careful
else:
new = arg
out.append(new)
out = tuple(out) if len(out) > 1 else out[0]
return out
def _parse_label_cols(
self,
df: pd.DataFrame,
df_path: Optional[Union[str, os.PathLike, List[Union[str, os.PathLike]]]],
label_cols: Union[Type[None], str, List[str]],
smiles_col: str,
) -> List[str]:
r"""
Parse the choice of label columns depending on the type of input.
The input parameters `label_cols` and `smiles_col` are described in
the `__init__` method.
Parameters:
df: The dataframe containing the labels.
df_path: The path to the dataframe containing the labels. If list, the first file is used.
label_cols: The columns to use as labels.
smiles_col: The column to use as SMILES
Returns:
the parsed label columns
"""
if df is None:
files = self._glob(df_path)
if len(files) == 0:
raise FileNotFoundError(f"No such file or directory `{df_path}`")
cols = BaseDataModule._get_table_columns(files[0])
for file in files[1:]:
_cols = BaseDataModule._get_table_columns(file)
if set(cols) != set(_cols):
raise RuntimeError(
f"Multiple data files have different columns. \nColumn set 1: {cols}\nColumn set 2: {_cols}"
)
else:
cols = list(df.columns)
# A star `*` at the beginning or end of the string specifies to look for all
# columns that starts/end with a specific string
if isinstance(label_cols, str):
if label_cols[0] == "*":
label_cols = [col for col in cols if str(col).endswith(label_cols[1:])]
elif label_cols[-1] == "*":
label_cols = [col for col in cols if str(col).startswith(label_cols[:-1])]
else:
label_cols = [label_cols]
elif label_cols is None:
label_cols = [col for col in cols if col != smiles_col]
return check_arg_iterator(label_cols, enforce_type=list)
@property
def is_prepared(self):
if not hasattr(self, "dataset"):
return False
return getattr(self, "dataset") is not None
@property
def is_setup(self):
if not (hasattr(self, "train_ds") or hasattr(self, "test_ds")):
return False
return (getattr(self, "train_ds") is not None) or (getattr(self, "test_ds") is not None)
@property
def num_node_feats(self):
"""Return the number of node features in the first graph"""
graph = self.get_fake_graph()
num_feats = graph.feat.shape[1]
return num_feats
@property
def in_dims(self):
"""
Return all input dimensions for the set of graphs.
Including node/edge features, and
raw positional encoding dimensions such eigval, eigvec, rwse and more
"""
graph = self.get_fake_graph()
if isinstance(graph, (GraphDict)):
graph = graph.data
# get list of all keys corresponding to positional encoding
pe_dim_dict = {}
g_keys = get_keys(graph)
# ignore the normal keys for node feat and edge feat etc.
for key in g_keys:
prop = graph.get(key, None)
if hasattr(prop, "shape"):
pe_dim_dict[key] = prop.shape[-1]
return pe_dim_dict
@property
def num_edge_feats(self):
"""Return the number of edge features in the first graph"""
graph = self.get_fake_graph()
empty = torch.Tensor([])
num_feats = graph.get("edge_feat", empty).shape[-1]
return num_feats
def get_fake_graph(self):
"""
Low memory footprint method to get the featurization of a fake graph
without reading the dataset. Useful for getting the number of node/edge features.
Returns:
graph: A fake graph with the right featurization
"""
smiles = "C1=CC=CC=C1"
trans = deepcopy(self.smiles_transformer)
trans.keywords.setdefault("on_error", "raise")
trans.keywords.setdefault("mask_nan", 0.0)
graph = trans(smiles)
return graph
########################## Private methods ######################################
def _save_to_cache(self):
raise NotImplementedError()
def _load_from_cache(self):
raise NotImplementedError()
def _extract_smiles_labels(
self,
df: pd.DataFrame,
task_level: str,
smiles_col: Optional[str] = None,
label_cols: List[str] = [],
idx_col: Optional[str] = None,
mol_ids_col: Optional[str] = None,
weights_col: Optional[str] = None,
weights_type: Optional[str] = None,
) -> Tuple[
np.ndarray, np.ndarray, Union[Type[None], np.ndarray], Dict[str, Union[Type[None], np.ndarray]]
]:
"""
For a given dataframe extract the SMILES and labels columns. Smiles is returned as a list
of string while labels are returned as a 2D numpy array.
Parameters:
df: Pandas dataframe
smiles_col: Name of the column containing the SMILES
label_cols: List of column names containing the labels
idx_col: Name of the column containing the index
mol_ids_col: Name of the column containing the molecule ids
weights_col: Name of the column containing the weights
weights_type: Type of weights to use.
Returns:
smiles, labels, sample_idx, extras
"""
if smiles_col is None: # Should we specify which dataset has caused the potential issue?
smiles_col_all = [col for col in df.columns if "smile" in str(col).lower()]
if len(smiles_col_all) == 0:
raise ValueError(f"No SMILES column found in dataframe. Columns are {df.columns}")
elif len(smiles_col_all) > 1:
raise ValueError(
f"Multiple SMILES column found in dataframe. SMILES Columns are {smiles_col_all}"
)
smiles_col = smiles_col_all[0]
if label_cols is None:
label_cols = df.columns.drop(smiles_col)
label_cols = check_arg_iterator(label_cols, enforce_type=list)
smiles = df[smiles_col].values
if len(label_cols) > 0:
if task_level == "graph":
labels = extract_labels(df, "graph", label_cols)
elif task_level == "node":
labels = extract_labels(df, "node", label_cols)
elif task_level == "edge":
labels = extract_labels(df, "edge", label_cols)
elif task_level == "nodepair":
labels = extract_labels(df, "nodepair", label_cols)
else:
raise ValueError(f"Unknown task level: {task_level}")
else:
labels = float("nan") + np.zeros([len(smiles), 0])
# Get the indices, used for sub-sampling and splitting the dataset
if idx_col is not None:
df = df.set_index(idx_col)
sample_idx = df.index.values
# Get the molecule ids
mol_ids = None
if mol_ids_col is not None:
mol_ids = df[mol_ids_col].values
# Extract the weights
weights = None
if weights_col is not None:
weights = df[weights_col].values
elif weights_type is not None:
if not np.all((labels == 0) | (labels == 1)):
raise ValueError("Labels must be binary for `weights_type`")
if weights_type == "sample_label_balanced":
ratio_pos_neg = np.sum(labels, axis=0, keepdims=1) / labels.shape[0]
weights = np.zeros(labels.shape)
weights[labels == 0] = ratio_pos_neg
weights[labels == 1] = ratio_pos_neg**-1
elif weights_type == "sample_balanced":
ratio_pos_neg = np.sum(labels, axis=0, keepdims=1) / labels.shape[0]
weights = np.zeros(labels.shape)
weights[labels == 0] = ratio_pos_neg
weights[labels == 1] = ratio_pos_neg**-1
weights = np.prod(weights, axis=1)
else:
raise ValueError(f"Undefined `weights_type` {weights_type}")
weights /= np.max(weights) # Put the max weight to 1
extras = {"weights": weights, "mol_ids": mol_ids}
return smiles, labels, sample_idx, extras
def _get_split_indices(
self,
dataset_size: int,
split_val: float,
split_test: float,
sample_idx: Optional[Iterable[int]] = None,
split_seed: int = None,
splits_path: Union[str, os.PathLike, Dict[str, Iterable[int]]] = None,
split_names: Optional[List[str]] = ["train", "val", "test"],
):
r"""
Compute indices of random splits.
Parameters:
dataset_size: Size of the dataset
split_val: Fraction of the dataset to use for validation
split_test: Fraction of the dataset to use for testing
sample_idx: Indices of the samples to use for splitting
split_seed: Seed for the random splitting
splits_path: Path to a file containing the splits
Returns:
train_indices, val_indices, test_indices
"""
if sample_idx is None:
sample_idx = np.arange(dataset_size)
if splits_path is None:
# Random splitting
if split_test + split_val > 0:
train_indices, val_test_indices = train_test_split(
sample_idx,
test_size=split_val + split_test,
random_state=split_seed,
)
sub_split_test = split_test / (split_test + split_val)
else:
train_indices = sample_idx
val_test_indices = np.array([])
sub_split_test = 0
if split_test > 0:
val_indices, test_indices = train_test_split(
val_test_indices,
test_size=sub_split_test,
random_state=split_seed,
)
else:
val_indices = val_test_indices
test_indices = np.array([])
else:
train, val, test = split_names
if isinstance(splits_path, (Dict, pd.DataFrame)):
# Split from a dataframe
splits = splits_path
else:
# Split from an indices file
file_type = self._get_data_file_type(splits_path)
train, val, test = split_names
if file_type == "pt":
splits = torch.load(splits_path)
elif file_type in ["csv", "tsv"]:
with fsspec.open(str(splits_path)) as f:
splits = self._read_csv(splits_path)
else:
raise ValueError(
f"file type `{file_type}` for `{splits_path}` not recognised, please use .pt, .csv or .tsv"
)
train_indices = np.asarray(splits[train]).astype("int")
train_indices = train_indices[~np.isnan(train_indices)].tolist()
val_indices = np.asarray(splits[val]).astype("int")
val_indices = val_indices[~np.isnan(val_indices)].tolist()
test_indices = np.asarray(splits[test]).astype("int")
test_indices = test_indices[~np.isnan(test_indices)].tolist()
# Filter train, val and test indices
_, train_idx, _ = np.intersect1d(sample_idx, train_indices, return_indices=True)
train_indices = train_idx.tolist()
_, valid_idx, _ = np.intersect1d(sample_idx, val_indices, return_indices=True)
val_indices = valid_idx.tolist()
_, test_idx, _ = np.intersect1d(sample_idx, test_indices, return_indices=True)
test_indices = test_idx.tolist()
return train_indices, val_indices, test_indices
def _sub_sample_df(
self, df: pd.DataFrame, sample_size: Union[int, float, None], seed: Optional[int] = None
) -> pd.DataFrame:
r"""
subsample from a pandas dataframe
Parameters:
df: pandas dataframe to subsample
sample_size: number of samples to subsample
Returns:
subsampled pandas dataframe
"""
# Sub-sample the dataframe
if isinstance(sample_size, int):
n = min(sample_size, df.shape[0])
df = df.sample(n=n, random_state=seed)
elif isinstance(sample_size, float):
df = df.sample(frac=sample_size, random_state=seed)
elif sample_size is None:
pass
else:
raise ValueError(
f"Wrong value for `sample_size`: {sample_size}"
) # Maybe specify which task it was for?
return df
def get_data_hash(self):
"""
Get a hash specific to a dataset and smiles_transformer.
Useful to cache the pre-processed data.
"""
args = {}
# pop epoch_sampling_fraction out when creating hash
# so that the data cache does not need to be regenerated
# when epoch_sampling_fraction has changed.
for task_key, task_args in deepcopy(self.task_specific_args).items():
if isinstance(task_args, DatasetProcessingParams):
task_args = task_args.__dict__ # Convert the class to a dictionary
# Keep only first 5 rows of a dataframe
if "df" in task_args.keys():
if task_args["df"] is not None:
task_args["df"] = task_args["df"].iloc[:5]
# Remove the `epoch_sampling_fraction`
task_args.pop("epoch_sampling_fraction", None)
args[task_key] = task_args
hash_dict = {
"smiles_transformer": self.smiles_transformer,
"task_specific_args": args,
}
data_hash = get_md5_hash(hash_dict)
return data_hash
def get_data_cache_fullname(self, compress: bool = False) -> str:
"""
Create a hash for the dataset, and use it to generate a file name
Parameters:
compress: Whether to compress the data
Returns:
full path to the data cache file
"""
if self.processed_graph_data_path is None:
return
ext = ".datacache"
if compress:
ext += ".gz"
data_cache_fullname = fs.join(self.processed_graph_data_path, self.data_hash + ext)
return data_cache_fullname
def load_data_from_cache(self, verbose: bool = True, compress: bool = False) -> bool:
"""
Load the datasets from cache. First create a hash for the dataset, and verify if that
hash is available at the path given by `self.processed_graph_data_path`.
Parameters:
verbose: Whether to print the progress
compress: Whether to compress the data
Returns:
cache_data_exists: Whether the cache exists (if the hash matches) and the loading succeeded
"""
full_cache_data_path = self.get_data_cache_fullname(compress=compress)
if full_cache_data_path is None:
logger.info("No cache data path specified. Skipping loading the data from cache.")
return False
cache_data_exists = fs.exists(full_cache_data_path)
if cache_data_exists:
try:
logger.info(f"Loading the data from cache at path `{full_cache_data_path}`")
now = time.time()
with fsspec.open(full_cache_data_path, mode="rb", compression="infer") as file:
load_params = torch.load(file)
self.__dict__.update(load_params)
(
self.train_singletask_datasets,
self.val_singletask_datasets,
self.test_singletask_datasets,
) = self.get_subsets_of_datasets(
self.single_task_datasets,
self.task_train_indices,
self.task_val_indices,
self.task_test_indices,
)
elapsed = round(time.time() - now)
logger.info(
f"Successfully loaded the data from cache in {elapsed}s at path: `{full_cache_data_path}`"
)
return True
except Exception as e:
if verbose:
logger.warning(
f"Data cache failed to load path: `{full_cache_data_path}`.\nThe data will be prepared and cache will be created for future runs."
)
logger.warning(e.__str__())
return False
else:
if verbose:
logger.info(
f"Data cache not found at path: `{full_cache_data_path}`.\nThe data will be prepared and cache will be created for future runs."
)
return False
def get_subsets_of_datasets(
self,
single_task_datasets: Dict[str, Datasets.SingleTaskDataset],
task_train_indices: Dict[str, Iterable],
task_val_indices: Dict[str, Iterable],
task_test_indices: Dict[str, Iterable],
) -> Tuple[Subset, Subset, Subset]:
"""
From a dictionary of datasets and their associated indices, subset the train/val/test sets
Parameters:
single_task_datasets: Dictionary of datasets
task_train_indices: Dictionary of train indices
task_val_indices: Dictionary of val indices
task_test_indices: Dictionary of test indices
Returns:
train_singletask_datasets: Dictionary of train subsets
val_singletask_datasets: Dictionary of val subsets
test_singletask_datasets: Dictionary of test subsets
"""
train_singletask_datasets = {}
val_singletask_datasets = {}
test_singletask_datasets = {}
for task in task_train_indices.keys():
train_singletask_datasets[task] = Subset(single_task_datasets[task], task_train_indices[task])
val_singletask_datasets[task] = Subset(single_task_datasets[task], task_val_indices[task])
test_singletask_datasets[task] = Subset(single_task_datasets[task], task_test_indices[task])
return train_singletask_datasets, val_singletask_datasets, test_singletask_datasets
def __len__(self) -> int:
r"""
Returns the number of elements of the current DataModule, which is the combined size of all single-task datasets given.
Returns:
num_elements: Number of elements in the current DataModule
"""
num_elements = 0
for task, args in self.task_dataset_processing_params.items():
if args.df is None:
df = self._read_table(args.df_path, usecols=[args.smiles_col])
num_elements += len(df)
else:
num_elements += len(args.df)
return num_elements
def to_dict(self) -> Dict[str, Any]:
"""
Returns a dictionary representation of the current DataModule
Returns:
obj_repr: Dictionary representation of the current DataModule
"""
# TODO: Change to make more multi-task friendly
obj_repr = {}
obj_repr["name"] = self.__class__.__name__
obj_repr["len"] = len(self)
obj_repr["train_size"] = len(self.train_indices) if self.train_indices is not None else None
obj_repr["val_size"] = len(self.val_indices) if self.val_indices is not None else None
obj_repr["test_size"] = len(self.test_indices) if self.test_indices is not None else None
obj_repr["batch_size_training"] = self.batch_size_training
obj_repr["batch_size_inference"] = self.batch_size_inference
obj_repr["batch_size_per_pack"] = self.batch_size_per_pack
obj_repr["num_node_feats"] = self.num_node_feats
obj_repr["num_node_feats_with_positional_encoding"] = self.num_node_feats_with_positional_encoding
obj_repr["num_edge_feats"] = self.num_edge_feats
obj_repr["num_tasks"] = len(self.task_dataset_processing_params)
obj_repr["num_labels"] = len(self.label_cols)
obj_repr["collate_fn"] = self.collate_fn.__name__
obj_repr["featurization"] = self.featurization
return obj_repr
def __repr__(self) -> str:
r"""
Controls how the class is printed
Returns:
"""
return omegaconf.OmegaConf.to_yaml(self.to_dict())
class GraphOGBDataModule(MultitaskFromSmilesDataModule):
def __init__(
self,
task_specific_args: Dict[str, Union[DatasetProcessingParams, Dict[str, Any]]],
processed_graph_data_path: Optional[Union[str, os.PathLike]] = None,
dataloading_from: str = "ram",
featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None,
batch_size_training: int = 16,
batch_size_inference: int = 16,
batch_size_per_pack: Optional[int] = None,
num_workers: int = 0,
pin_memory: bool = True,
persistent_workers: bool = False,
multiprocessing_context: Optional[str] = None,
featurization_n_jobs: int = -1,
featurization_progress: bool = False,
featurization_backend: str = "loky",
collate_fn: Optional[Callable] = None,
prepare_dict_or_graph: str = "pyg:graph",
**kwargs,
):
r"""
Load an OGB (Open-graph-benchmark) GraphProp dataset.
Parameters:
task_specific_args: Arguments related to each task, with the task-name being the key,
and the specific arguments being the values. The arguments must be a
Dict containing the following keys:
- "dataset_name": Name of the OGB dataset to load. Examples of possible datasets are
"ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv".
- "sample_size": The number of molecules to sample from the dataset. Default=None,
meaning that all molecules will be considered.
processed_graph_data_path: Path to the processed graph data. If None, the data will be
downloaded from the OGB website.
dataloading_from: Whether to load the data from RAM or disk. Default is "ram".
featurization: args to apply to the SMILES to Graph featurizer.
batch_size_training: batch size for training and val dataset.
batch_size_inference: batch size for test dataset.
num_workers: Number of workers for the dataloader. Use -1 to use all available
cores.
pin_memory: Whether to pin on paginated CPU memory for the dataloader.
featurization_n_jobs: Number of cores to use for the featurization.
featurization_progress: whether to show a progress bar during featurization.
featurization_backend: The backend to use for the molecular featurization.
- "multiprocessing": Found to cause less memory issues.
- "loky": joblib's Default. Found to cause memory leaks.
- "threading": Found to be slow.
collate_fn: A custom torch collate function. Default is to `graphium.data.graphium_collate_fn`
sample_size:
- `int`: The maximum number of elements to take from the dataset.
- `float`: Value between 0 and 1 representing the fraction of the dataset to consider
- `None`: all elements are considered.
"""
new_task_specific_args = {}
self.metadata = {}
for task_name, task_args in task_specific_args.items():
# Get OGB metadata
this_metadata = self._get_dataset_metadata(task_args["dataset_name"])
# Get dataset
df, mol_ids_col, smiles_col, label_cols, splits_path = self._load_dataset(
this_metadata, sample_size=task_args.get("sample_size", None)
)
new_task_specific_args[task_name] = {
"df": df,
"mol_ids_col": mol_ids_col,
"smiles_col": smiles_col,
"label_cols": label_cols,
"splits_path": splits_path,
"task_level": task_args["task_level"],
}
self.metadata[task_name] = this_metadata
# Config for datamodule
dm_args = {}
dm_args["task_specific_args"] = new_task_specific_args
dm_args["processed_graph_data_path"] = processed_graph_data_path
dm_args["dataloading_from"] = dataloading_from
dm_args["dataloader_from"] = dataloading_from
dm_args["featurization"] = featurization
dm_args["batch_size_training"] = batch_size_training
dm_args["batch_size_inference"] = batch_size_inference
dm_args["batch_size_per_pack"] = batch_size_per_pack
dm_args["num_workers"] = num_workers
dm_args["pin_memory"] = pin_memory
dm_args["featurization_n_jobs"] = featurization_n_jobs
dm_args["featurization_progress"] = featurization_progress
dm_args["featurization_backend"] = featurization_backend
dm_args["persistent_workers"] = persistent_workers
dm_args["multiprocessing_context"] = multiprocessing_context
dm_args["collate_fn"] = collate_fn
dm_args["prepare_dict_or_graph"] = prepare_dict_or_graph
super().__init__(**dm_args, **kwargs)
def to_dict(self) -> Dict[str, Any]:
r"""
geenrate a dictionary representation of the class
Returns:
dict: dictionary representation of the class
"""
# TODO: Change to make more multi-task friendly
obj_repr = {}
obj_repr["dataset_name"] = self.dataset_name
obj_repr.update(super().to_dict())
return obj_repr
# Private methods
def _load_dataset(
self,
metadata: dict,
sample_size: Optional[int] = None,
) -> Tuple[pd.DataFrame, str, str, List[str], str]:
"""
Download, extract and load an OGB dataset.
Parameters:
metadata: Metadata for the dataset to load.
sample_size: The number of molecules to sample from the dataset. Default=None,
meaning that all molecules will be considered.
Returns:
df: Pandas dataframe containing the dataset.
mol_ids_col: Name of the column containing the molecule ids.
smiles_col: Name of the column containing the SMILES.
label_cols: List of column names containing the labels.
splits_path: Path to the file containing the train/val/test splits.
"""
base_dir = fs.get_cache_dir("ogb")
dataset_dir = base_dir / metadata["download_name"]
if not dataset_dir.exists():
# Create cache filepath for zip file and associated folder
dataset_path = base_dir / f"{metadata['download_name']}.zip"
# Download it
if not dataset_path.exists():
logger.info(f"Downloading {metadata['url']} to {dataset_path}")
fs.copy(metadata["url"], dataset_path, progress=True)
# Extract
zf = zipfile.ZipFile(dataset_path)
zf.extractall(base_dir)
# Load CSV file
if metadata["download_name"].startswith("pcqm4m"):
df_path = dataset_dir / "raw" / "data.csv.gz"
else:
df_path = dataset_dir / "mapping" / "mol.csv.gz"
logger.info(f"Loading {df_path} in memory.")
df = pd.read_csv(df_path)
# Subsample the dataset
df = self._sub_sample_df(df, sample_size)
# Load split from the OGB dataset and save them in a single CSV file
if metadata["download_name"].startswith("pcqm4m"):
split_name = metadata["split"]
split_dict = torch.load(dataset_dir / "split_dict.pt")
train_split = pd.DataFrame(split_dict["train"])
val_split = pd.DataFrame(split_dict["valid"])
if "test" in split_dict.keys():
test_split = pd.DataFrame(split_dict["test"])
else:
test_split = pd.DataFrame(split_dict["test-dev"])
splits = pd.concat([train_split, val_split, test_split], axis=1) # type: ignore
splits.columns = ["train", "val", "test"]
splits_path = dataset_dir / "split"
if not splits_path.exists():
os.makedirs(splits_path)
splits_path = dataset_dir / f"{split_name}.csv.gz"
else:
splits_path = splits_path / f"{split_name}.csv.gz"
else:
split_name = metadata["split"]
train_split = pd.read_csv(dataset_dir / "split" / split_name / "train.csv.gz", header=None) # type: ignore
val_split = pd.read_csv(dataset_dir / "split" / split_name / "valid.csv.gz", header=None) # type: ignore
test_split = pd.read_csv(dataset_dir / "split" / split_name / "test.csv.gz", header=None) # type: ignore
splits = pd.concat([train_split, val_split, test_split], axis=1) # type: ignore
splits.columns = ["train", "val", "test"]
splits_path = dataset_dir / "split" / f"{split_name}.csv.gz"
logger.info(f"Saving splits to {splits_path}")
splits.to_csv(splits_path, index=None)
# Get column names: OGB columns are predictable
if metadata["download_name"].startswith("pcqm4m"):
mol_ids_col = df.columns[0]
smiles_col = df.columns[-2]
label_cols = df.columns[-1:].to_list()
else:
mol_ids_col = df.columns[-1]
smiles_col = df.columns[-2]
label_cols = df.columns[:-2].to_list()
return df, mol_ids_col, smiles_col, label_cols, splits_path
def _get_dataset_metadata(self, dataset_name: str) -> Dict[str, Any]:
ogb_metadata = self._get_ogb_metadata()
if dataset_name not in ogb_metadata.index:
raise ValueError(f"'{dataset_name}' is not a valid dataset name.")
return ogb_metadata.loc[dataset_name].to_dict()
def _get_ogb_metadata(self):
"""
Get the metadata of OGB GraphProp datasets.
"""
with importlib.resources.open_text("ogb.graphproppred", "master.csv") as f:
ogb_metadata = pd.read_csv(f)
ogb_metadata = ogb_metadata.set_index(ogb_metadata.columns[0])
ogb_metadata = ogb_metadata.T
# Add metadata related to PCQM4M
ogb_metadata = pd.concat([ogb_metadata, pd.DataFrame(PCQM4M_meta, index=["ogbg-lsc-pcqm4m"])])
ogb_metadata = pd.concat([ogb_metadata, pd.DataFrame(PCQM4Mv2_meta, index=["ogbg-lsc-pcqm4mv2"])])
# Only keep datasets of type 'mol'
ogb_metadata = ogb_metadata[ogb_metadata["data type"] == "mol"]
return ogb_metadata
class ADMETBenchmarkDataModule(MultitaskFromSmilesDataModule):
"""
Wrapper to use the ADMET benchmark group from the TDC (Therapeutics Data Commons).
!!! warning "Dependency"
This class requires [PyTDC](https://pypi.org/project/PyTDC/) to be installed.
!!! note "Citation"
Huang, K., Fu, T., Gao, W., Zhao, Y., Roohani, Y., Leskovec, J., Coley, C., Xiao, C., Sun, J., & Zitnik, M. (2021).
Therapeutics Data Commons: Machine Learning Datasets and Tasks for Drug Discovery and Development.
Proceedings of Neural Information Processing Systems, NeurIPS Datasets and Benchmarks.
Parameters:
tdc_benchmark_names: This can be any subset of the benchmark names that make up the ADMET benchmarking group.
If `None`, uses the complete benchmarking group. For all full list of options, see
[the TDC website](https://tdcommons.ai/benchmark/admet_group/overview/) or use:
```python
import tdc.utils.retrieve_benchmark_names
retrieve_benchmark_names("admet_group")
```
tdc_train_val_seed: TDC recommends a default splitting method for the train-val split. This parameter
is used to seed that splitting method.
"""
def __init__(
self,
# TDC-specific
tdc_benchmark_names: Optional[Union[str, List[str]]] = None,
tdc_train_val_seed: int = 0,
# Inherited arguments from superclass
processed_graph_data_path: Optional[Union[str, Path]] = None,
dataloading_from: str = "ram",
featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None,
batch_size_training: int = 16,
batch_size_inference: int = 16,
batch_size_per_pack: Optional[int] = None,
num_workers: int = 0,
pin_memory: bool = True,
persistent_workers: bool = False,
multiprocessing_context: Optional[str] = None,
featurization_n_jobs: int = -1,
featurization_progress: bool = False,
featurization_backend: str = "loky",
collate_fn: Optional[Callable] = None,
prepare_dict_or_graph: str = "pyg:graph",
**kwargs,
):
try:
from tdc.benchmark_group import admet_group
from tdc.utils import retrieve_benchmark_names
except ImportError as error:
# To make sure we use the exact same train-test set and preprocessing as other benchmark entries,
# we rely on the PyTDC library.
raise RuntimeError(
f"To use {self.__class__.__name__}, `PyTDC` needs to be installed. "
f"Please install it with `pip install PyTDC`"
) from error
# Pick a path to save the TDC data to
tdc_cache_dir = fs.get_cache_dir("tdc")
tdc_cache_dir = fs.join(tdc_cache_dir, "ADMET_Benchmark")
fs.mkdir(tdc_cache_dir, exist_ok=True)
# Create the benchmark group object
# NOTE (cwognum): We redirect stderr and stdout to a file since TDC uses print statements,
# which quickly pollute the logs. Ideally, we would use `redirect_stderr(None)`, but that breaks TQDM.
with tempfile.TemporaryFile("w") as f:
with redirect_stderr(f):
with redirect_stdout(f):
self.group = admet_group(path=tdc_cache_dir)
# By default, use all available benchmarks in a benchmark group
if tdc_benchmark_names is None:
tdc_benchmark_names = retrieve_benchmark_names("admet_group")
if isinstance(tdc_benchmark_names, str):
tdc_benchmark_names = [tdc_benchmark_names]
# Create the task-specific arguments
logger.info(
f"Preparing the TDC ADMET Benchmark Group splits for each of the {len(tdc_benchmark_names)} benchmarks."
)
task_specific_args = {
t: self._get_task_specific_arguments(t, tdc_train_val_seed, tdc_cache_dir)
for t in tdc_benchmark_names
}
super().__init__(
task_specific_args=task_specific_args,
featurization=featurization,
processed_graph_data_path=processed_graph_data_path,
dataloading_from=dataloading_from,
batch_size_training=batch_size_training,
batch_size_inference=batch_size_inference,
batch_size_per_pack=batch_size_per_pack,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
multiprocessing_context=multiprocessing_context,
featurization_n_jobs=featurization_n_jobs,
featurization_progress=featurization_progress,
featurization_backend=featurization_backend,
collate_fn=collate_fn,
prepare_dict_or_graph=prepare_dict_or_graph,
**kwargs,
)
def _get_task_specific_arguments(self, name: str, seed: int, cache_dir: str) -> DatasetProcessingParams:
"""
Loads the data and split from the TDC benchmark group object.
For the train-test split, this is fixed.
For the train-val split, this does not have to be fixed. Here we use the default splitting method
from TDC to do the train-val split and allow the seed to be changed. This is likely to best match
other entries in the benchmarking group.
"""
benchmark = self.group.get(name)
# Get the default train-val-test split
# NOTE (cwognum): TDC prints by default to stderr, which pollutes the logs quite a bit.
# This context manager mutes these by temporarily writing to a file.
# Ideally, we would use `redirect_stderr(None)`, but that breaks TQDM.
with tempfile.TemporaryFile("w") as f:
with redirect_stderr(f):
train, val = self.group.get_train_valid_split(seed, name)
test = benchmark["test"]
# Convert to the Graphium format
n_val = len(val)
n_test = len(test)
n_train = len(train)
max_len = max(n_train, n_val, n_test)
total_len = n_train + n_val + n_test
data = pd.concat([train, val, test], ignore_index=True)
# NOTE (cwognum): We need to convert the labels to float, since we use NaNs down the line.
# If you uncomment this line, collating the labels will raise an overflow by converting a NaN to the int dtype.
data["Y"] = data["Y"].astype(float)
split = pd.DataFrame(
{
"train": list(range(n_train)),
"val": list(range(n_train, n_train + n_val)) + [float("nan")] * (max_len - n_val),
"test": list(range(n_train + n_val, total_len)) + [float("nan")] * (max_len - n_test),
}
)
split_path = fs.join(cache_dir, f"{name}_split.csv")
split.to_csv(split_path, index=False)
return DatasetProcessingParams(
df=data,
idx_col=None,
smiles_col="Drug",
label_cols=["Y"],
splits_path=split_path,
split_names=["train", "val", "test"],
task_level="graph",
)
class FakeDataModule(MultitaskFromSmilesDataModule):
"""
A fake datamodule that generates artificial data by mimicking the true data coming
from the provided dataset.
It is useful to test the speed and performance of the model on a dataset without
having to featurize it and wait for the workers to load it.
"""
def __init__(
self,
task_specific_args: Dict[str, Dict[str, Any]], # TODO: Replace this with DatasetParams
featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None,
batch_size_training: int = 16,
batch_size_inference: int = 16,
num_workers: int = 0,
pin_memory: bool = True,
persistent_workers: bool = False,
multiprocessing_context: Optional[str] = None,
collate_fn: Optional[Callable] = None,
prepare_dict_or_graph: str = "pyg:graph",
num_mols_to_generate: int = 1000000,
indexing_single_elem: bool = True,
**kwargs,
):
super().__init__(
task_specific_args=task_specific_args,
featurization=featurization,
batch_size_training=batch_size_training,
batch_size_inference=batch_size_inference,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
multiprocessing_context=multiprocessing_context,
collate_fn=collate_fn,
prepare_dict_or_graph=prepare_dict_or_graph,
**kwargs,
)
self.num_mols_to_generate = num_mols_to_generate
self.indexing_single_elem = indexing_single_elem
def generate_data(self, label_cols: List[str], smiles_col: str):
"""
Parameters:
labels_cols
smiles_col
Returns:
pd.DataFrame
"""
num_generated_mols = int(1)
# Create a dummy generated dataset - singel smiles string, duplicated N times
example_molecules = dict(
smiles="C1N2C3C4C5OC13C2C45",
cxsmiles="[H]C1C2=C(NC(=O)[C@@]1([H])C1=C([H])C([H])=C(C([H])([H])[H])C([H])=C1[H])C([H])=C([H])N=C2[H] |(6.4528,-1.5789,-1.2859;5.789,-0.835,-0.8455;4.8499,-0.2104,-1.5946;3.9134,0.7241,-0.934;3.9796,1.1019,0.3172;5.0405,0.6404,1.1008;5.2985,1.1457,2.1772;5.9121,-0.5519,0.613;6.9467,-0.2303,0.8014;5.677,-1.7955,1.4745;4.7751,-2.7953,1.0929;4.2336,-2.7113,0.154;4.5521,-3.9001,1.914;3.8445,-4.6636,1.5979;5.215,-4.0391,3.1392;4.9919,-5.2514,4.0126;5.1819,-5.0262,5.0671;5.6619,-6.0746,3.7296;3.966,-5.6247,3.925;6.1051,-3.0257,3.52;6.6247,-3.101,4.4725;6.3372,-1.9217,2.7029;7.0168,-1.1395,3.0281;2.8586,1.2252,-1.7853;2.1303,1.9004,-1.3493;2.8118,0.8707,-3.0956;2.0282,1.2549,-3.7434;3.716,0.0207,-3.7371;4.6658,-0.476,-3.0127;5.3755,-1.1468,-3.5021)|",
)
example_df_entry = {smiles_col: example_molecules[smiles_col]}
for label in label_cols:
example_df_entry[label] = np.random.random()
df = pd.DataFrame([example_df_entry])
logger.info(f"Generating fake dataset on host... \n Generating {num_generated_mols} rows in the df.")
df = pd.concat([df] * num_generated_mols, ignore_index=True)
return df
def prepare_data(self):
"""Called only from a single process in distributed settings. Steps:
- If each cache is set and exists, reload from cache and return. Otherwise,
- For each single-task dataset:
- Load its dataframe from a path (if provided)
- Subsample the dataframe
- Extract the smiles, labels from the dataframe
- In the previous step, we were also able to get the unique smiles, which we use to compute the features
- For each single-task dataframe and associated data (smiles, labels, etc.):
- Filter out the data corresponding to molecules which failed featurization.
- Create a corresponding SingletaskDataset
- Split the SingletaskDataset according to the task-specific splits for train, val and test
"""
"""Load all single-task dataframes."""
if self.num_mols_to_generate is None:
num_mols = 0
task_df = {}
for task, args in self.task_dataset_processing_params.items():
logger.info(f"Reading data for task '{task}'")
if args.df is None:
# Only load the useful columns, as some datasets can be very large when loading all columns.
label_cols = self._parse_label_cols(
df=None, df_path=args.df_path, label_cols=args.label_cols, smiles_col=args.smiles_col
)
task_df[task] = self.generate_data(label_cols=args.label_cols, smiles_col=args.smiles_col)
if self.num_mols_to_generate is None:
num_mols = max(num_mols, len(task_df[task]))
task_df[task] = task_df[task].iloc[0:1]
args.label_cols = label_cols
if self.num_mols_to_generate is None:
self.num_mols_to_generate = num_mols
logger.info("Done reading datasets")
"""Subsample the data frames and extract the necessary data to create SingleTaskDatasets for each task (smiles, labels, extras)."""
task_dataset_args = {}
for task in task_df.keys():
task_dataset_args[task] = {}
for task, df in task_df.items():
logger.info(f"Prepare single-task dataset for task '{task}' with {len(df)} data points.")
# Extract smiles, labels, extras
args = self.task_dataset_processing_params[task]
smiles, labels, sample_idx, extras = self._extract_smiles_labels(
df,
task_level=args.task_level,
smiles_col=args.smiles_col,
label_cols=args.label_cols,
idx_col=args.idx_col,
weights_col=args.weights_col,
weights_type=args.weights_type,
)
# Store the relevant information for each task's dataset
task_dataset_args[task]["smiles"] = smiles
task_dataset_args[task]["labels"] = labels
task_dataset_args[task]["sample_idx"] = sample_idx
task_dataset_args[task]["extras"] = extras
"""Convert SMILES to features (graphs, fingerprints, etc.) for the unique molecules found."""
all_smiles = []
idx_per_task = {}
total_len = 0
for task, dataset_args in task_dataset_args.items():
all_smiles.extend(dataset_args["smiles"])
num_smiles = len(dataset_args["smiles"])
idx_per_task[task] = (total_len, total_len + num_smiles)
total_len += num_smiles
# Get all unique mol ids
all_unique_mol_ids = smiles_to_unique_mol_ids(
all_smiles,
n_jobs=self.featurization_n_jobs,
featurization_batch_size=self.featurization_batch_size,
backend=self.featurization_backend,
)
# Convert SMILES to features
features, _ = self._featurize_molecules(all_smiles)
task_dataset_args[task]["features"] = features
"""Filter data based on molecules which failed featurization. Create single task datasets as well."""
self.single_task_datasets = {}
for task, args in task_dataset_args.items():
self.single_task_datasets[task] = Datasets.SingleTaskDataset(
features=task_dataset_args[task]["features"],
labels=task_dataset_args[task]["labels"],
smiles=task_dataset_args[task]["smiles"],
indices=task_dataset_args[task]["sample_idx"],
unique_ids=all_unique_mol_ids[idx_per_task[task][0] : idx_per_task[task][1]],
**task_dataset_args[task]["extras"],
)
"""We split the data up to create train, val and test datasets"""
self.train_singletask_datasets = {}
self.val_singletask_datasets = {}
self.test_singletask_datasets = {}
for task, df in task_df.items():
self.train_singletask_datasets[task] = Subset(self.single_task_datasets[task], [0])
self.val_singletask_datasets[task] = Subset(self.single_task_datasets[task], [0])
self.test_singletask_datasets[task] = Subset(self.single_task_datasets[task], [0])
def setup(self, stage=None):
# TODO
"""
Prepare the torch dataset. Called on every GPUs. Setting state here is ok.
Parameters:
stage (str): Either 'fit', 'test', or None.
"""
labels_size = {}
if stage == "fit" or stage is None:
self.train_ds = Datasets.FakeDataset(self.train_singletask_datasets, num_mols=self.num_mols_to_generate, indexing_same_elem=self.indexing_single_elem) # type: ignore
self.val_ds = Datasets.FakeDataset(self.val_singletask_datasets, num_mols=self.num_mols_to_generate, indexing_same_elem=self.indexing_single_elem) # type: ignore
print(self.train_ds)
print(self.val_ds)
labels_size.update(
self.train_ds.labels_size
) # Make sure that all task label sizes are contained in here. Maybe do the update outside these if statements.
labels_size.update(self.val_ds.labels_size)
if stage == "test" or stage is None:
self.test_ds = Datasets.FakeDataset(self.test_singletask_datasets, num_mols=self.num_mols_to_generate, indexing_same_elem=self.indexing_single_elem) # type: ignore
print(self.test_ds)
labels_size.update(self.test_ds.labels_size)
default_labels_size_dict = self.collate_fn.keywords.get("labels_size_dict", None)
if default_labels_size_dict is None:
self.collate_fn.keywords["labels_size_dict"] = labels_size
def get_fake_graph(self):
"""
Low memory footprint method to get the first datapoint DGL graph.
The first 10 rows of the data are read in case the first one has a featurization
error. If all 20 first element, then `None` is returned, otherwise the first
graph to not fail is returned.
"""
keys = list(self.task_dataset_processing_params.keys())
task = keys[0]
args = self.task_dataset_processing_params[task]
if args.df is None:
df = self._read_csv(args.df_path, nrows=20)
else:
df = args.df.iloc[0:20, :]
df = df.iloc[0:20, :]
label_cols = self._parse_label_cols(
df, df_path=None, label_cols=args.label_cols, smiles_col=args.smiles_col
)
smiles, labels, sample_idx, extras = self._extract_smiles_labels(
df,
task_level=args.task_level,
smiles_col=args.smiles_col,
label_cols=label_cols,
idx_col=args.idx_col,
weights_col=args.weights_col,
weights_type=args.weights_type,
)
graph = None
for s in smiles:
graph = self.smiles_transformer(s, mask_nan=0.0)
num_nodes = graph.num_nodes
num_edges = graph.num_edges
if (graph is not None) and (num_edges > 0) and (num_nodes > 0):
break
return graph
================================================
FILE: graphium/data/dataset.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import os
from copy import deepcopy
from functools import lru_cache
from multiprocessing import Manager
from typing import Any, Dict, List, Optional, Tuple, Union
import fsspec
import numpy as np
import torch
from datamol import parallelized, parallelized_with_batches
from loguru import logger
from torch.utils.data.dataloader import Dataset
from torch_geometric.data import Batch, Data
from graphium.data.smiles_transform import smiles_to_unique_mol_ids
from graphium.features import GraphDict
class SingleTaskDataset(Dataset):
def __init__(
self,
labels: List[Union[torch.Tensor, np.ndarray]],
features: Optional[List[Union[Data, "GraphDict"]]] = None,
smiles: Optional[List[str]] = None,
indices: Optional[List[int]] = None,
weights: Optional[Union[torch.Tensor, np.ndarray]] = None,
unique_ids: Optional[List[str]] = None,
mol_ids: Optional[List[str]] = None,
):
r"""
dataset for a single task
Parameters:
labels: A list of labels for the given task (one per graph)
features: A list of graphs
smiles: A list of smiles
indices: A list of indices
weights: A list of weights
unique_ids: A list of unique ids for each molecule generated from `datamol.unique_id`
mol_ids: A list of ids coming from the original dataset. Useful to identify the molecule in the original dataset.
"""
# Verify that all lists are the same length
numel = len(labels)
def _check_if_same_length(to_check, label):
"""Simple utility method to throw an error if the length is not as expected."""
if to_check is not None and len(to_check) != numel:
raise ValueError(
f"{label} must be the same length as `labels`, got {len(to_check)} and {numel}"
)
_check_if_same_length(features, "features")
_check_if_same_length(indices, "indices")
_check_if_same_length(weights, "weights")
_check_if_same_length(unique_ids, "unique_ids")
_check_if_same_length(mol_ids, "mol_ids")
self.labels = labels
if smiles is not None:
manager = Manager() # Avoid memory leaks with `num_workers > 0` by using the Manager
self.smiles = manager.list(smiles)
else:
self.smiles = None
self.features = features
self.indices = indices
if self.indices is not None:
self.indices = np.array(
self.indices
) # Avoid memory leaks with `num_workers > 0` by using numpy array
self.weights = weights
self.unique_ids = unique_ids
self.mol_ids = mol_ids
def __len__(self):
r"""
return the size of the dataset
Returns:
size: the size of the dataset
"""
return len(self.labels)
def __getitem__(self, idx):
"""
get the data at the given index
Parameters:
idx: the index to get the data at
Returns:
datum: a dictionary containing the data at the given index, with keys "features", "labels", "smiles", "indices", "weights", "unique_ids"
"""
datum = {}
if self.features is not None:
datum["features"] = self.features[idx]
if self.labels is not None:
datum["labels"] = self.labels[idx]
if self.smiles is not None:
datum["smiles"] = self.smiles[idx]
if self.indices is not None:
datum["indices"] = self.indices[idx]
if self.weights is not None:
datum["weights"] = self.weights[idx]
if self.unique_ids is not None:
datum["unique_ids"] = self.unique_ids[idx]
if self.mol_ids is not None:
datum["mol_ids"] = self.mol_ids[idx]
return datum
def __getstate__(self):
"""Serialize the class for pickling."""
state = {}
state["labels"] = self.labels
state["smiles"] = list(self.smiles) if self.smiles is not None else None
state["features"] = self.features
state["indices"] = self.indices
state["weights"] = self.weights
state["unique_ids"] = self.unique_ids
state["mol_ids"] = self.mol_ids
return state
def __setstate__(self, state: dict):
"""Reload the class from pickling."""
if state["smiles"] is not None:
manager = Manager()
state["smiles"] = manager.list(state["smiles"])
self.__dict__.update(state)
class MultitaskDataset(Dataset):
pass
def __init__(
self,
datasets: Dict[str, SingleTaskDataset],
n_jobs=-1,
backend: str = "loky",
featurization_batch_size=1000,
progress: bool = True,
save_smiles_and_ids: bool = False,
about: str = "",
data_path: Optional[Union[str, os.PathLike]] = None,
dataloading_from: str = "ram",
data_is_cached: bool = False,
):
r"""
This class holds the information for the multitask dataset.
Several single-task datasets can be merged to create a multi-task dataset. After merging the dictionary of single-task datasets.
we will have a multitask dataset of the following form:
- self.mol_ids will be a list to contain the unique molecular IDs to identify the molecules
- self.smiles will be a list to contain the corresponding smiles for that molecular ID across all single-task datasets
- self.labels will be a list of dictionaries where the key is the task name and the value is the label(s) for that task.
At this point, any particular molecule will only have entries for tasks for which it has a label. Later, in the collate
function, we fill up the missing task labels with NaNs.
- self.features will be a list of featurized graphs corresponding to that particular unique molecule.
However, for testing purposes we may not require features so that we can make sure that this merge function works.
Parameters:
datasets: A dictionary of single-task datasets
n_jobs: Number of jobs to run in parallel
backend: Parallelization backend
featurization_batch_size: The batch size to use for the parallelization of the featurization
progress: Whether to display the progress bar
save_smiles_and_ids: Whether to save the smiles and ids for the dataset. If `False`, `mol_ids` and `smiles` are set to `None`
about: A description of the dataset
data_path: The location of the data if saved on disk
dataloading_from: Whether to load the data from `"disk"` or `"ram"`
data_is_cached: Whether the data is already cached on `"disk"`
"""
super().__init__()
self.n_jobs = n_jobs
self.backend = backend
self.featurization_batch_size = featurization_batch_size
self.progress = progress
self.about = about
self.save_smiles_and_ids = save_smiles_and_ids
self.data_path = data_path
self.dataloading_from = dataloading_from
logger.info(f"Dataloading from {dataloading_from.upper()}")
if data_is_cached:
self._load_metadata()
if dataloading_from == "disk":
self.features = None
self.labels = None
elif dataloading_from == "ram":
logger.info(f"Transferring {about} from DISK to RAM...")
self.transfer_from_disk_to_ram()
else:
task = next(iter(datasets))
self.features = None
if (len(datasets[task]) > 0) and ("features" in datasets[task][0]):
self.mol_ids, self.smiles, self.labels, self.features = self.merge(datasets)
else:
self.mol_ids, self.smiles, self.labels = self.merge(datasets)
# Set mol_ids and smiles to None to save memory as they are not needed.
if not save_smiles_and_ids:
self.mol_ids = None
self.smiles = None
self.labels_size = self.set_label_size_dict(datasets)
self.labels_dtype = self.set_label_dtype_dict(datasets)
self.dataset_length = len(self.labels)
self._num_nodes_list = None
self._num_edges_list = None
if self.features is not None:
self._num_nodes_list = get_num_nodes_per_graph(self.features)
self._num_edges_list = get_num_edges_per_graph(self.features)
def transfer_from_disk_to_ram(self, parallel_with_batches: bool = False):
"""
Function parallelizing transfer from DISK to RAM
"""
def transfer_mol_from_disk_to_ram(idx):
"""
Function transferring single mol from DISK to RAM
"""
data_dict = self.load_graph_from_index(idx)
mol_in_ram = {
"features": data_dict["graph_with_features"],
"labels": data_dict["labels"],
}
return mol_in_ram
if parallel_with_batches and self.featurization_batch_size:
data_in_ram = parallelized_with_batches(
transfer_mol_from_disk_to_ram,
range(self.dataset_length),
batch_size=self.featurization_batch_size,
n_jobs=0,
backend=self.backend,
progress=self.progress,
tqdm_kwargs={"desc": "Transfer from DISK to RAM"},
)
else:
data_in_ram = parallelized(
transfer_mol_from_disk_to_ram,
range(self.dataset_length),
n_jobs=0,
backend=self.backend,
progress=self.progress,
tqdm_kwargs={"desc": "Transfer from DISK to RAM"},
)
self.features = [sample["features"] for sample in data_in_ram]
self.labels = [sample["labels"] for sample in data_in_ram]
def save_metadata(self, directory: str):
"""
Save everything other than features/labels
"""
attrs_to_save = [
"mol_ids",
"smiles",
"labels_size",
"labels_dtype",
"dataset_length",
"_num_nodes_list",
"_num_edges_list",
]
attrs = {attr: getattr(self, attr) for attr in attrs_to_save}
path = os.path.join(directory, "multitask_metadata.pkl")
torch.save(attrs, path, pickle_protocol=4)
def _load_metadata(self):
"""
Load everything other than features/labels
"""
attrs_to_load = [
"mol_ids",
"smiles",
"labels_size",
"labels_dtype",
"dataset_length",
"_num_nodes_list",
"_num_edges_list",
]
path = os.path.join(self.data_path, "multitask_metadata.pkl")
with fsspec.open(path, "rb") as f:
attrs = torch.load(path)
if not set(attrs_to_load).issubset(set(attrs.keys())):
raise ValueError(
f"The metadata in the cache at {self.data_path} does not contain the right information. "
f"This may be because the cache was prepared using an earlier version of Graphium. "
f"You can try deleting the cache and running the data preparation again. "
f"\nMetadata keys found: {attrs.keys()}"
f"\nMetadata keys required: {attrs_to_load}"
)
for attr, value in attrs.items():
setattr(self, attr, value)
if self.save_smiles_and_ids:
if self.smiles is None or self.mol_ids is None:
logger.warning(
f"Argument `save_smiles_and_ids` is set to {self.save_smiles_and_ids} but metadata in the cache at {self.data_path} does not contain smiles and mol_ids. "
f"This may be because `Datamodule.prepare_data(save_smiles_and_ids=False)` was run followed by `Datamodule.setup(save_smiles_and_ids=True)`. "
f"When loading from cached files, the `save_smiles_and_ids` argument of `Datamodule.setup()` is superseeded by the `Datamodule.prepare_data()`. "
)
def __len__(self):
r"""
Returns the number of molecules
"""
return self.dataset_length
@property
def num_nodes_list(self):
"""
The number of nodes per graph
"""
if self._num_nodes_list is None:
if len(self) == 0:
self._num_nodes_list = []
else:
self._num_nodes_list = get_num_nodes_per_graph(self.features)
return self._num_nodes_list
@property
def num_edges_list(self):
"""
The number of edges per graph
"""
if self._num_edges_list is None:
if len(self) == 0:
self._num_edges_list = []
else:
self._num_edges_list = get_num_edges_per_graph(self.features)
return self._num_edges_list
@property
def num_graphs_total(self):
r"""
number of graphs (molecules) in the dataset
"""
return len(self)
@property
def num_nodes_total(self):
"""Total number of nodes for all graphs"""
if len(self) == 0:
return
return sum(self.num_nodes_list)
@property
def max_num_nodes_per_graph(self):
"""Maximum number of nodes per graph"""
if len(self) == 0:
return
return max(self.num_nodes_list)
@property
def std_num_nodes_per_graph(self):
"""Standard deviation of number of nodes per graph"""
if len(self) == 0:
return
return np.std(self.num_nodes_list)
@property
def min_num_nodes_per_graph(self):
"""Minimum number of nodes per graph"""
if len(self) == 0:
return
return min(self.num_nodes_list)
@property
def mean_num_nodes_per_graph(self):
"""Average number of nodes per graph"""
if len(self) == 0:
return
return self.num_nodes_total / self.num_graphs_total
@property
def num_edges_total(self):
"""Total number of edges for all graphs"""
if len(self) == 0:
return
return sum(self.num_edges_list)
@property
def max_num_edges_per_graph(self):
"""Maximum number of edges per graph"""
if len(self) == 0:
return
return max(self.num_edges_list)
@property
def min_num_edges_per_graph(self):
"""Minimum number of edges per graph"""
if len(self) == 0:
return
return min(self.num_edges_list)
@property
def std_num_edges_per_graph(self):
"""Standard deviation of number of nodes per graph"""
if len(self) == 0:
return
return np.std(self.num_edges_list)
@property
def mean_num_edges_per_graph(self):
"""Average number of edges per graph"""
if len(self) == 0:
return
return self.num_edges_total / self.num_graphs_total
def __getitem__(self, idx):
r"""
get the data for at the specified index
Parameters:
idx: The index of the data to retrieve
Returns:
A dictionary containing the data for the specified index with keys "mol_ids", "smiles", "labels", and "features"
"""
datum = {}
if self.dataloading_from == "disk":
data_dict = self.load_graph_from_index(idx)
datum["features"] = data_dict["graph_with_features"]
datum["labels"] = data_dict["labels"]
if "smiles" in data_dict.keys():
datum["smiles"] = data_dict["smiles"]
else:
if self.mol_ids is not None:
datum["mol_ids"] = self.mol_ids[idx]
if self.smiles is not None:
datum["smiles"] = self.smiles[idx]
if self.labels is not None:
datum["labels"] = self.labels[idx]
if self.features is not None:
datum["features"] = self.features[idx]
return datum
def load_graph_from_index(self, data_idx):
r"""
load the graph (in pickle file) from the disk
Parameters:
data_idx: The index of the data to retrieve
Returns:
A dictionary containing the data for the specified index with keys "graph_with_features", "labels" and "smiles" (optional).
"""
filename = os.path.join(
self.data_path, format(data_idx // 1000, "04d"), format(data_idx, "07d") + ".pkl"
)
with fsspec.open(filename, "rb") as f:
data_dict = torch.load(f)
return data_dict
def merge(
self, datasets: Dict[str, SingleTaskDataset]
) -> Tuple[List[str], List[str], List[Dict[str, Any]], List[Any]]:
r"""This function merges several single task datasets into a multitask dataset.
The idea: for each of the smiles, labels, features and tasks, we create a corresponding list that concatenates these items across all tasks.
In particular, for any index, the elements in the smiles, labels, features and task lists at that index will correspond to each other (i.e. match up).
Over this list of all smiles (which we created by concatenating the smiles across all tasks), we compute their molecular ID using functions from Datamol.
Once again, we will have a list of molecular IDs which is the same size as the list of smiles, labels, features and tasks.
We then use numpy's `unique` function to find the exact list of unique molecular IDs as these will identify the molecules in our dataset. We also get the
inverse from numpy's `unique`, which will allow us to index in addition to the list of all molecular IDs, the list of all smiles, labels, features and tasks.
Finally, we use this inverse to construct the list of list of smiles, list of label dictionaries (indexed by task) and the list of features such that
the indices match up. This is what is needed for the `get_item` function to work.
Parameters:
datasets: A dictionary of single-task datasets
Returns:
A tuple of (list of molecular IDs, list of smiles, list of label dictionaries, list of features)
"""
# Get all the smiles, labels, features and tasks.
all_lists = self._get_all_lists_ids(datasets=datasets)
mol_ids, inv = self._get_inv_of_mol_ids(all_mol_ids=all_lists["mol_ids"])
# Store the smiles.
smiles = [[] for _ in range(len(mol_ids))]
for all_idx, unique_idx in enumerate(inv):
smiles[unique_idx].append(all_lists["smiles"][all_idx])
# Store the labels.
labels = [Data() for _ in range(len(mol_ids))]
for all_idx, unique_idx in enumerate(inv):
task: str = all_lists["tasks"][all_idx]
label = all_lists["labels"][all_idx]
labels[unique_idx][task] = label
if all_idx < len(all_lists["features"]):
features = all_lists["features"][all_idx]
labels[unique_idx]["x"] = torch.empty(
(features.num_nodes, 1)
) # IPU is not happy with zero-sized tensors, so use shape (features.num_nodes, 1) here
labels[unique_idx]["edge_index"] = torch.empty((2, features.num_edges))
# Store the features
if len(all_lists["features"]) > 0:
features = [-1 for i in range(len(mol_ids))]
for all_idx, unique_idx in enumerate(inv):
features[unique_idx] = all_lists["features"][all_idx]
return mol_ids, smiles, labels, features
else:
return mol_ids, smiles, labels
def _get_all_lists_ids(self, datasets: Dict[str, SingleTaskDataset]) -> Dict[str, Any]:
all_smiles = []
all_features = []
all_labels = []
all_mol_ids = []
all_tasks = []
for task, ds in datasets.items():
if len(ds) == 0:
continue
# Get data from single task dataset
ds_smiles = [ds[i]["smiles"] for i in range(len(ds))]
ds_labels = [ds[i]["labels"] for i in range(len(ds))]
if "unique_ids" in ds[0].keys():
ds_mol_ids = [ds[i]["unique_ids"] for i in range(len(ds))]
else:
ds_mol_ids = smiles_to_unique_mol_ids(
ds_smiles,
n_jobs=self.n_jobs,
featurization_batch_size=self.featurization_batch_size,
backend=self.backend,
progress=self.progress,
progress_desc=f"{task}: mol to ids",
)
if "features" in ds[0]:
ds_features = [ds[i]["features"] for i in range(len(ds))]
else:
ds_features = None
all_smiles.extend(ds_smiles)
all_labels.extend(ds_labels)
all_mol_ids.extend(ds_mol_ids)
if ds_features is not None:
all_features.extend(ds_features)
task_list = [task] * ds.__len__()
all_tasks.extend(task_list)
all_lists = {
"smiles": all_smiles,
"features": all_features,
"labels": all_labels,
"mol_ids": all_mol_ids,
"tasks": all_tasks,
}
return all_lists
def _get_inv_of_mol_ids(self, all_mol_ids):
mol_ids, inv = np.unique(all_mol_ids, return_inverse=True)
return mol_ids, inv
def _find_valid_label(self, task, ds):
r"""
For a given dataset, find a genuine label for that dataset
"""
valid_label = None
for i in range(len(ds)):
if ds[i] is not None:
valid_label = ds[i]["labels"]
break
if valid_label is None:
raise ValueError(f"Dataset for task {task} has no valid labels.")
return valid_label
def set_label_size_dict(self, datasets: Dict[str, SingleTaskDataset]):
r"""
This gives the number of labels to predict for a given task.
"""
task_labels_size = {}
for task, ds in datasets.items():
if len(ds) == 0:
continue
valid_label = self._find_valid_label(task, ds)
# Assume for a fixed task, the label dimension is the same across data points
torch_label = torch.as_tensor(valid_label)
# First dimension is graph-specific
task_labels_size[task] = torch_label.size()
return task_labels_size
def set_label_dtype_dict(self, datasets: Dict[str, SingleTaskDataset]):
r"""
Gets correct dtype for a given label
"""
task_labels_dtype = {}
for task, ds in datasets.items():
if len(ds) == 0:
continue
valid_label = self._find_valid_label(task, ds)
torch_label = torch.as_tensor(valid_label)
task_labels_dtype[task] = torch_label.dtype
return task_labels_dtype
def __repr__(self) -> str:
"""
summarizes the dataset in a string
Returns:
A string representation of the dataset.
"""
if len(self) == 0:
out_str = (
f"-------------------\n{self.__class__.__name__}\n"
+ f"\tabout = {self.about}\n"
+ f"\tnum_graphs_total = {self.num_graphs_total}\n"
+ f"-------------------\n"
)
return out_str
# Faster to compute the statistics if we unbatch first.
features = self.features
if isinstance(self.features, Batch):
self.features = self.features.to_data_list()
out_str = (
f"-------------------\n{self.__class__.__name__}\n"
+ f"\tabout = {self.about}\n"
+ f"\tnum_graphs_total = {self.num_graphs_total}\n"
+ f"\tnum_nodes_total = {self.num_nodes_total}\n"
+ f"\tmax_num_nodes_per_graph = {self.max_num_nodes_per_graph}\n"
+ f"\tmin_num_nodes_per_graph = {self.min_num_nodes_per_graph}\n"
+ f"\tstd_num_nodes_per_graph = {self.std_num_nodes_per_graph}\n"
+ f"\tmean_num_nodes_per_graph = {self.mean_num_nodes_per_graph}\n"
+ f"\tnum_edges_total = {self.num_edges_total}\n"
+ f"\tmax_num_edges_per_graph = {self.max_num_edges_per_graph}\n"
+ f"\tmin_num_edges_per_graph = {self.min_num_edges_per_graph}\n"
+ f"\tstd_num_edges_per_graph = {self.std_num_edges_per_graph}\n"
+ f"\tmean_num_edges_per_graph = {self.mean_num_edges_per_graph}\n"
+ f"-------------------\n"
)
# Restore the original features.
self.features = features
return out_str
class FakeDataset(MultitaskDataset):
"""
A dataset to hold the fake data.
"""
def __init__(
self, datasets: Dict[str, SingleTaskDataset], num_mols: int = 1234, indexing_same_elem: bool = False
):
"""
Parameters:
datasets:
A dictionary of datasets. The keys are the task names and the values are the datasets.
num_mols:
The number of molecules to generate. In reality, it is the same molecule,
but `num_mols` will change the length of the dataset.
indexing_same_elem:
If True, the same molecule is used for all samples.
Otherwise, a deepcopied molecule is used for each sample.
"""
self.indexing_same_elem = indexing_same_elem
self.num_mols = num_mols
self.num_datasets = len(datasets)
self.about = "FakeDatasets"
task = next(iter(datasets))
if "features" in datasets[task][0]:
self.mol_ids, self.smiles, self.labels, self.features = self.merge(datasets)
if self.indexing_same_elem is False:
self.mol_ids, self.smiles, self.labels, self.features = self.deepcopy_mol(
self.mol_ids, self.smiles, self.labels, self.features
)
else:
self.mol_ids, self.smiles, self.labels = self.merge(datasets)
if self.indexing_same_elem is False:
self.mol_ids, self.smiles, self.labels, _ = self.deepcopy_mol(
self.mol_ids, self.smiles, self.labels
)
self.labels_size = self.set_label_size_dict(datasets)
self.labels_dtype = self.set_label_dtype_dict(datasets)
self.features = self.features
def _get_inv_of_mol_ids(self, all_mol_ids):
# The generated data is a single molecule duplicated
mol_ids = np.array(all_mol_ids)
inv = [_ for _ in range(len(mol_ids) // self.num_datasets)] * self.num_datasets
mol_ids = np.unique(inv)
return mol_ids, inv
def deepcopy_mol(self, mol_ids, labels, smiles, features=None):
"""
Create a deepcopy of the single molecule num_mols times
Args:
mol_ids (array): The single value for the mol ID
labels (List[Dict]): List containing one dict with the label name-value pairs
smiles (List[List[str]]): List of list containing SMILE sting
features (List[Data], optional): list containing Data object. Defaults to None.
Returns:
The deep copy of the inputs
"""
logger.info("Duplicating the single dataset element...")
mol_ids = [deepcopy(mol_ids[0]) for _ in range(self.num_mols)]
logger.info("Finished `mol_ids`")
labels = [deepcopy(labels[0]) for _ in range(self.num_mols)]
logger.info("Finished `labels`")
smiles = [deepcopy(smiles[0]) for _ in range(self.num_mols)]
logger.info("Finished `smiles`")
if features is not None:
features = [deepcopy(features[0]) for _ in range(self.num_mols)]
logger.info("Finished `features`")
return mol_ids, labels, smiles, features
def __len__(self):
r"""
Returns the number of molecules
"""
return self.num_mols
def __getitem__(self, idx):
r"""
get the data for at the specified index
Parameters:
idx: The index of the data to retrieve
Returns:
A dictionary containing the data for the specified index with keys "mol_ids", "smiles", "labels", and "features"
"""
datum = {}
if self.indexing_same_elem is True:
# If using a single memory location override the idx value passed
idx = 0
if self.labels is not None:
datum["labels"] = self.labels[idx]
if self.features is not None:
datum["features"] = self.features[idx]
return datum
def get_num_nodes_per_graph(graphs):
r"""
number of nodes per graph
"""
if isinstance(graphs, Batch):
graphs = graphs.to_data_list()
counts = [graph.num_nodes for graph in graphs]
return counts
def get_num_edges_per_graph(graphs):
r"""
number of edges per graph
"""
if isinstance(graphs, Batch):
graphs = graphs.to_data_list()
counts = [graph.num_edges for graph in graphs]
return counts
================================================
FILE: graphium/data/make_data_splits/make_train_val_test_splits_large.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Tuple, Dict\n",
"\n",
"import random\n",
"import pandas as pd\n",
"import numpy as np\n",
"import torch\n",
"from os.path import join\n",
"import datamol as dm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"BASE_PATH = \"../neurips2023/large-dataset/\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"df_L1000_VCAP.shape (15220, 984)\n",
"df_L1000_MCF7.shape (11622, 984)\n",
"df_PCBA.shape (1563664, 1332)\n",
"df_PCQM4M.shape (3810323, 31)\n"
]
}
],
"source": [
"df_L1000_VCAP = pd.read_csv(join(BASE_PATH, \"LINCS_L1000_VCAP_0-4.csv.gz\"))\n",
"print(\"df_L1000_VCAP.shape\", df_L1000_VCAP.shape)\n",
"df_L1000_MCF7 = pd.read_csv(join(BASE_PATH, \"LINCS_L1000_MCF7_0-4.csv.gz\"))\n",
"print(\"df_L1000_MCF7.shape\", df_L1000_MCF7.shape)\n",
"df_PCBA = pd.read_parquet(join(BASE_PATH, \"PCBA_1328_1564k.parquet\"))\n",
"print(\"df_PCBA.shape\", df_PCBA.shape)\n",
"df_PCQM4M = pd.read_parquet(join(BASE_PATH, \"PCQM4M_G25_N4.parquet\"))\n",
"print(\"df_PCQM4M.shape\", df_PCQM4M.shape)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def smiles_to_unique_ids(smiles: List):\n",
" return dm.parallelized_with_batches(loop_smiles_to_unique_ids, smiles, batch_size=100, n_jobs=32, progress=True)\n",
"\n",
"def loop_smiles_to_unique_ids(smiles: List):\n",
" unique_ids = []\n",
" for s in smiles:\n",
" if not isinstance(s, str):\n",
" unique_ids.append(None)\n",
" continue\n",
" mol = dm.to_mol(s)\n",
" if mol is None:\n",
" unique_ids.append(None)\n",
" else:\n",
" unique_ids.append(dm.unique_id(mol))\n",
" return unique_ids"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a1bdefc329d14d288510253cc36867a8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/38103 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[16:03:00] WARNING: not removing hydrogen atom without neighbors\n",
"[16:03:10] WARNING: not removing hydrogen atom without neighbors\n",
"[16:03:37] WARNING: not removing hydrogen atom without neighbors\n",
"[16:03:37] WARNING: not removing hydrogen atom without neighbors\n",
"[16:03:37] WARNING: not removing hydrogen atom without neighbors\n",
"[16:03:44] WARNING: not removing hydrogen atom without neighbors\n",
"[16:03:44] WARNING: not removing hydrogen atom without neighbors\n",
"[16:03:44] WARNING: not removing hydrogen atom without neighbors\n",
"[16:03:44] WARNING: not removing hydrogen atom without neighbors\n",
"[16:03:44] WARNING: not removing hydrogen atom without neighbors\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7087b7c562ca46ce837695034e38cb57",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/152 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted'[16:03:49] for input: 'restricted'\n",
"SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f651dcb510004bb7afef0b02fb1e3b20",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/116 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n",
"[16:03:49] SMILES Parse Error: syntax error while parsing: restricted\n",
"[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1ab2cca58d1f4a36ac5994b897cfe940",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/15636 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[16:03:51] WARNING: not removing hydrogen atom without neighbors\n",
"[16:03:52] WARNING: not removing hydrogen atom without neighbors\n",
"[16:03:53] WARNING: not removing hydrogen atom without neighbors\n",
"[16:03:55] WARNING: not removing hydrogen atom without neighbors\n",
"[16:03:59] WARNING: not removing hydrogen atom without neighbors\n",
"[16:03:59] WARNING: not removing hydrogen atom without neighbors\n",
"[16:03:59] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:02] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:02] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:07] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:08] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:08] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:08] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:08] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:09] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:12] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:13] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:13] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:13] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:13] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:13] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:13] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:15] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:15] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:15] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:15] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:15] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:15] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:15] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:15] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:15] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:15] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:16] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:20] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:20] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:22] WARNING: not removing hydrogen atom without neighbors\n",
"[16:04:22] WARNING: not removing hydrogen atom without neighbors\n"
]
}
],
"source": [
"unique_ids_QM = smiles_to_unique_ids(df_PCQM4M[\"ordered_smiles\"])\n",
"unique_ids_L1000_VCAP = smiles_to_unique_ids(df_L1000_VCAP[\"SMILES\"])\n",
"unique_ids_L1000_MCF7 = smiles_to_unique_ids(df_L1000_MCF7[\"SMILES\"])\n",
"unique_ids_PCBA = smiles_to_unique_ids(df_PCBA[\"SMILES\"])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"L1000_VCAP 726\n",
"L1000_MCF7 1023\n",
"PCBA 56512\n"
]
}
],
"source": [
"# Check the number of unique ids that intersect between unique_ids_QM and the other columns\n",
"intersection_VCAP = set(unique_ids_QM) & set(unique_ids_L1000_VCAP)\n",
"print(\"L1000_VCAP\", len(intersection_VCAP))\n",
"\n",
"intersection_MCF7 = set(unique_ids_QM) & set(unique_ids_L1000_MCF7)\n",
"print(\"L1000_MCF7\", len(intersection_MCF7))\n",
"\n",
"intersection_PCBA = set(unique_ids_QM) & set(unique_ids_PCBA)\n",
"print(\"PCBA\", len(intersection_PCBA))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"QM_in_VCAP 1443 VCAP_in_QM 748\n",
"QM_in_MCF7 2373 MCF7_in_QM 1065\n",
"QM_in_PCBA 91174 PCBA_in_QM 56512\n"
]
}
],
"source": [
"def find_indices(list_1, list_2):\n",
" intersection = set(list_1) & set(list_2)\n",
" intersection = {elem for elem in intersection if elem is not None}\n",
" is_2_in_1 = np.isin(list_2, list(intersection))\n",
" is_1_in_2 = np.isin(list_1, list(intersection))\n",
" return is_1_in_2, is_2_in_1\n",
"\n",
"QM_in_VCAP, VCAP_in_QM = find_indices(unique_ids_QM, unique_ids_L1000_VCAP)\n",
"print(\"QM_in_VCAP\", sum(QM_in_VCAP), \"VCAP_in_QM\", sum(VCAP_in_QM))\n",
"\n",
"QM_in_MCF7, MCF7_in_QM = find_indices(unique_ids_QM, unique_ids_L1000_MCF7)\n",
"print(\"QM_in_MCF7\", sum(QM_in_MCF7), \"MCF7_in_QM\", sum(MCF7_in_QM))\n",
"\n",
"QM_in_PCBA, PCBA_in_QM = find_indices(unique_ids_QM, unique_ids_PCBA)\n",
"print(\"QM_in_PCBA\", sum(QM_in_PCBA), \"PCBA_in_QM\", sum(PCBA_in_QM))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"test_seen_VCAP = np.where(VCAP_in_QM)[0].tolist()\n",
"test_seen_MCF7 = np.where(MCF7_in_QM)[0].tolist()\n",
"test_seen_PCBA = np.where(PCBA_in_QM)[0].tolist()\n",
"\n",
"train_QM_seen = np.where(QM_in_VCAP | QM_in_MCF7 | QM_in_PCBA)[0].tolist()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def make_random_splits(num_elem, val_ratio, test_ratio, seed=42, ignore_idx=None):\n",
" \"\"\"Make random splits of the data into train, validation, and test sets.\n",
"\n",
" Args:\n",
" num_elem (int): Number of elements in the dataset.\n",
" val_ratio (float): Ratio of the dataset to use for validation.\n",
" test_ratio (float): Ratio of the dataset to use for testing.\n",
" seed (int): Random seed.\n",
" ignore_idx (list): List of indices to ignore.\n",
"\n",
" Returns:\n",
" train_idx (list): List of indices for the training set.\n",
" val_idx (list): List of indices for the validation set.\n",
" test_idx (list): List of indices for the test set.\n",
" \"\"\"\n",
" # Create a list of indices\n",
" idx = list(range(num_elem))\n",
" # Remove the indices to ignore\n",
" if ignore_idx is not None:\n",
" idx = list(set(idx) - set(ignore_idx))\n",
" num_elem = len(idx)\n",
" # Shuffle the list of indices\n",
" random.seed(seed)\n",
" random.shuffle(idx)\n",
" # Compute the number of elements in each set\n",
" num_val = int(num_elem * val_ratio)\n",
" num_test = int(num_elem * test_ratio)\n",
" num_train = num_elem - num_val - num_test\n",
" # Split the list of indices into three sets\n",
" train_idx = idx[:num_train]\n",
" val_idx = idx[num_train:(num_train + num_val)]\n",
" test_idx = idx[(num_train + num_val):]\n",
" # Return the three lists of indices\n",
" return train_idx, val_idx, test_idx"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"train_VCAP, val_VCAP, test_VCAP = make_random_splits(len(df_L1000_VCAP), 0.04, 0.04, seed=42, ignore_idx=test_seen_VCAP)\n",
"train_MCF7, val_MCF7, test_MCF7 = make_random_splits(len(df_L1000_MCF7), 0.04, 0.04, seed=42, ignore_idx=test_seen_MCF7)\n",
"train_PCBA, val_PCBA, test_PCBA = make_random_splits(len(df_PCBA), 0.04, 0.04, seed=42, ignore_idx=test_seen_PCBA)\n",
"train_QM, val_QM, test_QM = make_random_splits(len(df_PCQM4M), 0.04, 0.04, seed=42, ignore_idx=train_QM_seen)\n",
"train_QM = np.concatenate((train_QM, train_QM_seen))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def make_random_splits_file(df, out_file, train_idx, val_idx, test_idx, test_seen_idx=None):\n",
" # Save the splits\n",
" if test_seen_idx is None:\n",
" splits_dict = {\"train\": train_idx, \"val\": val_idx, \"test\": test_idx}\n",
" else:\n",
" splits_dict = {\"train\": train_idx, \"val\": val_idx, \"test\": test_idx, \"test_seen\": test_seen_idx}\n",
"\n",
" # Check the splits validity\n",
" \n",
" assert len(set(train_idx).intersection(set(val_idx))) == 0\n",
" assert len(set(train_idx).intersection(set(test_idx))) == 0\n",
" assert len(set(val_idx).intersection(set(test_idx))) == 0\n",
" assert len(train_idx) > 0\n",
" assert len(val_idx) > 0\n",
" assert len(test_idx) > 0\n",
"\n",
" if test_seen_idx is None:\n",
" assert len(df) == len(train_idx) + len(val_idx) + len(test_idx), f\"{len(df)} != {len(train_idx)} + {len(val_idx)} + {len(test_idx)}\"\n",
" print(out_file, \"train\", len(train_idx), \"val\", len(val_idx), \"test\", len(test_idx))\n",
" else:\n",
" assert len(test_seen_idx) > 0\n",
" assert len(set(train_idx).intersection(set(test_seen_idx))) == 0\n",
" assert len(set(val_idx).intersection(set(test_seen_idx))) == 0\n",
" assert len(set(test_idx).intersection(set(test_seen_idx))) == 0\n",
" assert len(df) == len(train_idx) + len(val_idx) + len(test_idx) + len(test_seen_idx), f\"{len(df)} != {len(train_idx)} + {len(val_idx)} + {len(test_idx)} + {len(test_seen_idx)}\"\n",
" print(out_file, \"train\", len(train_idx), \"val\", len(val_idx), \"test\", len(test_idx), \"test_seen\", len(test_seen_idx))\n",
"\n",
" torch.save(splits_dict, out_file)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"l1000_vcap_random_splits.pt train 13316 val 578 test 578 test_seen 748\n",
"l1000_mcf7_random_splits.pt train 9713 val 422 test 422 test_seen 1065\n",
"pcba_1328_random_splits.pt train 1386580 val 60286 test 60286 test_seen 56512\n",
"pcqm4m_g25_n4_random_splits.pt train 3512805 val 148759 test 148759\n"
]
}
],
"source": [
"make_random_splits_file(df_L1000_VCAP, \"l1000_vcap_random_splits.pt\", train_VCAP, val_VCAP, test_VCAP, test_seen_VCAP)\n",
"make_random_splits_file(df_L1000_MCF7, \"l1000_mcf7_random_splits.pt\", train_MCF7, val_MCF7, test_MCF7, test_seen_MCF7)\n",
"make_random_splits_file(df_PCBA, \"pcba_1328_random_splits.pt\", train_PCBA, val_PCBA, test_PCBA, test_seen_PCBA)\n",
"make_random_splits_file(df_PCQM4M, \"pcqm4m_g25_n4_random_splits.pt\", train_QM, val_QM, test_QM)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "graphium",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: graphium/data/make_data_splits/make_train_val_test_splits_small.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import pandas as pd\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def make_random_splits(num_elem, val_ratio, test_ratio, seed=42):\n",
" \"\"\"Make random splits of the data into train, validation, and test sets.\n",
"\n",
" Args:\n",
" num_elem (int): Number of elements in the dataset.\n",
" val_ratio (float): Ratio of the dataset to use for validation.\n",
" test_ratio (float): Ratio of the dataset to use for testing.\n",
"\n",
" Returns:\n",
" train_idx (list): List of indices for the training set.\n",
" val_idx (list): List of indices for the validation set.\n",
" test_idx (list): List of indices for the test set.\n",
" \"\"\"\n",
" # Create a list of indices\n",
" idx = list(range(num_elem))\n",
" # Shuffle the list of indices\n",
" random.seed(seed)\n",
" random.shuffle(idx)\n",
" # Compute the number of elements in each set\n",
" num_val = int(num_elem * val_ratio)\n",
" num_test = int(num_elem * test_ratio)\n",
" num_train = num_elem - num_val - num_test\n",
" # Split the list of indices into three sets\n",
" train_idx = idx[:num_train]\n",
" val_idx = idx[num_train:num_train + num_val]\n",
" test_idx = idx[num_train + num_val:]\n",
" # Return the three lists of indices\n",
" return train_idx, val_idx, test_idx"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def make_random_splits_from_file(in_file, out_file, val_ratio, test_ratio, seed=42):\n",
" # Make the splits for QM9\n",
" df = pd.read_csv(in_file, usecols=[\"smiles\"])\n",
" train_idx, val_idx, test_idx = make_random_splits(len(df), val_ratio, test_ratio, seed=seed)\n",
" # Save the splits\n",
" splits_dict = {\"train\": train_idx, \"val\": val_idx, \"test\": test_idx}\n",
"\n",
" # Check the splits validity\n",
" assert len(set(train_idx).intersection(set(val_idx))) == 0\n",
" assert len(set(train_idx).intersection(set(test_idx))) == 0\n",
" assert len(set(val_idx).intersection(set(test_idx))) == 0\n",
" assert len(train_idx) + len(val_idx) + len(test_idx) == len(df)\n",
" \n",
" torch.save(splits_dict, out_file)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"make_random_splits_from_file(\"qm9.csv.gz\", \"qm9_random_splits.pt\", 0.1, 0.1)\n",
"make_random_splits_from_file(\"Tox21-7k-12-labels.csv.gz\", \"Tox21_random_splits.pt\", 0.1, 0.1)\n",
"make_random_splits_from_file(\"ZINC12k.csv.gz\", \"ZINC12k_random_splits.pt\", 0.1, 0.1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "graphium",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: graphium/data/micro_ZINC/__init__.py
================================================
================================================
FILE: graphium/data/micro_ZINC/micro_ZINC.csv
================================================
SMILES,SA,logp,score
Cc1cccc(COC(=O)Nc2ccc(-c3cn[nH]c3)cc2OCCN2CCCC2)c1.O=C(O)C(F)(F)F.O=C(O)C(F)(F)F,2.896514950230738,5.87502,2.978505049769262
CCC(C)C1=C(C=CC=C1)OCCC[N]2C=CN=C2.O=C(O)C(=O)O,2.66931991795572,3.0213,0.35198008204428
CCc1ccc(C(=O)/C(=C(/S)NC2CC2)[n+]2ccc(CC)cc2)cc1,2.775189478,3.7897,1.014510522
C=CCOc1ccc(C(F)(F)F)cc1C(=O)NC[C@H]1CCC[C@H]1O,3.071497898,3.161,0.089502102
COc1cc(OC)cc([C@@H](NC(=O)Cn2ccccc2=O)c2nccn2C)c1,2.84000896,1.5048,-1.33520896
CN1CCC[C@@]2(CCN(C(=O)CN3CCNC(=O)C3)C2)C1=O,3.605168457,-1.1109,-4.716068457
COc1ccc(Cl)cc1NC(=O)c1nn(C)cc1[N+](=O)[O-],2.132664673,2.2426,0.109935327
Cc1nn(-c2ccccc2)c(O)c1/C=[NH+]/Cc1ccncc1,3.117639223,0.98102,-2.136619223
C=CCN1C(=O)C(C)(C)COc2ccc(NC(=O)C3(c4ccc(OC)cc4)CCOCC3)cc21,2.744789746,4.3197,1.574910254
CON(C)C(=O)c1cnn2c(C)cc(C)nc12,2.547847184,0.97954,-1.568307184
NC(=O)[C@@H](NC(=O)CCCc1cccs1)c1ccccc1,2.444743614,2.4136,-0.031143614
COc1cccc(CN2CCC[NH+](CC(=O)Nc3ccc(F)cc3)S2(=O)=O)c1,3.546311784,0.8084,-2.737911784
Cc1ccc(-c2nc3ccc(C)c(C)c3[nH]2)nc1,2.342516783,3.55016,1.207643217
Cc1n[nH]c(SCC(=O)Nc2cccc(C#Cc3ccccn3)c2)n1,2.685016252,2.63872,-0.046296252
COC(=O)c1cncc(C(=O)Nc2cccc(COC(C)C)c2)c1,2.04419285,3.0455,1.00130715
Cc1ccc(C)c(Oc2ccc(CNC(=O)N3CCCSCC3)cn2)c1,2.369895336,4.13924,1.769344664
CC[NH2+][C@@]1(C(=O)[O-])CCC[C@H](Oc2ccc(CC)cc2)C1,4.165551538,0.6424,-3.523151538
CC(=O)N1c2ccc(S(=O)(=O)N3CCCC3)cc2C[C@H]1C(=O)NCC[NH+](C)C1CCCCC1,3.989079408,0.7123,-3.276779408
C=CCN1CC(=O)N2[C@@H](Cc3c([nH]c4ccccc34)[C@@H]2c2ccccc2OC)C1=O,3.287567863,3.0474,-0.240167863
COc1ccc(-c2noc3ncnc(N4CCC[C@H](C(=O)[O-])C4)c23)cc1,3.152874099,1.2597,-1.893174099
N#Cc1cc(NC(=O)Cc2ccccc2Cl)ccc1N1CCC(O)CC1,2.173785623,3.35398,1.180194377
Cc1ncc(-c2ccc(S(=O)(=O)NC3CCCCCC3)s2)o1,2.513723357,3.71262,1.198896643
Cc1ccc(NC(=O)c2ccc(F)cc2F)cc1S(=O)(=O)Nc1ccc(Cl)cc1,1.947448768,4.97972,3.032271232
CNC(=O)[C@@H]1CCC[NH+]1Cc1ccc(C)c(F)c1,4.391338086,0.42742,-3.963918086
CN1C(=O)N[C@@H](c2cccc([N+](=O)[O-])c2)C2=C1CN(c1ccc(F)cc1)C2=O,2.956072154,2.7309,-0.225172154
Cc1cc(CC(=O)N[C@H](c2ccc(F)cc2)C2CCC2)no1,2.665131639,3.32222,0.657088361
Cc1cc(N2CCN(CC(=O)NC3CC3)CC2)nc(-c2ccc([N+](=O)[O-])cc2)n1,2.249569987,1.76082,-0.488749987
Cc1cc(C(=O)N2CCN(C(=O)N[C@H]3CC(=O)N(C4CC4)C3)CC2)c(C)o1,2.921725033,1.12714,-1.794585033
O=c1c2c3nc4ccccc4nc3n(CCC3=CCCCC3)c2ncn1C[C@H]1CCCO1,3.257260117,4.3639,1.106639883
CC[NH+](C/C=C/c1ccc(C#N)cc1)C[C@H](C)C#N,4.36293875,1.63596,-2.72697875
COc1c(Cl)cc(C[NH2+][C@@H](Cc2c[nH]c3ccccc23)C(=O)[O-])cc1Cl,3.725368177,1.9079,-1.817468177
COc1ccccc1[C@H](C)NC(=O)C[C@H]1C[NH2+]CCO1,3.812797047,0.2247,-3.588097047
Cc1ccc(NC(=O)C(=O)NCCc2ccccc2F)c(C)c1,1.826753956,2.73994,0.913186044
CC(C)(C)OC(=O)N[C@H]1CCN(c2cc(-c3cccs3)n[nH]2)C1,3.228025092,3.2416,0.013574908
CC[C@H](C)C[NH+]1CCCC[C@@H]1C(=O)NC(C)(C)C,4.787067722,1.3846,-3.402467722
COC(=O)c1c(S(=O)(=O)NC2CC2)sc2c1CCN(Cc1ccc3ccccc3c1)C2,2.485526606,3.6869,1.201373394
Cc1cc(C)cc(NC(=O)[C@@H](Sc2nnnn2C2CC2)c2ccccc2)c1,2.708478583,4.09694,1.388461417
C=CCn1c([S-])nnc1-c1sc(NC(=O)c2cccc(NC(C)=O)c2)nc1C,2.943459024,3.01252,0.069060976
O=C(CNc1nc(C2CC2)no1)N1CCc2sccc2C1,2.781430253,2.0053,-0.776130253
Cc1ccccc1NC(=O)NCCn1c([N+](=O)[O-])cnc1C,2.258312512,2.22984,-0.028472512
C[C@H](NC=C(C#N)C#N)c1ccc(-n2cncn2)cc1,3.277555708,1.84896,-1.428595708
CN1C[C@H](C(=O)NC[C@H]2CCC(C)(C)c3ccccc32)CC1=O,3.259412472,2.4361,-0.823312472
CC(C)n1cccc1C(=O)N[C@H]1CCc2cc(N)ccc21,2.887282024,3.0685,0.181217976
CCN1C(=O)C(C#N)=C(C)/C(=C\Nc2ccc([N+](=O)[O-])cc2)C1=O,2.501067124,2.11938,-0.381687124
O=C(Nc1ccnn1Cc1cccc(Cl)c1)C1CCN(S(=O)(=O)c2ccccc2F)CC1,2.296421252,3.7633,1.466878748
COc1cc([C@@H]2N[C@H](C(=O)[O-])CS2)ccc1OC(=O)c1ccccc1,3.273511868,1.3679,-1.905611868
CCOc1ccc(C2=CCN(C(=O)c3ccccc3F)CC2)cc1,2.000346516,4.1539,2.153553484
CC(=O)N1CC[C@@H]([NH2+][C@@H](C)CSCC(C)C)C1,4.483674272,0.9483,-3.535374272
CCO[P@](=O)(CC(=O)[O-])c1ccccc1,3.510193788,0.3764,-3.133793788
O=C(CCc1nc2ccccc2c(=O)[nH]1)N1CCC[C@H]1C1CCCC1,2.774610155,3.0369,0.262289845
O=C(/C=C/SCc1ccco1)Nc1cccc(N2CCCC2=O)c1,2.584851266,3.792,1.207148734
CC(=O)c1c(C)[nH]c(C(=O)OCC(=O)N2C[C@H](C)C[C@@H](C)C2)c1C,3.068452859,2.49544,-0.573012859
Cc1ccc(C(=O)N2CCC[C@@H](C(=O)N(C)CC(=O)NC(C)C)C2)cc1,2.492863438,1.83022,-0.662643438
C[C@@H](C#N)Sc1nnc(C23CC4CC(CC(C4)C2)C3)o1,4.576896432,3.54158,-1.035316432
CC1(C)CCC[C@H]1n1c(N)[nH+]c2ccccc21,4.151966768,2.7888,-1.363166768
CCOc1ccc(C[NH+]2CCS[C@H]3COCC[C@@H]32)cc1OC,4.514122047,1.3831,-3.131022047
COc1cc(OC)c([C@@H](C)[NH2+]C[C@H](O)C2CCOCC2)cc1Cl,3.960984106,1.7691,-2.191884106
COC(C[C@@](C)(O)C[C@@H]1CCCN(S(C)(=O)=O)C1)OC,3.669896168,0.8081,-2.861796168
C=CCSc1nnc(C2CCOCC2)n1N,2.825013358,1.164,-1.661013358
O=C(Nc1c[nH]c2ccccc12)N1CCCC[C@@H]1CCO,2.674797385,2.9367,0.261902615
Cc1nc(-c2ccc(-c3noc(C4CC4)n3)cc2)cs1,2.133219774,4.04592,1.912700226
C[C@H]1CCCCN1C(=O)[C@@H](C)NC(=O)C(=O)Nc1ccnn1C(C)(C)C,3.418504414,1.4823,-1.936204414
c1cc2c(cc1N[C@@H]1CCOC3(CCC3)C1)CCC2,3.363061129,3.6889,0.325838871
O=C(Nc1ccc(CNC(=O)N2CCCCCCC2)cc1)c1ccco1,1.931563902,4.0076,2.076036098
CNC(=O)[C@H]1CCCCN1c1cccc(F)c1C(N)=S,2.82917551,1.5648,-1.26437551
O=C(NCc1nnc2n1CCC2)c1cc(-c2ccccc2)[nH]n1,2.376871447,1.5444,-0.832471447
Cc1nnc(-c2cc(CC(C)C)n(-c3cccc(C(=O)N[C@@H]4C[C@@H](C)CC(C)(C)C4)c3)n2)o1,3.548164288,5.37382,1.825655712
COc1cccc([C@@H]2CC(=O)C3=C(C2)NC(=O)C[C@H]3c2ccc(OC)c(OC)c2OC)c1,3.226440403,3.7253,0.498859597
C[NH2+]Cc1ccc(-c2cc(C)ccc2F)s1,3.128631552,2.55582,-0.572811552
CCC[C@H](NC(N)=O)C(=O)NC[C@@H]1CCCO1,2.869359841,0.1186,-2.750759841
COc1cc(F)cc(CNC(=O)[C@H]2CCCN2C(=O)Cc2ccccc2)c1,2.427176885,2.6842,0.257023115
CCc1cc(C#N)c(NC(=O)CCC(=O)N(CC)CC)s1,2.493551313,2.76928,0.275728687
C#Cc1cccc(NC(=O)C[NH+](C)[C@H](C)c2nnc(-c3ccccc3)o2)c1,3.779297449,1.9323,-1.846997449
CCOC[C@H](O)[C@](C)(CC)[NH+]1CCCC1,5.127905513,0.2312,-4.896705513
C[NH+](C)Cc1cccc(CNC(=O)NC[C@H](O)c2ccc(F)cc2)c1,3.24179101,1.003,-2.23879101
Cc1ccsc1C(=O)NCCNC(=O)c1ccco1,2.114582277,1.80932,-0.305262277
COC(=O)[C@H](C)N(Cc1ccccc1)C(=O)Cc1ccon1,2.852054716,1.8074,-1.044654716
C/C=C(/C)C(=O)NC[C@H]1C[C@@]12CCc1ccccc12,3.940919906,2.9729,-0.968019906
O=C1Nc2ccccc2Oc2cc([N+](=O)[O-])cc(Oc3ccc(Cl)cc3)c21,2.400946341,5.3985,2.997553659
CCc1onc(C)c1NC(=O)CCCC(C)(C)C,2.601600041,3.70032,1.098719959
COC(=O)C(C)(C)COc1ccc2c(c1)OC[C@@H]2[NH3+],3.576525387,0.94,-2.636525387
CC#CCCC(=O)Nc1cccc2c1C(=O)c1ccccc1C2=O,2.298959953,3.204,0.905040047
Cc1nc(C)c(S(=O)(=O)/N=C(\[O-])C[C@H]2CCCO2)s1,3.812988405,0.77654,-3.036448405
CCC(=O)N[C@H](CCSC)C(=O)Nc1ccc2[nH]c(C)cc2c1,2.732126257,3.06272,0.330593743
COCC[NH+](C)Cc1c(C)cc(C)c(C(C)=O)c1C,3.955103091,1.47556,-2.479543091
Cc1cscc1CNC(=O)N[C@@H](C)c1nc(C(=O)[O-])cs1,3.628458628,1.43692,-2.191538628
CNC(=O)NCC(=O)NC[C@H](c1cccnc1)C(C)C,2.795116489,0.8664,-1.928716489
COc1cccc(C(=O)N2CCC[C@H]2[C@H]2CCC[C@@H]2O)c1,3.002385916,2.4608,-0.541585916
COc1ccc2c(c1)[C@H]([NH2+][C@H](C)CCN1CCOCC1)CCCO2,3.789130146,1.5831,-2.206030146
COc1cccc(OCC(=O)N2CCN(c3nc4ccc(S(C)(=O)=O)cc4s3)CC2)c1,2.245396536,2.436,0.190603464
COc1cccc(-c2nc(C(=O)O[C@@H](C)CNC(C)=O)c(C)[nH]2)c1,2.86962049,2.07512,-0.79450049
CCO[C@H]1C[C@@H]1C(=O)Nc1ccc(Sc2nncs2)c(Cl)c1,3.445502237,3.7062,0.260697763
Cc1ccc(N(CC(=O)N2CCCC[C@H]2C)S(=O)(=O)c2c(C)nn(C)c2C)cc1,2.872766629,2.94166,0.068893371
COc1ccc(C(=O)Nc2nc(CC(=O)Nc3ccc(N(C)C)cc3)cs2)cc1,1.983649537,3.6512,1.667550463
O=c1c2ccccc2sn1-c1ncc(Br)s1,2.853247472,3.2712,0.417952528
Cc1ocnc1CNC(=O)N(C)Cc1ccc(OC(F)F)cc1,2.60278761,2.92602,0.32323239
COc1cc(Cl)c(C)cc1NC(=O)C(=O)NCc1ccccc1C[NH+](C)C,2.870575026,1.55642,-1.314155026
C[C@@H]([NH3+])c1cccc(Oc2cc(Br)ccc2Cl)c1,2.965149266,4.1977,1.232550734
CC(C)NS(=O)(=O)c1cccc(C(=O)NCc2ccco2)c1,1.961616674,1.8963,-0.065316674
CN(Cc1ncnn1C)C(=O)CSCc1ccccn1,2.529083407,1.1019,-1.427183407
CCOc1cc(/C=C(\C#N)C(=O)c2c[nH]c3cc(Cl)ccc23)ccc1OC,2.350767662,5.01848,2.667712338
CCOC(=O)Nc1ccc(NC(=O)N2C[C@@H](C)CC[C@H]2C)cc1C#N,3.067347555,3.77898,0.711632445
Cn1c(=O)oc2ccc(NC(=O)[C@@H]3CCN(C(=O)OC(C)(C)C)C3)cc21,2.748880773,2.327,-0.421880773
CC(C)N1CCO[C@@H]([C@H](O)c2cccs2)C1,3.356907136,1.8907,-1.466207136
CCN[C@H]1C[C@H](C)C[C@H](C)[C@@H]1[NH+]1C[C@@H](C)[C@@H](C)C1,5.501545361,1.5698,-3.931745361
CC1CCN(C(=O)N(CC[NH3+])C2CC2)CC1,3.063324,0.5446,-2.518724
COc1ccc(NC(=O)Cn2nc3c4scc(-c5ccc(F)cc5)c4ncn3c2=O)c(OC)c1,2.583698935,3.5677,0.984001065
CC(C)OC(=O)N1CCC(NC(=O)[C@H]2SCCc3ccccc32)CC1,3.087047482,3.1426,0.055552518
COc1ccc2c(c1)SC(=O)C2=O,2.401574365,1.5102,-0.891374365
COc1ccc(CCC(=O)N2CCC3(CC2)Nc2ccccc2-n2cccc23)cc1,2.913337418,4.3619,1.448562582
COc1ccccc1Cc1nnc(NC(=O)Cc2ccc(Br)cc2)s1,2.051226236,4.0812,2.029973764
CNC(=O)O/N=C(/N)c1ccc(Br)cc1,2.250317734,1.4254,-0.824917734
CC1(C)CN(C[C@@H](CS)c2ccccc2)CCO1,3.027019854,2.8108,-0.216219854
CCOC(=O)[C@H]1CCCN(C(=O)Cn2c(SC(F)F)nc3ccccc32)C1,2.799716077,3.1527,0.352983923
[NH3+][C@@H](Cc1ccccn1)c1ccc(Cl)cn1,3.324055225,1.6557,-1.668355225
CCOCCN(C)C(=O)C(=O)Nc1cccnc1-c1ccccc1,2.256607012,2.182,-0.074607012
CCn1cc([C@H](NC(=O)c2ccc(OC)c(OC)c2)c2ccc(F)cc2)c2cc[nH+]cc21,3.447988059,4.151,0.703011941
O=C([C@@H]1CC12CCOCC2)N1CC[C@@H](Oc2cccc(Cl)c2)C1,3.713785706,3.1364,-0.577385706
C[NH2+][C@]1(C(N)=O)CCC[C@H]1CCSC1CCOCC1,4.614268117,0.5061,-4.108168117
Cc1cc(NC(=O)c2ccc(-n3cncn3)nc2)ccc1N(C)C,2.267592306,2.28902,0.021427694
Cc1ccccc1CNC(=O)N1CCC([C@@H](O)c2ccc(Cl)cc2)CC1,2.524142332,4.30362,1.779477668
CCc1nc(CNCCc2cccs2)cs1,2.300594957,3.0993,0.798705043
O=C(Nc1ccc2nc(-c3cc(F)ccc3F)[nH]c2c1)c1ccc([N+](=O)[O-])cc1,2.145744096,4.6686,2.522855904
N#Cc1ccccc1NC(=O)C(=O)NC[C@H]1COC2(CCCC2)O1,3.412921889,1.29868,-2.114241889
O=C(NC1(C(=O)[O-])CC[NH2+]CC1)OCC1c2ccccc2-c2ccccc21,3.490460281,0.371,-3.119460281
Cc1ccc(C(=O)N2CCC[C@H]2CNC(=O)OC(C)(C)C)cc1,2.424440646,3.12432,0.699879354
C[NH+](C)CC(=O)NNC(=O)c1ccc(C2CCCCC2)cc1,2.819298295,0.6398,-2.179498295
C[C@@H]1CCCC[C@@H]1OCC(=O)N(C)Cc1cccc(O)c1,2.912373702,2.9459,0.033526298
C#CCNS(=O)(=O)c1ccc(C(=O)N[C@@H]2CCO[C@@]3(CCOC3)C2)cc1,3.898168643,0.666,-3.232168643
Cc1ccc(C2(CN3CCN(S(=O)(=O)N(C)C)CC3)CCCC2)cc1,2.439754533,2.23082,-0.208934533
Cc1sc2ncn(CCC(=O)NCC(N)=O)c(=O)c2c1C,2.347207183,0.06644,-2.280767183
COc1ccc(CCNC(=O)c2cc(C(C)C)nc3c2c(C)nn3C)cc1OC,2.282558575,3.38982,1.107261425
O=[N+]([O-])c1cnc(Br)s1,3.179818392,1.8138,-1.366018392
CC(C)(C)CC(=O)N1CCN(C(=O)c2ccn3ccsc23)CC1,2.712888331,2.7214,0.008511669
C/[NH+]=C(/NCCCc1cccc(Br)c1)NC1CC1,2.991361127,0.7897,-2.201661127
O=C(CCC1CCCC1)N1CCC(COc2ccccc2C(=O)N2CCCC2)CC1,2.19573155,4.5104,2.31466845
CC(C)(C)c1n[nH]cc1/C=C1\SC(NC2CCCCC2)=NC1=O,3.076214803,3.5998,0.523585197
CN(CC(F)(F)F)C(=O)CNC(=O)N1CCO[C@H](C#N)C1,3.449633505,-0.05892,-3.508553505
CC[C@@H](NC(=O)NCc1ccc(C#N)cc1)c1ccc(C)c(F)c1,2.506404532,3.9563,1.449895468
CC(C)c1noc(CCNC(=O)/C=C/c2cn(C)c3ccccc23)n1,2.489753078,3.0568,0.567046922
CC(C)=C1Oc2c3c(cc(C)c2C1=O)OCN(Cc1ccccc1)C3,2.840364996,4.21612,1.375755004
O=C1/C(=N/O)c2ccc(Cl)cc2N1Cc1ccccc1,2.0410745,3.0651,1.0240255
c1ccc(CC2(Cc3ccc4c(c3)CCC4)C[NH2+]C2)cc1,2.844717318,2.5239,-0.320817318
CC[C@H](C)NC(=O)N[C@H](c1ccccc1)C(F)(F)F,2.819638554,3.3877,0.568061446
NC(=O)COc1ccc2ccccc2c1C[NH2+]C1CCCCCCC1,2.95041288,2.8802,-0.07021288
O=C(COc1ccc(Cl)cc1[N+](=O)[O-])N1CCc2cc([N+](=O)[O-])ccc21,2.251658336,3.1245,0.872841664
O=C(Cc1ccc(F)cc1F)NC1CCN(c2cccc(F)c2)CC1,2.033335093,3.4316,1.398264907
COC(=O)c1ccc(C[NH+](C2CC2)[C@@H](C)c2ccco2)cc1F,4.038401414,2.5138,-1.524601414
Cc1ccc([C@@H]2C[NH2+]CC[C@@H]2c2c(F)cccc2F)cc1,3.825954589,3.10772,-0.718234589
Cc1ccc(-c2ccc(=O)n(CC(=O)NCc3ccco3)n2)c(C)c1,2.180504107,2.43654,0.256035893
CCNC(=O)NC(=O)CSc1nnc(N2C[C@@H](C)C[C@@H](C)C2)n1Cc1ccco1,3.320744633,2.3395,-0.981244633
CC(C)(C)NC(=O)N1CCC(CNC(=O)C(=O)Nc2ccccc2[N+](=O)[O-])CC1,2.351664532,1.8696,-0.482064532
Cc1ccc(-n2cnnn2)cc1NC(=O)c1cnn(Cc2ccccc2)c1,2.221673771,2.46782,0.246146229
O=C(Nc1ccc(F)cc1OCC(F)F)c1cccnc1N1CCCC1,2.273178554,3.7171,1.443921446
CCCNC(=O)CCNC(=O)CN1CCC([NH3+])CC1,2.682391179,-1.2748,-3.957191179
COCCNC(=O)COc1cc2c(c3oc(=O)c4c(c13)CCC4)CCC(C)(C)O2,2.819768111,2.5267,-0.293068111
CC(=O)N[C@@H](CC(=O)Nc1cnn(C)c1)c1cccs1,2.760825185,1.6876,-1.073225185
O=C(c1cccc(S(=O)(=O)[N-]c2ccccc2F)c1)N1CCC[C@@H]1c1nc2ccccc2[nH]1,3.413113671,5.0734,1.660286329
C[C@H](CNC(=O)c1[nH]nc(C2CC2)c1Cl)Oc1cccc(F)c1,3.06692193,3.2769,0.20997807
COc1ccccc1C(=O)/C(C#N)=C\c1cnc(-c2ccccn2)s1,2.549035852,4.00358,1.454544148
O=C(NCc1ccco1)[C@H]1CCCN1C(=O)C[NH+]1CCc2sccc2C1,4.132206287,0.5895,-3.542706287
CC[C@@H]1CN(C(=O)c2ccc3c(c2)OCO3)CC[NH2+]1,3.69238642,0.2131,-3.47928642
Cc1cc(C(=O)Nc2ccc(F)cc2C)on1,1.934252507,2.68284,0.748587493
C[C@@H](N[C@H](C[C@@H]1CCOC1)c1ccccc1)c1ccc([N+](=O)[O-])cc1,3.168218916,4.4133,1.245081084
O=C(NNS(=O)(=O)c1ccc(F)cc1)c1cc2c(s1)CCCC2,2.219578826,2.3893,0.169721174
Cc1cc(C#N)ccc1Oc1ccc([N+](=O)[O-])cc1C#N,2.231562188,3.43888,1.207317812
COc1ccc(-c2nc(CNC(=O)c3ccccc3)n[nH]2)cc1,1.980697063,2.4103,0.429602937
Cc1cccc(OCC(=O)N/N=C/c2cc(Br)c(O)c(Br)c2O)c1,2.453904832,3.46032,1.006415168
CC1=Nc2ccccc2Oc2c1c(=O)n(-c1ccccc1)c1ccccc21,2.458237938,5.2371,2.778862062
COC[C@@](C)(O)CNC(=O)[C@@H]1CC(=O)N(c2ccc(F)cc2)C1,3.005123049,0.6922,-2.312923049
Fc1ccccc1[C@@H]([NH2+]Cc1ccoc1)C1CCCC1,3.69757193,3.4136,-0.28397193
CC1CCN(C(=O)CN2C(=O)N[C@](C)(c3ccc(Cl)cc3Cl)C2=O)CC1,2.732754879,3.0189,0.286145121
COc1cccc(NC(=O)NCc2ccc3c(c2)N(C(=O)c2cccc(C)c2)CC3)c1,2.09420016,4.52822,2.43401984
O=C(/C=C/c1ccccc1)NC(=S)Nc1cc([N+](=O)[O-])ccc1F,2.105372184,3.2603,1.154927816
CCCn1nccc1NC(=O)c1nnn(-c2ccc(C)cc2)c1C,2.315214931,2.74294,0.427725069
COc1cccc(-c2nc(CC(=O)N[C@@H]3CCS(=O)(=O)C3)cs2)c1,2.722418907,1.6645,-1.057918907
COc1cc(NC(=O)c2cnn(-c3ccccc3)n2)cc(OC)c1OC,2.151091028,2.5454,0.394308972
COc1ccc2c(c1)cc(C(=O)N[C@H](C)Cc1cnccn1)n2C,2.807843567,2.3379,-0.469943567
CC(C)COCCNC(=O)NNc1ccccc1Cl,2.165718029,2.6387,0.472981971
CCC[C@@H](C)[NH2+]C[C@H](C)Oc1cccnc1C,4.126301903,1.90932,-2.216981903
Cc1ccc(C)c(OCCSc2ccc(N)c(C)c2)c1,2.060838323,4.36516,2.304321677
C[C@H](CC#N)N(C)C[C@@H]1CSc2ccccc21,3.590360435,3.10988,-0.480480435
CCN[C@@]1(C#N)CC[C@@H](Oc2cccc(F)c2F)C1,3.52564354,2.76798,-0.75766354
Cc1cc(C[NH2+][C@@H](CO)C(C)C)c(C)n1-c1ccccc1,3.596987275,2.17444,-1.422547275
CC1CCC(NC(=O)C(=O)Nc2cccc(-c3nncn3C)c2)CC1,2.347143544,2.1155,-0.231643544
CC(=O)Nc1cccc(C(=O)/C=C/c2ccco2)c1,2.003848944,3.1341,1.130251056
CCCNC(=O)c1cccc(NC(=O)C(=O)N2C[C@H]3CC=CC[C@@H]3C2)c1,3.148539376,2.1895,-0.959039376
Cc1nc(C)c(C(=O)N2CCN(C(=O)Cc3ccc(Cl)cc3)CC2)o1,2.226794753,2.47194,0.245145247
CCc1nc(CN/C(=[NH+]/C)N2CC[NH+](CC(C)C)CC2)cs1,4.614668767,-1.282,-5.896668767
CN(C)c1ccc(C(=O)Nc2c3c(nn2C)CCC3)cc1,2.254443099,2.2271,-0.027343099
O=C1[C@H](Nc2ccccc2O[C@H]2CCOC2)CCCN1Cc1ccccc1,2.979200776,3.4574,0.478199224
N#Cc1cccc(CNC(=O)CCOc2ccc(C=O)cc2)c1,1.99088313,2.45608,0.46519687
COc1ccc(CCn2cc(C(=O)[O-])c(=O)[nH]c2=O)cc1,2.453088713,-0.8486,-3.301688713
CN(CCc1ccccc1)C(=O)C(=O)Nc1ccc(Oc2ccncc2)cc1,2.089518656,3.5135,1.423981344
Cc1cc(C)cc(NC(=O)[C@H]2CCCN(S(=O)(=O)c3cn(C)cn3)C2)c1,2.830417176,2.07634,-0.754077176
Cc1nc(CNC(=O)[C@@H]2CCCCN2C(=O)OC(C)(C)C)sc1C,2.787253271,3.16574,0.378486729
CCSC1=NC(=O)[C@H]2C(=N1)NC1=C(C(=O)CCC1)[C@H]2c1cccnc1,4.10432105,2.4395,-1.66482105
COCC[C@H](C)C(=O)Nc1ccc(Oc2cc[nH+]c3ccccc23)cc1,3.383061791,4.0573,0.674238209
COC(=O)c1sccc1S(=O)(=O)N1CCC[C@@H](NC(C)=O)C1,2.758371043,0.8239,-1.934471043
CC(=O)O[C@H]1CC[C@]2(C)[C@@H]3CC[C@@]4(C)[C@@H](C[C@@H]5CCCC[C@@]54C(C)=O)[C@@H]3C[C@@H](O)[C@@]2(O)C1,4.807691125,4.422,-0.385691125
CCNS(=O)(=O)[C@H]1CCN(CCSc2ccccc2)C1,2.893618866,1.7923,-1.101318866
CC(C)OCCCNC(=O)N1CCC[NH+](CC2CCCCC2)CC1,3.869396076,1.682,-2.187396076
C[C@]1(CCc2ccccc2)NC(=O)N(CC(=O)N2CCc3sccc3[C@H]2c2cccs2)C1=O,3.330561001,4.227,0.896438999
Oc1n[nH]c(-c2ccccc2)c1/C=N\c1cccc(Cl)c1,2.575508438,4.1863,1.610791562
C[C@@](NC(=O)c1scnc1C1CC1)(C(N)=O)c1ccccc1,3.152593137,2.151,-1.001593137
Cn1cc(C[NH2+]C2CN(C(=O)OC(C)(C)C)C2)c(C(C)(C)C)n1,3.531149004,1.4003,-2.130849004
O=C(Cn1ncc2c([nH]c3ccccc32)c1=O)N1CC[C@@H](c2ccccc2)C1,2.801434255,2.8939,0.092465745
COc1cccn(CC(=O)N2CCC[C@H]2c2nccs2)c1=O,2.926470384,1.6771,-1.249370384
Cc1ccc(S(=O)(=O)N2CCCSCC2)c(C)c1,2.172121888,2.43104,0.258918112
CC(C)CSc1ccc(N)c(N)n1,2.677260409,1.9941,-0.683160409
Cc1ccccc1O[C@@H](C)CC[C@H](C)[NH+]1CCCCC1,4.183468302,2.99982,-1.183648302
CC1=CC(=O)C(C)=C(C)C1=O,2.554710776,1.4209,-1.133810776
COC(=O)c1c(NC(=O)c2cc3cc(Br)ccc3o2)sc2c1CCC2,2.326967163,4.7844,2.457432837
O=C(c1cc(=O)c2ccccc2o1)N1CCN(c2ccc(-c3noc(C(F)(F)F)n3)c[nH+]2)CC1,3.316210434,2.6383,-0.677910434
CCN1C(=O)c2cc(NC(=O)c3cccc(OC)c3OC)ccc2Oc2ccccc21,2.188834997,4.7285,2.539665003
CCC[C@H](C)C(=O)NCc1c(O)ccc2c1CCCC2,2.79815046,3.3234,0.52524954
Cc1c(CCC(=O)NC2CC2)c(=O)n2nc(-c3ccccc3)nc2n1Cc1ccccc1,2.39384803,3.12592,0.73207197
CC(C)(C)C(=O)N1N=C(c2cccc(NS(C)(=O)=O)c2)C[C@H]1c1cccs1,3.033061881,3.8434,0.810338119
COc1ccc(/C=C/C(=O)Nc2ccc(Cl)c(-c3nc4ncccc4o3)c2)cc1,2.282822064,5.2037,2.920877936
Cc1cc(C(=O)Nc2ccc(C[NH+]3CCCC3)cc2)c2cnn(Cc3cccs3)c2n1,3.308540728,3.28052,-0.028020728
C=CCn1c(CC)nnc1Sc1nc2c([nH]1)c(=O)n(C)c(=O)n2C,2.9507363,0.4514,-2.4993363
CCN(C(=O)[C@H](C)c1cccc(F)c1)c1ccc(C#N)c(Cl)c1,2.787640646,4.50738,1.719739354
COC(=O)CSc1nnc(-c2ccc(OC)cc2OC)o1,2.12901572,2.0189,-0.11011572
C[C@@H](NC(=O)CCc1cccc(F)c1)c1ccc2c(c1)CCC(=O)N2,2.559093635,3.5204,0.961306365
COC(=O)[C@H](NC(=O)c1cc(C)on1)c1ccc(F)c(F)c1,2.709930274,1.90532,-0.804610274
O=C([O-])c1ccccc1NCc1ccccc1[N+](=O)[O-],2.300858927,1.5704,-0.730458927
Cc1cc(=O)c(C(=O)NCCc2ccccc2F)nn1-c1ccccc1F,2.138447824,2.79162,0.653172176
CC(C)[C@H](NC(=O)c1ccccc1)C(=O)N[C@H](C)C(N)=O,2.47185794,0.431,-2.04085794
Cc1ccccc1OC1CN(c2ncccc2C(=O)N2C[C@@H]3CC[C@H]2C3)C1,3.852979496,3.28212,-0.570859496
CCc1nc(CN2CCC[C@H]([NH2+]CC3CCN(C(C)=O)CC3)C2)no1,3.812233153,0.4183,-3.393933153
CCN(CC)C(=O)CSc1c(C)cnc2c1c(=O)n(C)c(=O)n2C,2.632741217,0.90112,-1.731621217
CCOc1cc2c(cc1OCC)C[NH+](Cn1nc3n(c1=S)CCCCC3)CC2,4.018751641,2.53659,-1.482161641
O=C(CNCC(F)(F)F)N(Cc1cccc(O)c1)C1CC1,2.389515685,2.0351,-0.354415685
Cc1cc(CN2CCn3c(nnc3-c3ccccc3)C2)ccc1-n1cncn1,2.418767322,2.85002,0.431252678
C[C@@H]1C[C@@H](C)CN(S(=O)(=O)c2cccc(C(=O)Nc3nc(C4CC4)cs3)c2)C1,3.076699104,3.9394,0.862700896
COC(=O)c1c(NC(=O)CCC(=O)[O-])sc2c1CCCCCC2,2.562728956,1.6623,-0.900428956
S=C(Nc1cccnc1)N(Cc1ccccc1)C[C@H]1CCCO1,2.556251215,3.4596,0.903348785
C[NH+]1CC[C@@H](N2CC3(CCC2=O)CC[NH+](Cc2c[nH]cn2)CC3)C1,6.165717612,-1.5158,-7.681517612
CC[C@@](C)([C@@H]([NH2+]C)C1C[C@@H](C)C[C@@H](C)C1)[NH+]1CCCCC1,5.7683801,1.858,-3.9103801
Cn1cc(NC(=O)N2CCN(c3ccc([N+](=O)[O-])cc3)CC2)c2ccccc2c1=O,2.24714082,2.8008,0.55365918
CC(C)C[C@H](NC(=O)c1cc(-c2cccn2C)n[nH]1)c1ncnn1C,3.319192988,2.0609,-1.258292988
CCNC(=O)CN(C1CCCCC1)S(=O)(=O)c1ccccc1,2.045861074,2.1461,0.100238926
CCSc1cc(C(=O)NC[C@H](C)Nc2ccccc2)ccn1,2.637650604,3.424,0.786349396
Cc1noc(C(C)C)c1C(=O)N1CCN(C[C@H](C)O)CC1,2.859488392,1.24502,-1.614468392
CSc1ccc(Cl)c(C(=O)Nc2cc(C(=O)N(C)C)ccc2Cl)c1,2.013567479,4.6694,2.655832521
CC(=O)c1ccc(OC[C@@H](O)CN2CCn3c(nn(C)c3=O)C2)cc1,3.001332769,0.0399,-2.961432769
Cc1ccc([C@H]2C[NH+]=C(N)N2Cc2ccccc2)c(C)c1,3.297715153,1.25564,-2.042075153
Cc1nc(SCC(=O)c2ccc3[nH]c(=O)[nH]c3c2)nc(C)c1C,2.466697225,2.54646,0.079762775
Cn1cc(C(=O)N[C@H]2CCC[C@H]3OCC[C@H]23)c(-c2ccccc2)n1,3.352877627,2.7745,-0.578377627
CC[NH2+]C[C@]1(c2ccc(F)cc2Cl)CCC[C@@H]1C,4.165842394,3.1202,-1.045642394
CCN(CCNC(=O)N[C@@H](C)c1ncc(C)s1)c1cccc(C)c1,2.919644026,3.64664,0.726995974
COCC(=O)Nc1ccc(-c2noc([C@@H]3CCCN3Cc3[nH]cc[nH+]3)n2)cn1,3.85152994,1.1958,-2.65572994
Cc1cc(NC(=O)[C@H]2COc3ccccc3O2)c2ncccc2c1,2.627262029,3.32172,0.694457971
CC[C@H]1CC(=O)N(C[C@@H]([NH3+])c2ccc(OC(C)C)cc2)C1,3.641910716,2.0153,-1.626610716
Cc1ccc(C(=O)N(C)CCC#N)cc1,1.82915826,1.9807,0.15154174
CN(O)[C@H]1CC[C@H](c2ccc(C(C)(C)C)cc2)C1,3.088427273,3.9412,0.852772727
Cc1cc(C)n(-c2nc([O-])cc(C(C)C)n2)n1,3.071945177,1.47614,-1.595805177
CO[C@@H](C)c1nc(C[NH+](C)Cc2cccc(C)c2C)cs1,4.259931405,2.68224,-1.577691405
CCc1ncc(CN(C)C(=O)c2c[nH]nc2-c2ccc(OC)cc2)s1,2.753719275,3.3764,0.622680725
Cc1ccc2[nH]c3c(=O)n(CC(=O)N(C)c4ccc(C)c(C)c4)ncc3c2c1,2.543058432,3.46606,0.923001568
CC[NH2+]C[C@H]1CCO[C@@H]1c1ccc(F)cc1F,4.211152742,1.6257,-2.585452742
Cn1nc(C(C)(C)C)cc1C(=O)Nc1ccc(C(N)=O)cc1,2.062408364,2.0688,0.006391636
O=C(Nc1cccnc1)c1ccc(NC(=O)c2ncn(-c3ccccc3)n2)cc1,2.128494967,3.1669,1.038405033
O=c1cc(C2CCC2)nc(-c2cccc(C[NH+]3CC(O)C3)c2)[nH]1,3.57334498,0.4638,-3.10954498
C[C@@H](OC[C@H]1CCCO1)C(=O)Oc1cc(Cl)ccc1Cl,3.004023247,3.4829,0.478876753
COc1ccccc1C1(C(=O)Nc2ccc3c(c2)N(C(=O)C2CC2)CC3)CC1,2.34415068,3.6646,1.32044932
O=C([O-])[C@@H]1Cc2c([nH]c3ccccc23)[C@H](C(=O)[O-])[NH2+]1,4.686340492,-2.8031,-7.489440492
O=C1NC2(CCCC2)C(=O)N1Cc1ccc(Br)s1,2.905374606,2.8752,-0.030174606
CC(C)N1CCO[C@H](c2nnn[n-]2)C1,4.545882887,-0.3895,-4.935382887
Cc1ccc2c(c1)C(=O)N([C@H](C)C(=O)NCCCC(C)C)C2=O,2.543137633,2.53192,-0.011217633
CCN(CC)S(=O)(=O)N[C@H](c1ccc(F)cc1)c1nccn1C,2.880328558,1.8248,-1.055528558
CCN(CC)C(=O)C1CCN(c2cccc(Cl)c2)CC1,1.925061949,3.4248,1.499738051
O=C(NCC(=O)N1CCN(c2ccc(F)cc2)CC1)N[C@@H]1CCS(=O)(=O)C1,2.699990976,-0.0394,-2.739390976
CCn1c(SCC(N)=O)nnc1[C@H](C)Oc1cccc(Cl)c1,2.695897046,2.6688,-0.027097046
COC(C)(C)C(=O)Cc1cccc2ccccc12,2.090800657,3.3764,1.285599343
CC1CCN(C(=O)c2ccccc2NCC(=O)N[C@H](C)C(C)C)CC1,2.513852635,3.1313,0.617447365
COCC[C@@H](C)C(=O)Nc1ccc2c(C)n[nH]c2c1,2.813255353,2.48242,-0.330835353
CC[C@H]([NH3+])[C@H](Sc1cc2ccccc2[nH]1)c1cccnc1,3.893539633,3.4168,-0.476739633
N#C[C@@H](c1ccc(Cl)cc1)c1ccc([C@H](O)c2ccccc2)cc1,2.754553872,5.07718,2.322626128
CCS(=O)(=O)c1ccccc1NCc1c[nH+]cn1C1CC1,3.319676997,2.0428,-1.276876997
Cc1ccc(OC[C@@H](O)C[NH+](C)Cc2cc(C)on2)c(C)c1,3.945702345,1.05446,-2.891242345
COc1cccc2c1OC[C@@H](NC(=O)Nc1cncc3ccccc13)C2,2.816092277,3.3686,0.552507723
Cc1cc([C@H]2CCCN2CCc2ccc([N+](=O)[O-])cc2)no1,2.770692362,3.27082,0.500127638
COc1ccccc1-n1nc(C(=O)NC2CCCCCC2)c2ccccc2c1=O,2.082487998,3.8469,1.764412002
O=C(Cn1ccc(=O)[nH]c1=O)NC[C@H]1CCOC1,2.849533882,-1.3107,-4.160233882
CC(=O)N1C(C)(C)C=C(CO)C1(C)C,3.451082344,1.3244,-2.126682344
Cc1cnc(CNCc2c[nH+]cn2C2CC2)cn1,3.652209805,1.02532,-2.626889805
Cc1cc(C(=O)N2CCc3cc(Cl)cc(Cl)c3C2)[nH]n1,2.669253158,3.22342,0.554166842
CCOC(=O)[C@H](CC)N1CCC(C(=O)N2CCCCCC2)CC1,2.5693103,2.4427,-0.1266103
CC(C)[C@H](C)NC(=O)C(=O)Nc1cc(Br)ccc1-n1cccn1,2.804510976,2.734,-0.070510976
O=C(CSc1ccc(F)cc1)Nc1nnc(-c2ccccn2)o1,2.191540921,3.0015,0.809959079
C/C=C/c1ccc(OCC(=O)Nc2ccc3[nH]c(=O)[nH]c3c2)c(OC)c1,2.253707868,2.9154,0.661692132
CN(CC[NH+](C)C)C(=O)c1cc(F)c(Cl)cc1Cl,3.312684493,1.349,-1.963684493
O=[N+]([O-])c1cnccc1NCc1cccc(F)c1,2.013819951,2.741,0.727180049
Cc1ccc(F)c(NC(=O)c2cccc(CS(C)(=O)=O)c2)c1,1.842203357,2.93102,1.088816643
Cc1nc(N2CCN(c3ccccc3O)CC2)nc2cc3c(cc12)CCC3,2.344492661,3.45912,1.114627339
CC(C)c1ccsc1C(=O)N1CCN(c2ncccc2C#N)CC1,2.467553166,3.10058,0.633026834
Cc1ccc(S(=O)(=O)N2CCC[C@@H](c3nc(-c4ccccc4Br)no3)C2)cc1,2.679753116,4.37582,1.696066884
CCc1ccc2nc(C)cc(C(=O)N[C@H]3CCC(=O)N(C)C3)c2c1,2.726969629,2.45622,-0.270749629
CS(=O)(=O)[N-]c1ccc(C2=NN/C(=N\c3ccccc3)SC2)cc1,3.686623439,3.3796,-0.307023439
COc1ccc(Br)cc1C[NH2+][C@H]1CCCC1(C)C,3.803916035,3.0998,-0.704116035
O=C1CCC[C@@H]1CC[S@@](=O)c1cc(Cl)ccc1Cl,3.495879625,3.8603,0.364420375
C[C@@H]([NH2+]Cc1ccn[nH]1)c1ccc(Cl)s1,4.249743882,1.9492,-2.300543882
NC(=O)[C@@H]1CCCN1c1nc(-c2cccnc2)nc2scc(-c3ccccc3)c12,2.881834056,3.8744,0.992565944
COc1ccccc1CNC(=O)c1nn(C)c(=O)c2ccccc12,1.894141175,1.8721,-0.022041175
Cc1ccc(F)cc1S(=O)(=O)NCCCCn1ccccc1=O,2.142833523,2.05452,-0.088313523
Cc1cc(C)nc(N(C)CC(C)(C)O)n1,2.621362705,1.30054,-1.320822705
COC(=O)c1sc2ccc(NC(=O)NC(C)C)cc2c1OC,2.204937973,3.2264,1.021462027
Oc1c(F)ccc2cc[nH]c12,2.532800833,2.0126,-0.520200833
Cc1cccc(-c2nc(C[NH+]3CCC[C@@H](Cn4cncn4)C3)no2)c1,4.25447131,1.13162,-3.12285131
O[C@H]1Cc2ccccc2[C@H]1[NH2+]Cc1nccn1Cc1ccccc1,3.674403729,1.6531,-2.021303729
Cc1cccc(S(=O)(=O)[C@H](C)C(=O)Nc2cc(F)ccc2F)c1,2.496263936,3.07412,0.577856064
O=C1C([O-])=C(O)C(=O)c2ccccc21,2.704172084,0.1955,-2.508672084
CC(C)(C)CC(=O)Nc1cc(F)c(F)cc1Cl,2.036293505,3.9929,1.956606495
COc1ccc(C(=O)NC(=S)Nc2ccc3oc(-c4ccc(O)c(Br)c4)nc3c2)cc1,2.285489018,5.0983,2.812810982
COC(=O)C1(NCc2cnc3ccccn23)CCCC1,2.69067309,1.9097,-0.78097309
COC(=O)c1cc(-c2oc3ccc(Br)cc3c(=O)c2C)cc([N+](=O)[O-])c1,2.387240107,4.22572,1.838479893
Fc1ccc(CNC2=NC=NC3=NC=N[C@H]32)cc1,3.602406741,1.1647,-2.437706741
OC[C@@H](CCc1cccs1)c1cccc(F)c1,2.630519417,3.5959,0.965380583
C[C@H]1CCc2c(C(=O)NCc3ccc(S(N)(=O)=O)cc3)csc2C1,2.683701432,2.4503,-0.233401432
COC(=O)NCc1ccc(NCc2ccc(Br)c(C)c2)cc1,1.910394325,4.22562,2.315225675
CNS(=O)(=O)CC(=O)N[C@H](C)c1ccc(SC(C)C)cc1,2.831783505,1.9135,-0.918283505
c1cc(N2CCCCCC2)ccc1C[NH2+]Cc1nnc2n1CCC2,3.039500112,1.8683,-1.171200112
CC[C@@H](C)NC(=O)[C@@H](C)S(=O)(=O)Cc1ncc(Cl)n1C,3.496099959,1.2915,-2.204599959
CC(=O)c1cccc(OC(=O)/C=C/c2cn(CCC#N)c3ccccc23)c1,2.377963524,4.37638,1.998416476
CC(C)c1ccccc1NC(=O)C1CCN(S(C)(=O)=O)CC1,1.946704896,2.4201,0.473395104
C[NH+](C)[C@H]1CC[C@@H](NC(=O)N2CCN(c3nnnn3-c3ccccc3)CC2)C1,3.830633555,-0.4405,-4.271133555
CCn1cc(C(=O)C(=O)N(C)c2c(Cl)cccc2Cl)cn1,2.653209612,3.0555,0.402290388
CC(C)CCN1N[C@H](C2=NC(=O)C3=C4CCCC[C@@H]4SC3=N2)c2ccccc2C1=O,4.330683963,4.0574,-0.273283963
Cc1ccnc2nc(C(=O)N(CCC[NH+](C)C)Cc3ccccc3)nn12,3.171964209,0.60972,-2.562244209
CC[C@H](NC(=O)c1ccc([N+](=O)[O-])cc1F)c1ccc(OC)cc1,2.384004406,3.6236,1.239595594
CCc1ccccc1NC(=O)N[C@H](C)C(=O)N1CCCC[C@@H]1C,2.748723049,3.16,0.411276951
O=C(CS(=O)(=O)c1ccc2ccccc2n1)N1CCCC1,2.204420544,1.6309,-0.573520544
Cc1nc(-c2ccc(S(=O)(=O)N[C@H](C)C[NH+]3C[C@@H](C)C[C@@H](C)C3)cc2)co1,4.544819199,1.87762,-2.667199199
CCOc1ccccc1NC(=O)C[C@@H]1CSc2nc(C(C)(C)C)cc(=O)n21,3.009533648,3.6151,0.605566352
C[C@H]([NH2+]Cc1ccc(F)c(CO)c1)c1ccc(-c2cccs2)cc1,3.417228851,3.8711,0.453871149
C[C@@H](Oc1ccccc1Cl)C(=O)Nc1cccc(I)c1,2.252775185,4.3506,2.097824815
COc1ccc(CCNC(=O)c2c[nH]nc2-c2ccccc2)cc1,2.153124233,3.0578,0.904675767
C[NH+]1CCN(CC(=O)NC[C@@H]2COc3ccccc3O2)CC1,3.69680263,-1.2271,-4.92390263
COc1ccc2c(c1)N(C(=O)c1cccc([N+](=O)[O-])c1)CCO2,2.081577406,2.6426,0.561022594
CCc1nc(C(=O)Nc2ccc(N(C)C(=O)OC)cc2)nn1-c1c(Cl)cccc1Cl,2.515200631,4.5914,2.076199369
O=C(Nc1sc2c(c1C(=O)Nc1ccccc1)CCC2)c1ccccc1Cl,1.933911169,5.3948,3.460888831
Cc1ccc(Sc2ccc(=O)n(-c3ccc(C(=O)[O-])cc3)n2)cc1,2.544596831,2.05562,-0.488976831
O=C1CCCN1CC#CC[NH+]1CCCC1,4.434869482,-0.7091,-5.143969482
COc1cc2c(cc1C[NH+]1C[C@@H]3CCC[C@H](C1)C3(O)c1ccccn1)CCC2,5.299426023,2.2815,-3.017926023
CC(C)CN(C)CN1C(=O)C(=O)N(C23CC4CC(CC(C4)C2)C3)C1=O,4.036679938,2.2913,-1.745379938
C[C@@H](O)c1ccc(N(C)[C@H](C)c2ccccc2F)cc1O,3.081308678,3.782,0.700691322
COc1ccc(NC(=O)N[C@@H]2CCO[C@@H]2C)c(Br)c1,3.049111174,2.7566,-0.292511174
Cn1c(C[NH+]2CCC(CS(N)(=O)=O)CC2)nnc1-c1ccccc1,3.609551356,-0.4345,-4.044051356
CCc1cccc(NC(=O)Cc2nc(Cc3ccc(F)cc3)n[nH]2)c1,2.201228159,3.2782,1.076971841
Cc1cc(O)c(-c2cc(C(=O)Nc3cccnc3)n[nH]2)cc1C,2.368833974,3.04644,0.677606026
Cc1cnc(CN2CCO[C@H](CN(C)c3cccn[nH+]3)C2)o1,4.141147342,0.52932,-3.611827342
CC[NH2+][C@@H](Cc1ccc(Br)s1)[C@@H]1CSCCO1,4.741884708,2.137,-2.604884708
O=C(NC(=S)Nc1ccc(O)cc1)c1ccccc1[N+](=O)[O-],1.926072916,2.4272,0.501127084
CC(=O)Nc1cccc(NC(=O)C(=O)N(C)CCc2cccs2)c1C,2.321622474,2.65452,0.332897526
Clc1cccc(CC[NH2+]Cc2cc[nH]c2)c1,3.452394629,1.9742,-1.478194629
O=C(CSC(=S)N1CCCC1)N1CCc2ccc([N+](=O)[O-])cc21,2.513903323,2.5978,0.083896677
COc1ccc(N2C[C@@H](C(=O)NN3CCCCC3)CC2=O)c(OC)c1,2.730423213,1.5738,-1.156623213
O=C(c1cccc(O)c1)N1CCN(C/C=C/c2ccccc2[N+](=O)[O-])CC1,2.224870018,2.7716,0.546729982
COc1cc(-c2cccc(CN3CCOCC3)c2)ncn1,2.083956027,1.9844,-0.099556027
Cc1csc(C(=O)[O-])c1NC(=O)c1cccc([N+](=O)[O-])c1,2.643339977,1.58052,-1.062819977
CCC(CC)C(=O)Oc1ccc(S(=O)(=O)N(C)C)cc1,2.090592718,2.2785,0.187907282
CN(C)S(=O)(=O)c1ccc(C(=O)N2CCC(Oc3nc4c(F)cccc4s3)CC2)cc1,2.426432086,3.3693,0.942867914
COC(=O)Cc1ccc(OCC(=O)N2CCCC2)cc1,1.744125512,1.4033,-0.340825512
CC[C@H](Oc1ccc2c(c1)CCC2)C(=O)N1CCCCCC1,2.545510082,3.7353,1.189789918
CO[C@H](C(=O)Nc1ccc(C(=O)NCC(C)C)cc1)c1ccccc1,2.252969911,3.3986,1.145630089
O=C1C[C@@H](CNC(=O)Nc2ccc3c(c2)COC3)c2ccccc2N1,2.934345456,2.9643,0.029954544
C=CCN(CC(=O)[O-])C(=O)CSc1cccs1,3.162627302,0.6047,-2.557927302
CCOC(=O)C1CCN(CN2C(=O)NC(c3ccccc3)(c3ccccc3)C2=O)CC1,2.379611396,2.7146,0.334988604
CC(C)([C@H](O)[C@H]1CCOC2(CCCC2)C1)[NH+]1CCCCC1,5.244483813,1.9341,-3.310383813
Nc1cc(-c2nc3cc(-n4cnnc4)ccc3o2)ccc1Cl,2.535858453,3.3111,0.775241547
C[NH+](C)[C@@H]1CC[C@H](NC(=O)C(C)(C)CCCc2ccccc2)C1,3.853817039,2.2173,-1.636517039
CCOc1ccc([C@H](C)NC(=O)c2ccc(F)cc2F)cc1,2.159335886,3.8545,1.695164114
COC(=O)C(C)(C)[C@@H]1CCC[NH+](Cc2nnc(C)n2C2CC2)C1,4.613226345,0.91552,-3.697706345
CCc1cnc(CN(C)C(=O)c2ccc(SC(F)F)cc2)s1,2.537210369,4.2924,1.755189631
C[NH2+]Cc1cc(=O)[nH]c(-c2ccc(F)cn2)n1,3.604829082,-0.3358,-3.940629082
C[C@@H](NC(=O)COc1ccccc1F)c1ccccc1,2.002354089,3.0819,1.079545911
CC[C@@H](C)Oc1cccc(C(=O)Nc2cccc(C(=O)N3CC[NH+](C)CC3)c2)c1,3.343060863,2.0867,-1.256360863
COC(=O)c1cccc(CNC(=O)Nc2cc(F)ccc2OC)c1,1.778129704,2.9426,1.164470296
NC(=O)c1cncc(C#CC2(NC(=O)C3CCCC3)CCCCC2)c1,2.735607005,2.5413,-0.194307005
O=C(CN1CC2(CCOCC2)Oc2ccccc2C1=O)NCCN1CCCC1=O,3.166655355,0.809,-2.357655355
CC[C@H]1CCC[C@@H](C(=O)[C@@](C)(CC)N2CCOCC2)C1,3.690512322,3.2728,-0.417712322
COc1ccc(OC)c(-n2nnnc2S[C@@H](C)c2nc3ccccc3n2C(F)F)c1,3.036983258,4.2776,1.240616742
CC(=O)NN1C(N)=C(C#N)[C@@H](c2ccc(Cl)cc2)C2=C1CC(C)(C)CC2=O,3.23336918,3.12738,-0.10598918
Cc1cc(-c2cncnc2C2CCN(C(=O)c3cn(C)nc3C)CC2)on1,2.788500188,2.50174,-0.286760188
C[C@H](NC(=O)Cc1cc2ccccc2[nH]c1=O)c1ccc2ccccc2c1,2.464508914,4.1012,1.636691086
Cc1ccc(CC(=O)N2CCC[C@H](C)C2)cn1,2.432306191,2.19102,-0.241286191
CCn1/c(=N/C(=O)Cn2cccn2)[nH]c2ccc(Br)cc21,3.05958439,2.0758,-0.98378439
Cc1cc(C(=O)N2CCN(Cc3cnc4ncccn34)CC2)c(C)o1,2.619295467,1.89714,-0.722155467
COc1ccccc1CNC[C@@H](c1ccc(C)o1)N1CCOCC1,2.661936754,2.75972,0.097783246
C#CCN(Cc1ccccc1Cl)C1CCCC1,2.276202846,3.7178,1.441597154
CCn1cc(Cn2cc(NC(=O)C3CCN(S(C)(=O)=O)CC3)cn2)cn1,2.468456254,0.7579,-1.710556254
CC[C@@H](C)NC(=O)[C@@H](C)NCc1cccc(C(N)=O)c1,2.653152917,1.1783,-1.474852917
NC(=O)C1CC[NH+](C/C=C/c2ccccc2)CC1,3.578814538,0.48,-3.098814538
c1ccc(-c2cc(CSc3ncnc4sccc34)on2)cc1,2.382726154,4.6386,2.255873846
C[C@H](NC(=O)N(C)C)c1cccc(C#CCCO)c1,2.827051719,1.7527,-1.074351719
CC(=O)CCc1ccc(O[C@H](C)C(=O)N[C@H]2CCCC[C@@H]2C)cc1,2.92425479,3.6704,0.74614521
O=Cc1c(C2CC2)oc2ccc(OCc3ccc(Cl)nc3)cc12,2.624109703,4.7501,2.125990297
CN(Cc1cscn1)c1ccc(S(=O)(=O)C(C)(C)C)nc1,2.900280867,2.7467,-0.153580867
CC(C)[NH+]1CCN(c2ccc(C(=O)N3C[C@H](C)OCC3(C)C)cc2)CC1,4.047651791,1.4394,-2.608251791
O=C(NCCC[NH+]1CCCCC1)N1CCO[C@H](c2ccccc2)C1,3.814447251,1.2284,-2.586047251
C[C@H](C(=O)N(C)CCC#N)[NH+](C)Cc1ccccc1,3.961452338,0.46188,-3.499572338
O=C(/C=C/c1ccc(Cl)cc1)Nc1nnc([C@H]2CCCO2)s1,2.806479793,3.6949,0.888420207
CC[C@@](C)([C@H](NC)c1cccc(OC)c1)[NH+]1CCCCC1,4.2931591,2.1932,-2.0999591
Cn1c(=O)c2[nH]c(NCCCCl)nc2n(C)c1=O,2.736699884,0.0011,-2.735599884
Cc1csc(C2(NC(=O)CCCc3nnc(-c4ccccc4)o3)CCCC2)n1,2.643283154,4.40992,1.766636846
CCSc1ccc(Cl)cc1C(=O)Nc1cccnc1C,2.132314147,4.40772,2.275405853
C[C@@H]1CCCN(C(=O)N[C@@H]2CCCCC[C@@H]2[NH3+])C1,3.76711659,1.3711,-2.39601659
O=C(CN1CCCOC1=O)N1CCC[C@@H]([NH+]2CCCCCC2)C1,4.143119308,0.2786,-3.864519308
Cc1cc(SCCCSCC#N)nc2ccccc12,2.540944048,4.2822,1.741255952
O=C(CCNc1ccccc1[N+](=O)[O-])OC1CCC1,2.06472993,2.4925,0.42777007
CCCSC1=NC(=O)[C@H]2C(=N1)NC(=O)C[C@H]2c1cc(Br)ccc1OC,3.954159567,3.1153,-0.838859567
C=CCN(Cc1cccc(C#N)c1)C[C@H](O)c1cccc(F)c1,2.763633613,3.41898,0.655346387
O=C([O-])c1cccc2c1ccn2Cc1cscn1,2.841347083,1.5096,-1.331747083
C[C@@H]1CCCCN1C(=O)c1ccc(NS(=O)(=O)c2ccc3c(c2)=[NH+]C(=O)[NH+]=3)cc1,3.780275122,-1.964,-5.744275122
CC[C@@H](Cl)CCc1cc(C)cc(F)c1,2.71715977,4.08412,1.36696023
[NH3+][C@H](Cc1cccc(F)c1F)c1cc(Cl)sc1Cl,3.588641574,3.8588,0.270158426
Cc1cnc(N(C)CC2CCCC2)c([N+](=O)[O-])c1,2.421682195,2.92462,0.502937805
Cc1cccc2c(CCC(=O)N3CCC[C@@H](n4cncn4)C3)c[nH]c12,2.996277595,2.86412,-0.132157595
Cc1noc(C)c1[C@@H]1CCCN1C(=O)CN1CCNC1=O,3.158593888,0.98014,-2.178453888
CCCCCCSc1nc2ncccn2n1,2.464580554,2.7967,0.332119446
Cc1ccc(OC(=O)C2CCN(C(=O)c3ccc(F)cc3)CC2)c(C)c1,1.905636655,3.90034,1.994703345
CC[C@]1(C2CC2)CC(=O)NC(=O)[C@H]1Cc1ccccc1,3.395847148,2.6982,-0.697647148
O=C(CSc1ncc2ccccn12)N1CCc2sccc2C1,2.587505606,3.0728,0.485294394
CC(=O)Nc1ccc(NC(=O)c2ccccc2C(=O)c2ccc(F)cc2)cc1Cl,1.8770974,4.9208,3.0437026
Cc1ccc(Cl)cc1NC(=O)N1CC[C@@H](N(C)C(=O)OC(C)(C)C)C1,2.673373595,4.12152,1.448146405
Nc1nc(CC(=O)N2CCC[C@@H](c3[nH]ncc3-c3ccccc3)C2)cc(=O)[nH]1,3.16264224,1.6909,-1.47174224
C[C@H](NS(=O)(=O)c1ccccc1)C(=O)N1C[C@H](C)c2ccccc21,2.879521476,2.5037,-0.375821476
Cc1cccc2cc(C(=O)NC[C@H](C)C(=O)[O-])oc12,3.101639588,0.85702,-2.244619588
CNC(=O)c1cccc(OC(=O)c2oc3ccc(OC)cc3c2C)c1,2.081208379,3.32862,1.247411621
CC[C@@H](NC(=O)Nc1cc(COC)ncn1)c1ccc(C)c(F)c1,2.843322787,3.34332,0.499997213
CCc1nn(C)cc1NC(=O)N1CCN(C(=O)[C@H]2CCCO2)CC1,3.02663316,0.8376,-2.18903316
CC(C)Nc1nnc(SCc2cn3cc(Cl)ccc3n2)s1,2.651955928,3.9518,1.299844072
CC(=O)Nc1cccc(N[C@H](C)C(=O)NC[C@H](C)c2ccccc2)c1C,2.750741611,3.67372,0.922978389
CC(C)(CCC(=O)OC(C)(C)C)NC(=O)c1ccc(Br)o1,2.503298166,3.6724,1.169101834
CCOC(=O)c1c(NC(=O)c2cc3cccc(OC)c3oc2=O)sc(C)c1C,2.294956795,3.90894,1.613983205
CC1(C)C[C@H]([NH2+]C(C)(C)C(N)=O)C(C)(C)O1,4.587654623,0.1598,-4.427854623
Fc1cccc(C[NH2+][C@H]2CCC[C@@H](C(F)(F)F)C2)c1,4.065181131,3.0102,-1.054981131
COc1ccc(-c2nc(S(=O)(=O)c3ccccc3)c(NCc3ccc4c(c3)OCO4)o2)cc1,2.369919411,4.5238,2.153880589
CCN(C(=O)c1sc(NC(=O)c2ccco2)cc1C)[C@@H](C)C(C)C,3.067582502,4.40842,1.340837498
CC(=O)c1ccc(N2CCN(CC(=O)Nc3ccc(F)cc3)CC2)cc1,1.745600799,2.789,1.043399201
CC(C)(C)c1ccccc1NC(=O)CNc1cc2c(cc1Cl)NC(=O)CO2,2.379469962,4.019,1.639530038
COc1ccc(NC(=S)NC(=O)c2ccccc2OC(C)C)cc1Cl,1.95587331,4.2626,2.30672669
COCCn1ccc2cc(NC(=O)c3cccs3)ccc21,2.079198728,3.6015,1.522301272
CCOc1ccc(C(=S)NC[C@@H]2CCCO2)cc1,2.528851633,2.5294,0.000548367
CCCc1c(C(=O)NC2CCN(C(=O)OCC)CC2)cnn1-c1ccccc1,2.229874275,3.1755,0.945625725
C#Cc1cccc(NC(=O)NCCC[NH+]2CCCCCC2)c1,3.486046434,1.6384,-1.847646434
CCc1ccc([C@@H](O)[C@@]2(C)CCCO2)cc1,3.291365597,2.8515,-0.439865597
COC(=O)[C@@]1(F)CCN(C(=O)Cc2cccc(OC)c2)C1,2.848120431,1.3513,-1.496820431
CNC(=O)CC1CCN(C(=O)c2nc(-c3ccccc3)oc2C2CC2)CC1,2.470756132,3.2073,0.736543868
O=C1N[C@H](CO)C(=O)N2CCN(Cc3cnn(-c4ccccc4Cl)c3)C[C@H]12,3.293818928,0.0292,-3.264618928
CNc1nc(C)cc(-c2cccc3cnccc23)n1,2.234845867,3.04192,0.807074133
CCn1cc(-c2nnc(SCc3ccccc3Cl)n2C)cn1,2.291699181,3.6442,1.352500819
O=C(CSc1nc(-c2ccco2)nc2ccccc12)Nc1nccs1,2.346398845,4.0771,1.730701155
O=C1CCc2cc(S(=O)(=O)Oc3cccc(C(F)(F)F)c3)ccc2N1,2.341990367,3.3578,1.015809633
CC[C@@H](C)c1noc([C@H]2CCCN(C(=O)Cc3cccs3)C2)n1,3.228844727,3.5933,0.364455273
CCCOc1ccc(Br)cc1C[NH+]1CCC(c2noc(C)n2)CC1,3.614952536,2.89182,-0.723132536
CCNc1cc[nH+]cc1S(=O)(=O)[N-]c1ccc(C)[nH+]c1C,5.185505424,1.75754,-3.427965424
Cc1cc([C@H]([NH3+])C2CCCCCC2)sc1Br,3.776193085,4.07242,0.296226915
CC1(C)CCC[C@H]1NS(=O)(=O)c1ccc(N)cc1Cl,3.001833462,2.7792,-0.222633462
Cc1ccc(Cl)c(-n2nnc(C(=O)Nc3ccccc3F)c2C)c1,2.064234641,3.92894,1.864705359
CCOC(=O)[C@H]1CCCN(C(=S)Nc2ccccc2F)C1,2.457567813,2.7976,0.340032187
C=CC[C@H](C)[C@@H](C)[NH2+][C@@H](C)CS(C)(=O)=O,4.867243714,0.5836,-4.283643714
COc1cc([C@@H]2CC(O)=Nc3c(C(=O)[O-])cn(-c4ccc(Cl)cc4)c32)cc(OC)c1O,3.620591927,3.3406,-0.279991927
CCC([NH3+])(CC)C(=O)N[C@@H]1C[C@H]2C[C@@H]1[C@H]1CCC[C@@H]12,5.358506374,1.728,-3.630506374
O=C(CNC(=O)N1CCC(OCc2ccccc2F)CC1)N1CCCCC1,2.212233062,2.5288,0.316566938
COC(=O)c1ccc(Oc2ncnc(Oc3cccc4cccnc34)c2[N+](=O)[O-])cc1,2.403730151,4.3042,1.900469849
CC[NH2+]C[C@@H](O)CN1C(=O)CCCC1=O,4.011495477,-1.5303,-5.541795477
[NH3+][C@H]1C=C[C@@H](C(=O)NC2(CC(=O)[O-])CCCCC2)C1,4.700839766,-0.8679,-5.568739766
COc1cccc([C@H]2C(C(=O)c3ccc(Cl)cc3)=C([O-])C(=O)N2c2cc(C)on2)c1,3.126908422,3.23022,0.103311578
CC(C)CC[NH2+]Cc1cc(F)ccc1F,3.319744001,2.0743,-1.245444001
CC(C)C[C@H](NC(N)=O)C(=O)NCCc1ccc(Cl)cc1Cl,2.501281448,2.7351,0.233818552
Cc1nc(C(=O)N(C)Cc2ccc(C(F)(F)F)cc2)nn1-c1ccc(F)cc1,2.28443192,4.00582,1.72138808
CCCc1nnc(NC(=O)c2c(Cl)ccc(Cl)c2Cl)s1,2.26792167,4.7031,2.43517833
O=C([O-])c1ccc[nH+]c1NC[C@H](c1ccccc1)N1CCOCC1,3.813179172,0.3496,-3.463579172
COCCOc1ccc(CNC(=O)N2CCC[C@H](C)CC2)cn1,2.587574312,2.4384,-0.149174312
Cc1nccc2c1=C[C@@H](C(=O)N[C@@H](C)c1ccccc1)C(=O)[NH+]=2,4.094326362,-1.09548,-5.189806362
COc1ccccc1NC(=O)c1ccc(NC(=O)Cn2nc(C(=O)[O-])c3ccccc3c2=O)cc1,2.34156721,1.6596,-0.68196721
O=C(NC[C@@H]1COc2ccccc2C1)c1n[nH]c(=O)c2ccccc12,2.702355993,1.9042,-0.798155993
CC1(C)CCC[C@@H]1[NH2+][C@@H]1CCCC1(C)C,4.747288252,2.7072,-2.040088252
CC(C)n1nccc1C(=O)OC[C@H]1CN(Cc2ccccc2)CCO1,2.811890846,2.5218,-0.290090846
CCc1ccc([C@@H](Br)c2ccc3c(c2)C[C@@H](C)O3)o1,3.294719399,4.6497,1.354980601
Cc1ccc(S(=O)(=O)N2CCCCC2)cc1C(=O)N(C)CC[NH+](C)C,2.941929517,0.38612,-2.555809517
CC[C@@H]1CCc2c(sc(=O)n2CN2CC[NH+](C)C[C@@H]2c2ccccc2)C1,4.511671939,1.9538,-2.557871939
C=CCC[C@H](O)c1ccc(F)cn1,3.059250513,2.2203,-0.838950513
CC(C)(C)NC(=S)N/N=C/c1cccnc1,2.423634839,1.6781,-0.745534839
CC1CCN(C(=O)c2ccc(CSc3nc4ccccc4[nH]3)cc2)CC1,2.05068455,4.7273,2.67661545
Cc1c(C)n(-c2ccc(Br)cn2)c2nc[n+](Cc3ccncc3)c(N)c12,2.971114521,3.11284,0.141725479
COc1ccc([C@H]2Nn3c(C)nnc3S[C@@H]2C(=O)N2CCCC2)cc1,3.304234512,1.97662,-1.327614512
Cc1ncc(CC(=O)N2CCc3c(c(C(=O)[O-])nn3CC3CC3)C2)c(=O)[nH]1,3.159885346,-0.82428,-3.984165346
COCc1nc2n(n1)CCC[C@@H]2NC(=O)Nc1cnccc1C,3.196257529,1.78452,-1.411737529
Cc1ccc(NC(=O)C[C@H](c2cccc(C)c2)n2cccc2)cc1,2.522819487,4.72314,2.200320513
CC1CCN(C(=O)CN2CCN(c3cc(C#N)ccn3)CC2)CC1,2.242431466,1.33378,-0.908651466
N#Cc1ccnc(N2CCC(CCC(=O)NCc3cccc(Cl)c3)CC2)c1,2.250365015,3.91968,1.669314985
COCC[C@@H](C)C(=O)Nc1cccc(Oc2ccncc2)c1,2.484806648,3.485,1.000193352
O=C(Cn1nc(C(=O)[O-])c2ccccc2c1=O)Nc1ccc2c(c1)C(=O)c1ccccc1C2=O,2.536245035,1.1741,-1.362145035
CC(C)(C)c1csc(CNC(=O)N[C@@H]2CCc3c(O)cccc32)n1,3.019480995,3.6329,0.613419005
CCn1cncc1[C@@H]1OCC[C@H]1C[NH2+]C(C)(C)C,4.646004856,1.3425,-3.303504856
O=C(CN1C(=O)N[C@]2(CCCc3sccc32)C1=O)N[C@H](c1ccccc1)c1cccs1,3.709928512,3.7988,0.088871488
Cc1ccc(C(=O)N[C@H](CC[NH+](C)C)c2ccc(Cl)cc2)s1,3.43628494,2.71562,-0.72066494
C/C=C/C[S@@](=O)Cc1nc(-c2cccs2)oc1C,3.517488633,3.53632,0.018831367
Nc1nonc1-c1noc(COc2ccc(Br)cc2)n1,2.399600676,2.0433,-0.356300676
COC[C@H]1CCCN(C(=O)N[C@@H]2C[C@](C)(OC)C2(C)C)C1,3.874154216,2.258,-1.616154216
C[NH+](C)[C@H]1CC[C@H](NC(=O)N(Cc2c(F)cccc2Cl)C2CC2)C1,3.994602094,2.2187,-1.775902094
CC(=O)/N=C1\S[C@@H]2CS(=O)(=O)C[C@@H]2N1c1ccc(Br)cc1Cl,3.526675094,2.7238,-0.802875094
Cc1c(Cl)cccc1S(=O)(=O)N1CCC(C(N)=O)CC1,2.029710667,1.53442,-0.495290667
CN(Cc1ncnn1C)C(=O)[C@@H]1CC(=O)N(CC(F)(F)F)C1,3.223841813,0.1843,-3.039541813
CC(=O)/C=C(/NNC(=O)c1cccc([N+](=O)[O-])c1)c1ccccc1,2.269320979,2.4593,0.189979021
CC(C)c1cc(C(=O)NC[C@@H]2CCC[NH+](C(C)C)C2)n(C)n1,4.366273556,0.9766,-3.389673556
CC[C@](C)([NH3+])CO[C@H]1CCOC2(CCCCC2)C1,4.663770551,2.2955,-2.368270551
Cc1ccc(-n2nc(C(=O)NC[C@H]3CCCO3)c(=O)n(Cc3cccc(F)c3)c2=O)cc1Cl,2.883985307,2.45222,-0.431765307
C[C@H](Cc1cccs1)N(C)C(=O)NCC1CC1,2.835861306,2.7305,-0.105361306
COc1ccc(C(=O)N[C@@H](Cc2ccccc2)c2nc3ccccc3[nH]2)cc1[N+](=O)[O-],2.584541779,4.1935,1.608958221
CCCn1c(N)c(Br)c(=O)n(CC(=O)NC)c1=O,2.537467799,-0.4893,-3.026767799
Cc1nc(/C=C/C(=O)Nc2nc(-c3ccccc3)ns2)cs1,2.462739944,3.62192,1.159180056
Cc1ccccc1OCCCC(=O)NCc1ccccc1,1.537882208,3.47042,1.932537792
CC[C@H](C)OCc1ccc(C(=O)NN)cn1,2.8132977,1.0002,-1.8130977
CCCc1noc(-c2cc(S(=O)(=O)NC)ccc2OC)n1,2.261944947,1.6058,-0.656144947
Cc1cc(C)n2cc(CN3C[C@@H](C)C[C@H]([NH3+])C3)nc2n1,4.026418302,0.79844,-3.227978302
Cc1noc(C)c1[C@@H](C)NC(=O)NCCCC[NH+]1C[C@@H](C)C[C@H](C)C1,4.640605204,1.99264,-2.647965204
C[C@@H](NC(=O)C(C)(C)C)C(=O)Nc1ccc(F)c(C(=O)N(C)C)c1,2.557460441,2.0168,-0.540660441
O=c1oc2cccnc2n1CC[NH+](Cc1ccccc1F)C1CCCCC1,3.845392421,2.5464,-1.298992421
COc1cccc(CCC(=O)N2CCC(OCc3ccccc3F)CC2)c1,2.033454364,3.9747,1.941245636
CCn1c(N/N=C/c2ccc(N(C)C)cc2)nc2c1c(=O)[nH]c(=O)n2C,2.59704872,0.9552,-1.64184872
Cc1cc2n(n1)CCC(=O)N2[C@@H](C)C(=O)NCc1ccccn1,3.100623436,1.02812,-2.072503436
Cc1ccc(CC(=O)NC[C@@H]2COc3ccccc3O2)cc1,2.34855287,2.49372,0.14516713
CC[C@H]1COCCN1Cc1c[nH]nc1-c1ccccc1F,3.202033955,2.8266,-0.375433955
CC[C@H](C)N1C(=O)c2ccccc2N[C@@H]1c1ccc(Cl)c([N+](=O)[O-])c1,3.14156896,4.6132,1.47163104
CS(=O)(=O)c1ccc(NC(=O)Cn2c(-c3ccccc3)cc3ccccc32)cc1,1.961669185,4.3505,2.388830815
CS[C@@H]1CC[C@@H](NC(=O)/C=C/c2ccc(OCc3cccnc3)cc2)C1,3.120059729,4.0741,0.954040271
CS[C@H](CO)[C@@H](C)NC(=O)NCc1cc(=O)[nH]c2ccccc12,3.361951895,1.4397,-1.922251895
O=C(NCC1CC1)C1([NH2+]Cc2cnn(-c3ccccc3)c2)CCCC1,3.257340383,1.7747,-1.482640383
CC(C)[C@@H]1CN(C(=O)NC[C@H]2CC[NH+](C3CC3)C2)CCS1,5.10628001,0.8366,-4.26968001
[NH3+]C[C@H](CC(=O)NC1CC1)N1CCn2cnnc2C1,3.791785348,-1.6271,-5.418885348
O=C(NN1Cc2ccccc2C1)c1cnn(-c2ccccc2)n1,2.463468958,1.9279,-0.535568958
COc1ccc(S(=O)(=O)[N-]c2ccc(Cl)c(Cl)c2)cc1[N+](=O)[O-],3.09463754,4.3043,1.20966246
O=C(NCC[C@@H]1C[C@H]2CC[C@@H]1C2)C(=O)Nc1nc(-c2ccccc2)cs1,3.907753927,3.6911,-0.216653927
CCCS(=O)(=O)c1ncc(Cl)c(C(=O)Nc2sc3c(c2C(=O)OC)CCC3)n1,2.655303848,2.9028,0.247496152
CC(C)(C)c1csc(CNc2cc(F)cc(F)c2)n1,2.393327346,4.3309,1.937572654
CCOc1ccc(OCC(=O)N2CCN(c3nccn3-c3cccc(Cl)c3)CC2)cc1,2.239043704,3.652,1.412956296
CNC(=O)[C@@H]1C[C@H]([NH3+])CN1Cc1ccc(-n2cccn2)cc1C,3.745433108,0.11152,-3.633913108
C[C@@H](Cc1nc2ccccc2s1)NCc1ncn(C)n1,3.139160511,2.1456,-0.993560511
CCC(CC)n1ccc(Cn2nnc(N)c2C(C)(C)C)n1,3.056509207,2.7637,-0.292809207
CCC[C@H]1CN(C(=O)[C@H]2C[C@@H]2c2ccccc2C)CCO1,3.194344378,3.12602,-0.068324378
C[C@H]1CCCC[C@@H]1OCCCOc1ccccc1C[NH3+],3.348375729,2.7927,-0.555675729
O=c1oc(-c2ccccc2)nn1CN1CCN(c2ccc(F)cn2)CC1,2.311079692,1.8171,-0.493979692
O=C(CCc1ccco1)NCc1cnn(-c2ccccc2)c1,2.047571641,2.7143,0.666728359
COc1ccc2ccccc2c1/C=C1/N=C(c2cccc([N+](=O)[O-])c2)OC1=O,2.34548966,4.1011,1.75561034
O=C([O-])[C@H]1CCO[C@H]1Cc1ccccc1,3.392413353,0.3841,-3.008313353
O=C([O-])C1[NH+]=c2ccccc2=[NH+]1,5.228431607,-5.8234,-11.05183161
C=CCN1C(=O)N=C(N)[C@@H]1c1ccc(OC(C)C)cc1,3.176426978,2.4937,-0.682726978
Cc1n[nH]c(SCC(=O)NC2(C)Cc3ccccc3C2)n1,2.992188236,1.87892,-1.113268236
C[C@H](NC(=O)NC[C@@H]1C[NH+](C)CCN1C)c1ccc(C(C)(C)C)cc1,4.117589731,1.173,-2.944589731
C=CCN1CCN(S(C)(=O)=O)[C@@H](C(=O)NCc2ccccc2)C1,2.679932748,0.4346,-2.245332748
O=S(=O)(c1cccn2c(SCc3c(F)cccc3Cl)nnc12)N1CCCC1,2.468769119,3.5986,1.129830881
CN[C@@H](c1c[nH+]ccc1N)[C@@H]1CN2CCC[C@@H]2CO1,4.894342078,0.2066,-4.687742078
C[C@@H](Sc1nc2ccccc2n1C)C(=O)N(C)Cc1ccccc1,2.431248545,3.7125,1.281251455
C[C@H]1CCCC[C@@H]1[NH+](C)Cc1cc(F)cc(C#CC[NH3+])c1,5.04524194,1.0125,-4.03274194
O=C(Nc1ccc2c(c1)CCCN2S(=O)(=O)c1ccccc1)c1cc2ccccc2o1,2.17677755,4.8266,2.64982245
C[C@@H]1CCCC[C@H]1[NH2+][C@@H]1CCN(C2CCCCC2)C1=O,4.263772335,2.0621,-2.201672335
CCc1nn(C)c2c1nc(N)n2[C@H](C)c1ccccc1F,3.276272156,2.6628,-0.613472156
CC(C)CNC(=O)CNC(=O)C(=O)Nc1ccc2ccn(C)c2c1,2.263583154,1.0052,-1.258383154
Cc1nc(-c2nnc(NC(=O)c3ccc(S(=O)(=O)N4CCC[C@H](C)C4)cc3)o2)cs1,2.887521357,3.17442,0.286898643
O=C(NCCc1ccccc1)[C@H]1CCCN(C(=O)c2cccc(Cl)c2)C1,2.231072868,3.5511,1.320027132
CC1(C)C[C@H]([NH2+]C[C@@H]2CCC[C@@H]2NS(C)(=O)=O)C(C)(C)O1,4.794950037,0.6138,-4.181150037
Cc1ccc2cccc(NC(=O)[C@H](C)SCc3nc4scc(-c5ccccc5)c4c(=O)[nH]3)c2n1,2.981012578,5.76862,2.787607422
CN(C(=O)CS(=O)(=O)c1nnc(N2CCCC2)s1)C1CC1,2.737794284,0.5328,-2.204994284
CC(C)c1ocnc1C[S@@](=O)[C@@H](C)c1ccc(F)c(F)c1,3.919654427,4.0861,0.166445573
C/C(Cn1nnc2ccccc21)=N\NC(=O)c1ccc(Br)cc1,2.193402011,2.9997,0.806297989
CCOc1ccc(-n2ccnc(SCC(=O)Nc3ccccc3Cl)c2=O)cc1,2.118518314,4.0154,1.896881686
Cc1ccc(OC(F)F)c(CN2CCN(c3nc(C)ns3)CC2)c1,2.451202487,3.07854,0.627337513
Cc1cc2ncc([C@H](C)NC3CC[NH+](Cc4ccccn4)CC3)c(C)n2n1,4.106032583,1.63924,-2.466792583
CC[NH2+][C@@]1(C(=O)[O-])CC[C@@H](n2cnc3ccccc32)C1,4.410980043,-0.1667,-4.577680043
CCCN1C(=O)CCc2cc(NC(=O)N(CC)CC(C)(C)O)ccc21,2.532985399,3.0005,0.467514601
COC(C)(C)c1noc(-c2cc(Br)cn2C(C)C)n1,3.212215459,3.763,0.550784541
Cc1nc(NC[C@H](c2ccc(F)cc2)[NH+]2CCCC2)cc(-c2cc[nH]n2)n1,4.361277329,2.14612,-2.215157329
C[C@@H](O)CCC(=O)Nc1cccc(NC(N)=O)c1,2.34605947,1.2767,-1.06935947
CC(C)(C)c1csc([C@H]2CN(Cc3ccc(O)cc3O)CCO2)n1,3.143730594,3.4253,0.281569406
Cc1ccc(-c2c[n+](-c3cccc(C(F)(F)F)c3)c3n2CCCCC3)cc1,2.729851836,5.48542,2.755568164
O=c1[nH]c(CCl)nc2ccc(O)cc12,2.419924373,1.3675,-1.052424373
COc1cc([C@H](C)N[C@@H]2CCCc3nc(C)sc32)ccc1F,3.302915814,4.32742,1.024504186
Cc1ccc([N-]S(=O)(=O)[C@@H]2CCC[NH2+]C2)c(C)[nH+]1,6.031753882,0.17834,-5.853413882
Cc1ccccc1-n1nc2c(c1NC(=O)c1cc([N+](=O)[O-])ccc1Cl)CS(=O)(=O)C2,2.635545395,3.42302,0.787474605
Cc1ccc(NC(=O)C[C@@H]2C(=O)Nc3nc(-c4ccccc4)nn32)cc1C,2.991716945,3.08394,0.092223055
CC(C)[C@H]1CC[C@@H](C)C[C@H]1[NH2+]CC1(CS(C)(=O)=O)CC1,4.753834564,1.8354,-2.918434564
O=C(NCCc1nnc(-c2ccccc2)o1)c1ccc2c(c1)C(=O)NC2=O,2.221256139,1.5927,-0.628556139
Cc1ncc(CC(=O)N2CCC(=O)N(CC3CC3)[C@@H](C(C)C)C2)c(=O)[nH]1,3.25717113,1.11632,-2.14085113
CC[NH2+][C@@]1(C(=O)[O-])CCC[C@@H]1CCN1CCS(=O)(=O)CC1,4.746850504,-2.021,-6.767850504
Cc1ccccc1NN/C=C1/C(=O)NC(=O)c2ccccc21,2.204343137,2.22262,0.018276863
CCN(CCC(=O)NC1CC1)C(=O)c1cnc(C)cn1,2.306046425,0.91582,-1.390226425
Cc1cccc([C@H]2CCC[NH2+]2)c1,4.094583094,1.39332,-2.701263094
CS(=O)(=O)N1CCc2cc(C(=O)Nc3cc(Br)ccc3Br)ccc21,2.207749587,3.786,1.578250413
CC(C)Cn1ncc(C(=O)N(C)[C@@H](C)Cc2ccsc2)c1C(F)F,3.336934928,4.2414,0.904465072
Cc1ccc(F)c(NC(=O)C(=O)NCCC[NH+]2CCCCCC2)c1,3.312323302,1.03782,-2.274503302
Cc1cc(NC(=O)[C@H](C)NS(=O)(=O)c2cccc(Cl)c2)ccc1N1CCCC1=O,2.630654339,3.08072,0.450065661
O=C(Cn1nc(-c2ccccc2F)ccc1=O)NC[C@@H]1CCCO1,2.613433534,1.3446,-1.268833534
CC(C)COC(=O)N1CCC([NH2+]CCc2nnc3ccccn23)CC1,3.358366621,1.0922,-2.266166621
CC(=O)N1CC[C@@H]([NH2+][C@H](C)c2c[nH]c3ccc([N+](=O)[O-])cc23)C1,4.043087472,1.3213,-2.721787472
Cc1cnc(C(=O)N2CCC(c3[nH]ncc3-c3ccncc3)CC2)cn1,2.56977868,2.58992,0.02014132
CCCNC(=O)CN1CCC([NH2+][C@H](C)c2ccc(F)cn2)CC1,3.633696098,0.8357,-2.797996098
C1CCC([NH2+]CC[C@H]2C[NH2+]CCO2)CC1,5.438513113,-0.7652,-6.203713113
CCOC(=O)C1(C)CCN(C(=O)CSc2ccccc2Cl)CC1,2.317976552,3.6239,1.305923448
CCN(CC)C(=O)[C@@H]1CCC[NH+]1Cc1cnc(Cl)s1,4.76589579,1.2122,-3.55369579
CCOc1ccccc1N/C(=C\C(=O)OC)C(=O)NC,2.268032439,1.3001,-0.967932439
CNC(=O)COc1ccccc1-c1noc(-c2ccc(S(=O)(=O)N(C)C)o2)n1,2.535585059,1.3717,-1.163885059
COC(=O)CSc1cc(Cl)ccc1Cl,1.950704472,3.2585,1.307795528
COC(=O)[C@@H]1C(=O)C2=C(C[C@@H]1C)Nc1ccccc1N[C@@H]2c1ccc(Cl)c(F)c1,3.528144568,4.71,1.181855432
COCC[C@@H](C)C(=O)Nc1ccccc1S(=O)(=O)C(F)F,2.730644768,2.294,-0.436644768
CCN(c1ccc(C(=O)NCc2ccc([S@](C)=O)cc2)cc1)C(C)C,2.670970361,3.5887,0.917729639
CC[C@@H]1CCCC[C@H]1[C@H]([NH3+])c1c(F)cccc1F,3.827628952,3.4642,-0.363428952
Cc1cccc(C(=O)NN2Cc3ccccc3C2)c1F,2.30064616,2.79472,0.49407384
CCn1c(=O)n(CC(=O)Nc2cc(OC)ccc2OC)c(=O)c2ccccc21,2.001306361,1.839,-0.162306361
Cc1cc(=O)c(C(=O)N2CCC[C@H](Cn3cc(CC4CCCCC4)nn3)C2)c[nH]1,3.262503707,2.95002,-0.312483707
Cc1c(-c2ccnc(N3c4ccccc4C[C@H]3C)n2)nnc2nc(-c3ccco3)nn12,3.333401872,3.62742,0.294018128
Cc1ccc(NC(=O)c2cnn(-c3ccccc3)c2C(F)(F)F)cc1S(=O)(=O)N(C)C,2.297134473,3.70212,1.404985527
CC[C@@H](CSC)N(C)C(=O)Nc1ccc(C(=O)N(C)C)c(C)c1,2.971901104,3.30212,0.330218896
COc1cc2c(cc1NC(=O)c1cc(OC)c(OC)c(OC)c1)CCCC2,2.044842359,3.8521,1.807257641
COc1cc(C)c(Br)cc1[C@@H](Cl)[C@H]1CCO[C@@H]1C,3.81993083,4.47102,0.65108917
Cc1cnc(SCC[NH+]2CCOCC2)nc1[C@@H]1CCCN(C(=O)/C=C/c2cccnc2)C1,4.141789854,1.60662,-2.535169854
Cc1ccc(NC(=O)CCn2nnc3cc(C(=O)NCc4cccc(Cl)c4)ccc32)cc1,2.111270605,4.35192,2.240649395
Cc1ccc(SCCC(=O)NC[C@@H]2CN3CCCC[C@@H]3CO2)cc1,3.075083992,2.84672,-0.228363992
Cn1cc(CNC(=O)c2cc(-c3ccccc3)nc3ccc(Br)cc23)cn1,2.113376849,4.3278,2.214423151
Cc1cccn2c(=O)c(C(=O)Nc3ccc([S@](C)=O)c(F)c3)cnc12,2.962959658,2.13172,-0.831239658
Cc1nonc1C(=O)NCc1csc(=O)[nH]1,3.075637706,0.05782,-3.017817706
COc1cc2[nH]c(C(=O)N3CCN(c4ccc(F)cc4)CC3)cc2c(OC)c1OC,2.277064802,3.2952,1.018135198
O=C([O-])[C@H]1C[C@@H]2CCCC[C@@H]2N1C(=O)[C@@H]1CCS(=O)(=O)C1,4.106939223,-0.6693,-4.776239223
CCn1c([C@@H](C)[NH2+]CC2([C@H](O)c3ccccc3)CC2)nc2ccccc21,3.791835426,3.1944,-0.597435426
CCSC[C@H](C)N(C)C(=O)CCc1c[nH]c2ccccc12,2.930423354,3.7005,0.770076646
COc1ccc2c(C)c(CCC(=O)Nc3c(C(C)C)cccc3C(C)C)c(=O)oc2c1,2.335432391,5.92812,3.592687609
O=C(NNC(=O)c1cccc(NC(=O)c2ccco2)c1)c1ccccc1,1.806007951,2.6067,0.800692049
Cc1ncc([C@H](C)NC[C@H]2CN3CCCC[C@@H]3CO2)c(C)n1,3.635850378,1.99734,-1.638510378
CC[C@H]1CCCC[C@H]1n1cc(C(=O)[O-])c(=O)[nH]c1=O,3.611392278,0.0414,-3.569992278
N#C[C@H](C(=O)c1cccc(N2CCCC2=O)c1)c1ccc(C(F)(F)F)cn1,2.994926879,3.71728,0.722353121
CCOc1cccc([C@@H](C)[NH2+][C@H](C)Cn2cnc(C)c2C)c1,3.8512171,2.61174,-1.2394771
O=C(OCc1nnsc1Cl)c1cccc(N2CCCS2(=O)=O)c1,2.690322495,2.0884,-0.601922495
Cc1cc(NC(=O)c2cc(Cl)ccc2Cl)n(C(C)C)n1,2.145072206,4.33152,2.186447794
Cc1ccc(OC[C@@H](O)C[NH2+]C[C@@]2(C)CCCO2)cc1,3.972963039,0.86722,-3.105743039
CSC(C)(C)CNC(=O)c1ccc(S(=O)(=O)NC2CC2)cc1,2.376249042,1.9987,-0.377549042
Cc1ccccc1-n1nnnc1[C@@H](c1ccccc1)N(C)Cc1cccc([N+](=O)[O-])c1,2.833420949,4.10032,1.266899051
COc1cc(C)c(NC(=O)N2CC[C@@H](C)C[C@@H]2C)cc1OC,2.80715971,3.66452,0.85736029
CCC[NH2+]CC/C=C(/C)c1ccc(C)c(C)c1,3.433729102,3.07024,-0.363489102
CS(=O)(=O)c1nc(C(=O)Nc2ccc3c(c2)Cc2ccccc2-3)c2ccccn12,2.440662697,3.5613,1.120637303
O=C(N[C@@H](c1ccc2c(c1)OCCO2)C1CC1)NC1(c2noc(C3CC3)n2)CCCC1,3.346508313,3.938,0.591491687
CC(=O)Oc1ccccc1-c1nc2s/c(=C\c3cc(C)c(OC(C)=O)c(C)c3)c(=O)n2n1,2.667587844,2.83314,0.165552156
O=C(NCC(=O)N1CCN(C(=O)c2ccccc2)CC1)c1ccc(Br)o1,2.03066504,1.7565,-0.27416504
O=C(NC1CCCC1)[C@@H]1CCCN1C(=O)Nc1cccc(Cl)c1,2.401089697,3.3951,0.994010303
Cc1c(C(=O)N2CCC3(CC2)C(=O)NCCC[NH+]3C)cnc2ccnn12,4.613407565,-0.95288,-5.566287565
CC[C@@H](C)CN(C)c1ccccc1C[NH2+]C1CC1,3.802154805,2.3947,-1.407454805
O=C(Nc1ccnn1Cc1cccc(Cl)c1Cl)c1cccc(S(=O)(=O)NC2CCCC2)c1,2.387381354,4.7114,2.324018646
COc1ccc(C(=O)Nc2ccccc2C(=O)Nc2ccccn2)cc1,1.688012874,3.5948,1.906787126
CC(=O)Nc1cccc(-n2cc3c(c2-c2ccccc2Cl)c(=O)n(C)c(=O)n3C)c1,2.439185247,3.3067,0.867514753
Cc1ccc2nc(C)c(C(=O)NCC(C)(C)c3ccncc3)cc2c1,2.306965961,3.95424,1.647274039
CSc1ccsc1C(=O)NC1(c2cccc(Cl)c2)CC1,2.584660576,4.5425,1.957839424
O=C(OCc1nnc(-c2ccccc2Cl)o1)C1CCOCC1,2.256423598,2.8598,0.603376402
COc1ccc(-c2csc(NC(=O)CSc3nc4ccccc4[nH]3)n2)cc1OC,2.179890746,4.4344,2.254509254
CCc1cc([C@@H]2CCC[NH2+]2)cc(CC)c1Cl,4.274119239,2.8631,-1.411019239
Cc1nc2sc3c(NC[C@H]4CCCO4)ncnc3c2c2c1COC(C)(C)C2,3.430342051,3.99012,0.559777949
Cc1cncc(NC(=O)NCCc2cccc(F)c2)c1,1.947457496,2.89332,0.945862504
CN(C)c1ccc(NC(=O)[C@@H]2CCCN2C(=O)C2CC2)cn1,2.57703349,1.4871,-1.08993349
O=C1O[C@H](CCN2CCOCC2)CN1Cc1ccccc1,2.614376716,1.7297,-0.884676716
O=C(Cn1c(Cl)nc2ccccc21)c1ccc2c(c1)OCCO2,2.226699819,3.3438,1.117100181
CC(=O)N[C@H](CC(=O)Nc1nc[nH]n1)c1ccc(C)cc1,3.10270471,1.31912,-1.78358471
CCC[C@@]1(CC)NC(=O)NC1=O,3.343567709,0.7747,-2.568867709
CC(C)[C@@H](O)c1ccn(CCC(C)(C)C)c1,3.326211986,3.6137,0.287488014
O=C(CNS(=O)(=O)c1ccccc1Cl)NCc1ccco1,2.015111156,1.5277,-0.487411156
CCSC1=C(C#N)[C@H](CC(C)C)C2=C(CCCC2=O)N1,3.509414391,3.74728,0.237865609
C[C@H]1CC[C@H](NC(=O)NC[C@](C)(O)c2ccsc2)[C@H](C)C1,3.828507393,3.0795,-0.749007393
O=C(Cc1ccc(-n2cccn2)cc1)Nc1cccc(CN2CCOC2=O)c1,2.246151822,3.0057,0.759548178
COCCSc1ccccc1C(=O)N1C[C@H](C)O[C@@H](C)C1,2.919174745,2.6745,-0.244674745
CN[C@@H]1[C@@H](C[NH+](C)CCO)C(C)(C)OC1(C)C,5.466499996,-0.715,-6.181499996
CC(C)COC1CCN(C(=O)/C=C/c2cccc3cccnc23)CC1,2.385662775,3.9116,1.525937225
Cc1cc(-n2c(C)cc(C(=O)CN3C(=O)NC(c4ccccc4)(c4ccccc4)C3=O)c2C)no1,2.659919657,4.06886,1.408940343
CCC[NH2+][C@@H](CC[C@H]1CCCO1)[C@@H]1C[NH+]2CCN1CC2,6.424619424,-1.1297,-7.554319424
CC[C@@H](C)NC(=O)[C@H](NC(=O)c1cccc(C)c1)C1CCN(C(=O)c2ccccc2)CC1,2.852569634,3.56052,0.707950366
Cc1ccc([C@H]2CSCCN2C(=O)N[C@H]2CC[C@H]([NH+](C)C)C2)cc1,4.245038493,1.86012,-2.384918493
C[C@H](NC(=O)c1ccnn1C)C12CC3CC(CC(C3)C1)C2,4.194819005,2.7548,-1.440019005
Cn1ncnc1CCNC(=O)N1CCC[C@H]1Cc1ccccc1,2.744029526,1.7743,-0.969729526
O=C([O-])COc1cc(-c2ccncc2)nc(N2CCN(c3ccccc3F)CC2)n1,2.663513821,1.133,-1.530513821
N#C[C@H]1CNCCN1C(=O)c1ccsc1,3.291177871,0.68568,-2.605497871
CC[NH+]1CC[C@H](N(C)CC(=O)OC(C)(C)C)[C@H](C)C1,4.803468008,0.5731,-4.230368008
O=C(Cc1ccc2c(c1)OCCCO2)NCCNC(=O)[C@H]1C=c2ccccc2=[NH+]1,3.610596228,-1.8141,-5.424696228
C[C@@H]1CCN(C(=O)NCCn2c(=O)[nH]c3ccccc32)C[C@@H]1O,3.115989084,0.7419,-2.374089084
CCC1(CC)CC(=O)N(CCOc2ccccc2)C1=O,2.279792993,2.6307,0.350907007
CSc1cccc(NC(=O)[C@@H]2CCCN2C(=O)c2cccs2)c1,2.519958928,3.7133,1.193341072
CCOC(=O)c1ccc(N2C(=O)c3oc4ccc(F)cc4c(=O)c3[C@@H]2c2ccc(SC)cc2)cc1,2.869982169,5.5805,2.710517831
Cc1ccsc1CN[C@@H]1COC[C@@H]1O,3.438656169,0.90582,-2.532836169
CO[C@H](C)C(=O)Nc1cccc(NC(=O)N(C)[C@@H](C)c2cccnc2)c1,2.920062765,3.2799,0.359837235
CN(C(=O)c1ccc2c(c1)CCO2)C1CCC(=O)CC1,2.371567617,2.2052,-0.166367617
CC(C)c1csc(/C(C#N)=C\c2csc(-c3ccc(F)cc3)n2)n1,2.706482804,5.59328,2.886797196
COc1ccc(CC(=O)N2CCO[C@@H](COCc3cccc(C)c3)C2)cc1,2.523511281,2.99032,0.466808719
O=C(COc1ccc(Cl)cc1)N[C@@H](c1ccccc1)c1cccs1,2.260327084,4.6861,2.425772916
CCOC(=O)[C@H]1C(=O)C(=O)N(c2ccccc2)[C@@H]1c1ccc([N+](=O)[O-])cc1,2.954844007,2.4311,-0.523744007
C[NH+](C)CCCOc1ccc(NC(=O)[C@H]2CCC(=O)O2)cc1,3.481344542,0.2441,-3.237244542
CC(C)(C)OC(=O)N(CCNC(=O)N1CCC([NH+]2CCCC2)CC1)C1CC1,3.689186885,1.2386,-2.450586885
CC1=C[C@@H](c2ccc(S(=O)(=O)Nc3ccc(Cl)cc3C)o2)N=N1,3.54172818,4.45292,0.91119182
Cc1nn(-c2ccccc2)c(COC(=O)C23C[C@@H]4C[C@@H](CC(O)(C4)C2)C3)c1C#N,4.885278242,3.4269,-1.458378242
CC(C)C(=O)Nc1cccc(NC(=O)C(=O)NCC(C)(C)c2ccncc2)c1,2.297021131,2.7086,0.411578869
Cc1cc(C)n2nc(C)c(C(=O)O[C@@H](C)[C@@H]3CCCO3)c2n1,3.472079899,2.37886,-1.093219899
CCCN[C@]1(C#N)CC[C@@H](Sc2nc(C)c(C)c(C)n2)C1,3.865569071,3.30844,-0.557129071
CCOC(=O)C1(c2nc([C@H]3CC[C@H](C)C3)no2)CCCC1,3.496369973,3.3481,-0.148269973
CCCNC(=O)c1ccc2c(c1)c(C)c(C)n2C,1.967924688,2.93494,0.967015312
CCCOc1ccc(NCC(=O)Nc2ccc(F)cc2)cc1,1.632813467,3.6651,2.032286533
S=c1n(CN2CCSCC2)nc(N2CCCCC2)n1-c1ccccc1,2.583555394,3.39989,0.816334606
Cc1cc(C(F)F)nn1CC(=O)N1CC(c2ccncc2)C1,2.61704593,2.15012,-0.46692593
O=[N+]([O-])c1ccc(Cl)c(S(=O)(=O)/N=C([O-])/C=C/C2CC2)c1,3.166163675,1.6619,-1.504263675
CC[C@@H](C)n1ncc(C(=O)N2CCO[C@H](C#N)C2)c1C1CC1,3.768596419,2.09608,-1.672516419
O=C(NC(C1CC1)C1CC1)N1CC[C@@H](Oc2ccc(C(F)(F)F)cn2)C1,3.144221022,3.4517,0.307478978
C[C@H]1CCCC[C@@H]1OCCN(C)Cc1nc(=O)c2ccccc2[nH]1,3.075784642,2.9502,-0.125584642
COc1ccc(Cl)cc1S(=O)(=O)N1CCN(C(=O)c2cc3ccccc3oc2=O)CC1,2.197795481,2.6017,0.403904519
C[C@@H]1CCN(c2ncc([N+](=O)[O-])cc2Cl)[C@H](C)C1,3.208323663,3.268,0.059676337
Cc1nc(N2CCN(C(=O)c3cccc(F)c3)CC2)cc(-n2cnc(C)c2C)n1,2.473499443,2.68906,0.215560557
COC(=O)[C@H]1CCCCN1C[C@H](O)COc1ccc(Cl)cc1Cl,2.853293962,2.7606,-0.092693962
COC[C@](C)(CCO)[NH2+]Cc1ncccc1C,4.061519898,0.24092,-3.820599898
O=C(NC1CCCCC1)C1CCN(C(=O)c2cc3cc(Cl)cnc3[nH]2)CC1,2.420386909,3.5174,1.097013091
CC(C)(C)N(CCC(=O)[O-])C(=O)N1CC[NH+]2CCC[C@H]2C1,5.296421228,-1.2902,-6.586621228
COC(=O)c1cc(NC(=O)NCc2ncc[nH]2)ccc1C,2.150937232,1.82642,-0.324517232
CN(Cc1ccc(C#N)cc1)C(=O)NCCN(C)c1ccccc1,2.099813573,2.83608,0.736266427
CCCNC(=O)C1CCN(C(=O)c2cccc(NS(=O)(=O)c3ccc4c(c3)CCCC4)c2)CC1,2.239853257,3.7446,1.504746743
O[C@H]1Cc2ccccc2[C@H]1[NH2+]Cc1cnc(-c2ccsc2)s1,4.127840834,2.5933,-1.534540834
N#Cc1cc([N+](=O)[O-])ccc1Oc1ccc(/C=N\NC(=O)c2ccccc2-n2cccc2)cc1,2.503753463,4.81338,2.309626537
C[C@H]1C[C@H](C)CN(C(=O)[C@@H](C)S(=O)(=O)c2nc3ccccc3[nH]2)C1,3.481529855,2.2296,-1.251929855
CCN(C(=O)CN1CC[NH+](CC2CC2)CC1)c1cccc2ccccc12,3.297787028,1.8032,-1.494587028
Cc1ccc(Nc2nc3c(c(=O)[nH]2)CCCC3)c(F)c1,2.46036033,2.83982,0.37945967
CCC[C@H]1[C@H](C)CCCN1C(=O)NCCCCSC,3.291429317,3.7398,0.448370683
Cc1noc(C)c1CN1C(=O)N(c2cccc(Cl)c2)c2ncccc2S1(=O)=O,2.790302811,3.80244,1.012137189
CCC[NH2+][C@@H](C)c1ccc(SCC(C)C)cc1,3.752956887,3.4691,-0.283856887
C[C@@H](C(=O)NCCc1ccc(F)cc1)[NH+](C)Cc1ccc(Cl)nc1,3.667307699,1.6362,-2.031107699
CC(C)(C)N1C[C@@H](C(=O)n2ccnc2)CC1=O,3.241710455,1.1703,-2.071410455
CN(c1ccc(C(=O)Nc2ccccc2C(=O)N2CCCC2)cc1)S(C)(=O)=O,1.980031213,2.5707,0.590668787
Nc1ccc(=O)n(CCN2CC[NH+]3CCC[C@@H]3C2)c1,4.656713618,-1.2066,-5.863313618
Cc1cc(NC(=O)C(=O)N2CCN(c3ccc(Br)cn3)CC2)no1,2.379687103,1.42782,-0.951867103
Cc1ccc(C(=O)N2CCN(C(=O)CN3CCOCC3)CC2)s1,2.130681032,0.67312,-1.457561032
COc1ccc(C(=O)NCc2ccc(C[NH+]3CCC[C@H](C)C3)cc2)cc1[N+](=O)[O-],3.752516678,2.3482,-1.404316678
C[NH+](Cc1ccc(OC(F)(F)F)cc1)C[C@H]1CCOC1,4.01571473,1.6364,-2.37931473
COC(=O)c1ccc(C(=O)N2CCC(Oc3nc4c(C)ccc(C)c4s3)CC2)cc1,2.358502992,4.38334,2.024837008
CC[C@@H](C)[C@@H]([NH2+]C)C1([NH+](C)C)CCC(C)CC1,5.323609858,0.6877,-4.635909858
CC[C@H](C)NC(=O)[C@H](C)NC(=O)Nc1ccc2c(c1)COC2,2.971019369,2.1415,-0.829519369
Cc1nn(CC(=O)Nc2ccc(N3CCOCC3)cc2)c(=O)c(C#N)c1C,2.242124545,1.20712,-1.035004545
O=C([O-])CCCS(=O)(=O)Nc1ccc(Cl)cc1F,2.651310618,0.7509,-1.900410618
CCOC(=O)c1ccccc1[N-]S(=O)(=O)c1cccc(F)c1,2.877115287,3.3965,0.519384713
CC(C)(C)C(=O)Nc1ccc(NC(=O)c2ccccc2)nc1,1.82913395,3.3185,1.48936605
Cc1ccc(/C=C/C(=O)N(C)Cc2sccc2C)o1,2.502480003,3.62974,1.127259997
CC(C)c1noc(CN(C)C(=O)c2csc3nc(-c4cccc(F)c4)cn23)n1,2.730284857,3.9805,1.250215143
CCC[C@H](CCCc1c([O-])[nH+]c(N)[nH]c1=O)C(=O)[O-],4.655077806,-1.6663,-6.321377806
Cc1ccsc1-c1nc(C)c(C)nc1N,2.624802162,2.71256,0.087757838
Cn1nnc2cc(C(=O)NCC(C)(C)c3ccncc3)ccc21,2.42817859,2.0709,-0.35727859
O=C(c1cccc(C2=NO[C@@H]([C@@H]3N=NC4=C3CCCC4)N2)c1)N1CCCC1,3.976111738,3.1926,-0.783511738
COC[C@H](O)CN(C)C(=O)c1cscc1C,3.166815768,1.13582,-2.030995768
COc1ccc2c(c1)C(N1CCN(C(=O)c3ccc(F)cc3)CC1)=Nc1ccccc1O2,2.297544373,4.4763,2.178755627
NS(=O)(=O)N1CCC(C(=O)NCCNC(=O)c2ccco2)CC1,2.251835206,-0.9589,-3.210735206
CCC(CC)C(C)(C)C[C@@H](C)[NH2+]C,4.576779185,2.4206,-2.156179185
CCc1ccc2ncc(C#N)c(Nc3cc(OC)ccc3OC)c2c1,2.189201832,4.42968,2.240478168
C[C@@H]1CCCC[C@H]1NC(=O)N(C)Cc1nccs1,3.124031281,2.8632,-0.260831281
CCc1ccc(OCC(=O)N(C(C)C)C2CCOCC2)cc1,2.18664966,3.0438,0.85715034
COC1CCN(C(=O)c2ccc(N)c(C)c2)CC1,1.975924657,1.82822,-0.147704657
CCOC(=O)c1cc(S(=O)(=O)N2CCN(c3ccc(C(C)=O)cc3)CC2)cn1C,2.234633575,1.9153,-0.319333575
O=C(c1occc1Br)N1CCS[C@H](c2ccccc2)C1,3.097373773,3.9724,0.875026227
Clc1cc(N2CCC[C@H](OCc3cccnc3)C2)nc2[nH]ccc12,3.074239667,3.7969,0.722660333
CC(C)[NH2+]Cc1ccc([N+](=O)[O-])c(OC/C(Cl)=C/Cl)c1,3.652626052,2.7644,-0.888226052
Cc1cc(C)nc(Nc2nc(CC(=O)Nc3ccc(-n4cnnn4)cc3)cs2)n1,2.491890746,2.45044,-0.041450746
Cc1c2c(cc3c1O/C(=C\c1c(F)cccc1Cl)C3=O)CN(CCCN1CCOCC1)CO2,2.96219473,4.27792,1.31572527
CSc1ncc(CNc2cc(-c3ccccc3)nc3ccnn23)n1C,2.639763986,3.4638,0.824036014
CC[C@@H](C)c1ccc(S(=O)(=O)[N-]c2cc(F)ccc2F)cc1,3.448403351,4.8724,1.423996649
COCc1cccc(NC2=[NH+]CCCCC2)c1,3.269881966,1.298,-1.971881966
C[C@@H](Oc1ccccc1F)C(=O)NNC(=O)Nc1ccccn1,2.508059175,1.8409,-0.667159175
Cc1ccc(-n2c(C)cc(/C=N\NC(=O)CN3CCOCC3)c2C)cc1[N+](=O)[O-],2.45769911,2.09306,-0.36463911
CCN1C(=O)Cc2cc(NC(=O)c3ccc(Cl)c([N+](=O)[O-])c3)ccc21,2.169784347,3.4095,1.239715653
Cc1cc(I)ccc1NC(=O)C[NH+](C)[C@@H]1CCc2ccccc21,3.83053583,2.74032,-1.09021583
CCn1cc(NC(=O)[C@H]2CSCN2C(=O)CC(C)(C)C)ccc1=O,3.085154336,2.1444,-0.940754336
N#Cc1c(Cl)nsc1N[C@@H]1CCc2cc(Cl)ccc21,3.258787912,4.42098,1.162192088
Cc1cc(C)c(NC(=O)CNC(=O)[C@@H](NC(=O)c2ccco2)C(C)C)c(C)c1,2.611430454,2.71416,0.102729546
C[C@@H]1CCC[C@@H](N(C)C(=O)NCc2cccc(Cn3cccn3)c2)C1,3.054786435,3.6515,0.596713565
C[C@H](NC(=O)CN(C)Cc1nc2ccccc2c(=O)[nH]1)c1ccccc1,2.556272893,2.2323,-0.323972893
Cc1ccc(Oc2ncnc(Nc3ncccc3C)c2[N+](=O)[O-])cc1C,2.404361839,4.24096,1.836598161
O=C(CCn1cncn1)Nc1nc(-c2ccc3ccccc3c2)cs1,2.191495935,3.5836,1.392104065
Cc1cc(C)c(NC(=O)[C@H](C)N(C)OCc2ccccc2)c(Cl)c1,2.804898113,4.34744,1.542541887
Cc1ncc([C@@H](C)[NH2+][C@H](C)c2ccc3c(c2)CC(=O)N3)s1,4.219180742,2.33172,-1.887460742
CCc1ccc(C(=O)OCC(=O)N2CCCC2)o1,2.044629001,1.6212,-0.423429001
O=C(NCCCn1ccc2ccccc21)C1CCN(C(=O)/C=C/c2ccccc2)CC1,2.216937347,4.0996,1.882662653
CC/[NH+]=C(/NCc1nnc2n1CCCCC2)N[C@@H]1C[C@@H]1C,4.195178435,-0.4514,-4.646578435
CC[C@H]1CCCN1c1nc(C(F)(F)F)ccc1C(N)=S,3.041634251,3.1134,0.071765749
Cc1ccc(C)c(NC(=O)[C@@H](C)[S@@](=O)Cc2nccn2C(F)F)c1,3.483063654,3.17094,-0.312123654
C[C@@H]1CCCC[NH+]1CCNC(=O)CCc1ccc2c(ccn2C)c1,4.085937819,1.6844,-2.401537819
Cc1ccc(OCCOc2ccccc2)c(C(=O)[O-])n1,2.306296362,1.21132,-1.094976362
Cc1csc([C@@H](C#N)C(=O)CSc2nnc(-c3ccccc3F)n2C2CCCCC2)n1,3.136642384,5.3229,2.186257616
O=C(CCl)N1CCCC2(CCCCC2)C1,2.906406136,2.7981,-0.108306136
N#CCN1CCN(CC(=O)NCCCC2CCCCC2)CC1,2.20926879,1.60428,-0.60498879
COc1ccc(F)cc1CSCc1nc(CC(C)C)no1,2.372806593,3.8492,1.476393407
[NH3+][C@@H]1CCCC[C@@H]1/N=C/[C@@H]1C(=O)N=C(S)N(c2ccccc2)C1=O,4.498557498,1.0857,-3.412857498
CNS(=O)(=O)c1cc(NC(=O)Cc2cccc(F)c2F)ccc1C,2.073399905,2.36252,0.289120095
N#CCCN(CCC(F)(F)F)C(=O)c1ccc2c(c1)CCCC2,2.435420061,3.87368,1.438259939
CCCC(=O)N1CCC(n2nc(C(=O)OCC)cc2C2CC2)CC1,2.485511105,2.9008,0.415288895
CC(=O)Nc1c(N)cc2c3c(cccc13)C(=O)c1ccccc1-2,2.268598471,3.5918,1.323201529
CCN(Cc1cccs1)CN1C(=O)N[C@](CC)(c2ccccc2)C1=O,2.892918554,3.3848,0.491881446
C[C@@H](O)[C@H]1CCOC2(CCC2)C1,4.084328607,1.7165,-2.367828607
O=C1CCCN1C[C@H](NC(=O)N1CCN(c2nccs2)CC1)c1ccccc1,2.838055644,2.3384,-0.499655644
C=CCc1cc(N)ccc1O[C@H]1CCOC2(CCCC2)C1,3.696392259,3.8679,0.171507741
COC[C@@H]1CN(C(=O)C(=O)Nc2ccccc2-c2ccccc2)CCO1,2.611122099,2.1659,-0.445222099
O=C(CCc1ccc(Cl)cc1)NCCc1nc2cc(F)ccc2s1,2.030294342,4.3803,2.350005658
Cc1cn2c(=O)c(C(=O)N3CCC[C@H](n4cccn4)C3)cnc2s1,3.246150591,1.73822,-1.507930591
Brc1cc(C[NH+]2CCC[C@@H](c3cn[nH]c3)C2)cc2c1OCCCO2,4.62014089,2.296,-2.32414089
COc1ccccc1OCC(=O)NCC1(c2cccs2)CCCCC1,2.253408876,4.1538,1.900391124
CC[C@@H](C)Oc1cccc(C(=O)Nc2cccc(-c3nn4cnnc4s3)c2)c1,2.865814889,4.2824,1.416585111
Cc1ccc(CN[C@@H](C(=O)NC2CC2)c2ccccc2)s1,2.532502845,3.16602,0.633517155
C[C@@H](NC(=O)NNC(=O)[C@@H]1CCCO1)c1ccc(Cl)cc1,2.857988826,1.9104,-0.947588826
CC[C@H](O)Cn1nc(Cn2cncn2)nc1Cc1cccc(Cl)c1,3.119028851,1.933,-1.186028851
CC(C)[C@H](Sc1nnc(-c2ccco2)n1C)C(=O)Nc1ccc(Cl)cc1,2.778211591,4.4839,1.705688409
CO[C@H](C)c1noc(CNc2ccc(Sc3ccccc3)cc2C)n1,2.950654307,4.84872,1.898065693
C[C@@H](c1ccccc1)N1C[C@@H](C(=O)NCCCN2CC[NH+](C)CC2)CC1=O,3.848030117,-0.0673,-3.915330117
CCCC[C@@H](C(=O)OC)[C@H](O)c1ccc(Br)cc1F,3.086125432,3.601,0.514874568
O=C(NCC1(O)CCOCC1)c1ccc(CSC(F)F)o1,2.852911417,2.0067,-0.846211417
C[C@](O)(CNC(=O)c1cccs1)CN1CCOCC1,2.825985348,0.5611,-2.264885348
CCOc1ccccc1NC[C@@H]1COc2ccccc21,2.555787576,3.6734,1.117612424
CCN1C=C(C[NH+]2CCC(C(=O)Nc3ccc([C@@H]4C=c5ccccc5=[NH+]4)cc3)CC2)[C@@H](C)N1,5.152966629,-0.7319,-5.884866629
CC(C)(C)n1nccc1NC(=O)C(=O)N1CCC[C@@H](c2cn[nH]c2)C1,3.491526175,1.7059,-1.785626175
Cc1ccc(-n2ncc3c2ncn2nc(CCc4ccccc4)nc32)cc1Cl,2.444150134,4.21022,1.766069866
CC(=O)N1CCN(Cc2cc(NC[C@H](O)c3ccsc3)cc[nH+]2)CC1,3.795660076,1.3718,-2.423860076
CCC[C@H]1CN(C(=O)Nc2ccc(CC)c(F)c2)CCO1,2.654457773,3.4209,0.766442227
CC[NH+]1CCN([C@H](C)CNC(=O)COc2ccccc2C#N)CC1,3.789513477,-0.33782,-4.127333477
CCOc1ccc(N)c2ccccc12,1.62683674,2.8207,1.19386326
CC[C@@](C)([C@@H](N)c1cccc(Cl)c1)[NH+](C)C,4.121757026,1.653,-2.468757026
CC1CCC(NC(=O)C2=CC(=O)CS2)CC1,2.872345398,1.8811,-0.991245398
COc1cccc([C@@H](CO)NC(=O)NCc2scnc2C)c1,2.77534479,1.99292,-0.78242479
C[C@@H](CCO)C[NH2+]Cc1ncc(-c2ccccc2)s1,3.718444882,1.892,-1.826444882
C[C@@H](c1ccccc1)[C@@H]1[NH2+][C@@H](C(=O)[O-])CS1,4.769985386,-0.4551,-5.225085386
O=C(NCCc1ccc(S(=O)(=O)NC(=O)c2ccc(Cl)cc2)cc1)c1ccc(Cl)cc1,1.894832086,4.0846,2.189767914
Cc1nc(-c2c[nH]c(C(=O)NC(C)(C)C)c2)cs1,2.51899596,2.97492,0.45592404
CCSc1nnc(NC(=O)c2sc3cccc(F)c3c2C)s1,2.398658841,4.56462,2.165961159
Cn1c(=O)c(C2=NN[C@H](c3cccs3)C2)c(O)c2ccccc21,3.233543543,2.7443,-0.489243543
Cc1noc(C)c1CCC(=O)N1CCC2(CC1)Nc1ccccc1S(=O)(=O)N2,3.384083936,1.94674,-1.437343936
COCc1ccccc1NC(=O)NCc1ccc(C)cc1N(C)C,2.104700542,3.52912,1.424419458
Cc1cc2c3c([nH]c2cc1S(=O)(=O)N[C@H](C)Cc1ccco1)CC(C)(C)CC3=O,3.268707608,4.13392,0.865212392
CCOc1ccc([C@H](C)NC(=O)C(=O)Nc2ccc(Cl)cc2C)cc1,2.322029223,3.86302,1.540990777
CO[C@H](CNC(=O)C(=O)Nc1cccnc1-n1cccn1)c1ccccc1,2.837716015,1.7097,-1.128016015
CCc1ccc([C@H]([NH3+])[C@@H](C)Oc2ccccc2C)cc1,3.183736434,3.30792,0.124183566
Cc1cccc([C@H]2CCCCC[NH+]2C2CCN(C(N)=O)CC2)c1,4.215146126,2.03812,-2.177026126
CCC[C@@H](NC(=O)c1ccc2[nH]c3c(c2c1)CCCC3)C(=O)OC,2.682829088,3.1182,0.435370912
C[C@H]1CN(C(C)(C)CNC(=O)[C@@H]2C[C@@H]2c2ccc(F)cc2)C[C@@H](C)O1,3.550507363,2.9332,-0.617307363
CC(C)CN(C[C@H](O)c1ccc(F)cc1)C(=O)Nc1cccnc1Cl,2.717118187,4.0976,1.380481813
CCCCOc1ccc(Cl)cc1,1.295802195,3.5189,2.223097805
CS(=O)(=O)NC1CCC([NH2+][C@@H]2C[C@H]3C[C@@H]2[C@H]2CCC[C@@H]23)CC1,5.544107265,1.2349,-4.309207265
COC[C@@H](C)CNC(=O)[C@@H]1C=C[C@@H]([NH3+])C1,4.562448454,-0.4283,-4.990748454
Cc1ccc(C[C@@H](C)[NH2+][C@H](C)c2ccc(C)c(F)c2)s1,3.86904984,3.75964,-0.10940984
COCCN1C(=O)[C@@H]2[C@@H](Cc3ccc(O)c(O)c3)N[C@@]3(C(=O)Nc4ccc(F)cc43)[C@@H]2C1=O,4.296201915,0.8464,-3.449801915
Cc1ccc(C(=O)C2=C([O-])C(=O)N(CCCN3CCOCC3)[C@@H]2c2cccs2)cc1,3.18358086,2.15942,-1.02416086
Cc1cc([C@H](O)[C@@H]2CCO[C@]3(CCOC3)C2)ccc1F,4.177699502,2.75322,-1.424479502
C[C@@H]1CCC[C@H](N(C)C(=O)c2cccc(NC(=O)C(C)(C)C)c2)C1,2.889836908,4.3219,1.432063092
COC(=O)[C@@H]1c2ccsc2CCN1S(=O)(=O)c1ccc(F)cc1,2.801380825,2.3483,-0.453080825
CC[NH+](CC)C[C@@H](C)NC(=O)CC1C[C@@H]2CC[C@H](C1)[NH2+]2,6.514367429,-0.6897,-7.204067429
COC(=O)c1c(N)sc2c1CCS2,3.06293975,1.7651,-1.29783975
CC[C@H](C(=O)N1CCCSC[C@H]1C[NH+](C)C)c1ccccc1,3.940231943,1.6588,-2.281431943
Cc1occc1C(=O)N/N=C/c1cc(Br)cc([N+](=O)[O-])c1[O-],3.024713357,2.09622,-0.928493357
CCOc1ccccc1O[C@H]1CCC[C@H]1NC(=O)c1ccc(OC)nc1,2.886927227,3.2188,0.331872773
Cc1cc(C)c(NC(=O)NCc2ccc(N3CCOCC3)cc2)c(C)c1,1.995046745,3.77016,1.775113255
COc1ccc(C[C@@H](C#N)C(=O)[O-])cc1OC,2.962785706,0.13598,-2.826805706
CC(C)NC(=O)Nc1nc2c(s1)CN(C(=O)c1ccc(C(F)(F)F)cc1)CC2,2.379930821,3.8903,1.510369179
Cc1cccnc1CSc1nc2cc(S(=O)(=O)N(C)C)ccc2o1,2.451381369,3.07382,0.622438631
Cc1ccccc1[C@@H](C)NC(=O)C(=O)Nc1ccc(N(C)C)c(C)c1,2.528571728,3.18534,0.656768272
CC(C)(C)OC(=O)N1CC(=CC(=O)NCC#N)C1,2.850974182,0.80328,-2.047694182
O=C1c2cccnc2C(=O)N1Cc1ccc(F)cc1S(=O)(=O)N1CCCCCC1,2.385687811,2.5816,0.195912189
CCCN(CC(=O)NC)C(=O)[C@@H]1CC[C@H]2CCCC[C@@H]2[NH2+]1,4.358803608,0.2556,-4.103203608
CS(=O)(=O)/N=c1/[nH]c(CC(=O)Nc2c(Cl)cccc2Cl)cs1,2.978326434,2.4245,-0.553826434
C=CCn1c([S-])nc2cc(C(=O)N(C)CC(=O)Nc3ccccc3C(F)(F)F)ccc2c1=O,2.867110276,3.2178,0.350689724
C[C@@H](c1ccc(S(C)(=O)=O)cc1)N(C)C(=O)c1cccc(Br)c1,2.415458229,3.6858,1.270341771
Cc1c(C(=O)NNC(=O)NC[C@H]2CCCO2)sc2ccc(F)cc12,2.847237911,2.47182,-0.375417911
CCOc1ccc(O[C@@H](C)C(=O)NNC(=O)c2[nH]c3ccccc3c2Br)cc1,2.67368861,3.5576,0.88391139
Cc1ccc([C@@H](C)[NH2+][C@H](C)c2nc3c(s2)CCCC3)s1,4.200584091,3.77742,-0.423164091
O=C(CN1C(=O)NC2(CCCCCC2)C1=O)NCc1cccs1,2.725502984,2.0091,-0.716402984
CC[C@H](NC(=O)CCc1cc(F)ccc1F)C(=O)OC,2.471739379,1.9652,-0.506539379
O=C(/C=C/c1c(F)cccc1F)N1CCN(c2ncccc2F)CC1,2.325210066,2.8609,0.535689934
CC(C)C[C@H]1NC(=O)[C@H]2CN(C(=O)c3ccc4ccccc4c3)CCN2C1=O,3.012230743,2.0373,-0.974930743
CCOc1cc(/C=C2/C(=N)N3N=CSC3=NC2=O)ccc1OC(=O)c1ccccc1Cl,2.837941116,4.20687,1.368928884
Cc1oc(-c2ccco2)nc1CC(=O)NC[C@H](C)Oc1ccccc1,2.739838922,3.36922,0.629381078
CCC[C@H]1C[NH2+]CC[C@@H]1c1cnn(-c2ccccc2)c1,4.019041705,2.3393,-1.679741705
O=C(c1c[nH]c(=O)cn1)N(CC1CC[NH+](C2CCCC2)CC1)C[C@H]1CCCO1,4.283016619,0.6286,-3.654416619
C[C@@H]1C[C@@H]1c1ccc(/C=C/C(=O)N2CCC[C@H](O)C2)o1,3.44240408,2.3995,-1.04290408
C[C@H](NC(=O)NC[C@@](C)(O)c1cccs1)c1cccc(C#N)c1,3.261078337,2.88768,-0.373398337
O=C(Nc1ccnn1C1CC[NH+](Cc2ccccc2F)CC1)c1cccnc1,3.31736458,2.0895,-1.22786458
CC(C)[C@@H](CNC(=O)NC[C@@]1(C)CCCO1)C(=O)[O-],3.979685613,-0.1232,-4.102885613
Cc1sc(NC(=O)C2CC2)c(C(=O)NCc2n[nH]c(=S)n2C(C)C)c1C,2.734164435,3.47843,0.744265565
Cc1nc(-n2cccc2)sc1C(=O)N1CCN(c2ccccc2Cl)CC1,2.289127951,3.85802,1.568892049
Cc1cc(NC(=O)N[C@H]2CCCN(c3ccccc3F)C2=O)no1,2.730594246,2.43922,-0.291374246
COc1ccc(Cl)c(NC(=O)[C@@H]2CC(O)=Nc3nncn32)c1,3.229683996,2.1116,-1.118083996
O=C(Nc1ccc(Cl)cc1)C1CCN(c2ccc(-n3cccc3)cn2)CC1,2.15321418,4.3808,2.22758582
CCc1c(N)cnn1[C@@]1(C)CCS(=O)(=O)C1,3.768605035,0.5614,-3.207205035
COc1ccc(OC)c(N2C(=O)/C(=C/c3ccc(O)cc3)SC2=S)c1,2.199361566,3.8152,1.615838434
NC(=O)c1cccc(CNC(=O)COc2ccc(F)cc2Cl)c1,1.814883042,2.2732,0.458316958
COC[C@H](Nc1nccc(OC)n1)c1ccc(F)c(F)c1,2.847808751,2.563,-0.284808751
COC(=O)[C@@]1(NC(=O)CCCc2cc(Cl)sc2Cl)CCOC1,3.349771447,2.8259,-0.523871447
C[C@H]1C[C@H](C)CN(S(=O)(=O)CCn2cnc3ccccc32)C1,3.092222018,2.344,-0.748222018
CCCC[C@@H]1CCC[C@@H]1NC(=O)NC1CC[NH+](CC(=O)N(C)C)CC1,4.320234072,0.78,-3.540234072
Cc1cc(C)cc(N(C)CC(=O)N(C)[C@@H]2CCS(=O)(=O)C2)c1,2.950781578,1.38514,-1.565641578
Cc1ccc(/C=C/C(=O)Nc2sc(C)c(C)c2C#N)o1,2.5068509,3.78994,1.2830891
Cn1nnc2c(=O)n(CC(=O)NC3CC3)cnc21,2.465292041,-1.1964,-3.661692041
COc1cc([C@@H](C)N[C@H]2CC[C@H]([NH+](C)C)C2)ccc1OC(F)F,4.124893168,2.0128,-2.112093168
CC(C)(C)OC(=O)NC1CCN(C(=O)C[NH3+])CC1,2.72243658,-0.256,-2.97843658
CC[C@](C)([C@@H]([NH2+]C)c1cc(Cl)cnc1N)N1CCOCC1,4.283133262,1.0524,-3.230733262
Cc1cnc([C@@H](NC(=O)N(C)[C@H](c2ccccc2)C(C)C)C2CC2)s1,3.46224187,4.94132,1.47907813
COc1ccc([C@H]2CC(=O)C3=C(C)Nc4ccccc4N[C@H]3C2)c(OC)c1OC,3.44999679,4.3391,0.88910321
CC(=O)N(C)C1CCC(C)(C)CC1,2.486653325,2.4335,-0.053153325
CNC(=O)c1cccc(NC(=O)Cc2ccccc2Cl)c1,1.657934364,2.8808,1.222865636
COc1cc(Br)c(O[C@H]2CCCCO2)cc1C,2.706881191,3.67152,0.964638809
O=C(C[C@@H]([NH2+]C[C@H]1CCCO1)C(=O)[O-])Nc1cccc(Cl)c1,4.003377104,-0.4705,-4.473877104
COc1ccc2cc(-c3csc(NC(=O)COc4ccccc4)n3)oc2c1,2.100462086,4.5824,2.481937914
C[C@H](Sc1nnc(-c2ccccc2)n1C)C(=O)N1CCC[C@@H](C(N)=O)C1,2.909074093,1.6866,-1.222474093
CCC[C@@H]1C[C@@H]1NC(=O)C(=O)Nc1cccc(-c2nnc(C)o2)c1,3.176625808,2.28832,-0.888305808
Cc1cc(C(=O)COC(=O)c2ccc(I)cc2)c(C)n1C[C@H]1CCCO1,2.718210326,3.92824,1.210029674
Cc1c(Cl)cccc1NC(=O)C(=O)NCc1ncn(C)n1,2.463583143,1.03182,-1.431763143
Cc1ccc([C@@]2(C(C)(C)C)CCC[NH2+]2)cc1,3.995007141,2.59362,-1.401387141
CN(C)C(=O)c1ccc(NC[C@@H]2CC[C@@H](C(=O)[O-])O2)nn1,3.770737241,-1.1122,-4.882937241
c1cncc([C@H](c2nc(-c3ccc4nc[nH]c4c3)no2)N2CCOCC2)c1,3.116341853,2.4295,-0.686841853
C[C@@H]1[C@H](C(=O)[O-])CCN1C(=O)N1CCc2ccccc21,3.598947972,0.6294,-2.969547972
Cn1cc(CCCOC(=O)c2cc3ccccc3c(=O)[nH]2)cn1,2.31302229,2.0512,-0.26182229
CCOc1ccc(NC(=O)Cn2cnc3sccc3c2=O)cc1,2.011111102,2.4954,0.484288898
O=C(Nc1ccccc1O)C1CC[NH+](CC2CCCCC2)CC1,3.239610037,2.2059,-1.033710037
Cc1ccc(CCNC(=O)c2cc(NC(=O)C3CCCC3)n(C)n2)o1,2.383830001,2.42272,0.038889999
O=C(Nc1cccnc1N1CCCC1)c1ccccc1,1.689281897,2.9341,1.244818103
COc1ccc(N2C(=O)NC(=O)/C(=C/Nc3ccc(F)cc3)C2=O)cc1,2.107915685,2.4131,0.305184315
C[C@@H]([NH2+]Cc1ccc(Cl)c(Cl)c1)c1ccc2c(c1)NC(=O)CO2,3.579272193,3.1489,-0.430372193
Cc1cccc(C)c1NC(=O)CN(C)C(=O)c1cc(C)n(-c2ccccc2)c1C,2.10301896,4.42168,2.31866104
CCc1ccsc1CNC(=O)Nc1ccc(NC(C)=O)c(OC)c1,2.285167592,3.5992,1.314032408
CC/C(C)=N\NC(=O)Cc1ccc(OC)cc1,1.94757722,2.1398,0.19222278
Cc1n[nH]c2cc(NC(=O)c3c(O)cc(F)cc3F)ccc12,2.437484013,3.10742,0.669935987
Cc1ccc(S(=O)(=O)NC[C@H](c2ccc(N(C)C)cc2)[NH+](C)C)c(C)c1,3.427929923,1.53354,-1.894389923
CCC(O)(CC)C(=O)NN(C(=O)c1csc(C)c1)c1ccccc1,2.822409727,3.28562,0.463210273
CC(C)CCNc1nc2c(c(=O)n(Cc3ccccc3Cl)c(=O)n2C)n1C,2.454505588,2.5934,0.138894412
COc1cc(C(=O)NC[C@H](C2CC2)[NH+](C)C)cc(OC)c1C,3.532251982,0.66512,-2.867131982
Cc1cccc2cc([C@H]3NC(=O)c4c(sc5c4CCCC5)N3)c(Cl)nc12,3.223424868,4.99102,1.767595132
Cc1[nH]cnc1C[NH+]1CCCCCCC1,4.107993259,1.06712,-3.040873259
CCC1CCC(C[NH3+])([NH+](CC)[C@@H](C)CC)CC1,5.245877951,1.2706,-3.975277951
N#Cc1ccc([C@@H]([NH3+])CO)cc1F,3.721661942,-0.02732,-3.748981942
CC1(C)SC[C@H](C(=O)N2CCC(OCc3cccnc3)CC2)NC1=O,3.235181826,1.5994,-1.635781826
O=C(C[NH+]1CCC[C@H]1c1nc2ccccc2s1)Nc1cc2[nH]c(=O)[nH]c2cc1Br,4.250457538,2.5869,-1.663557538
CS(=O)(=O)NC[C@H]1CCCCN1C(=O)Nc1cc(C(F)(F)F)ccc1N1CCCC1,2.897130764,3.2412,0.344069236
C[C@H](CC(=O)[C@@H](C#N)c1nc2cc(Cl)ccc2s1)NC(=O)C1CCCC1,3.325606232,4.21098,0.885373768
C=CC[NH+](Cc1cc(F)ccc1OC)[C@H]1CCS(=O)(=O)C1,4.383766153,0.5923,-3.791466153
CCc1cccc(OC2CN(C(=O)[C@H]3CCCCC[NH+]3C)C2)c1,3.985533936,1.2959,-2.689633936
C[NH+]1CCN(C(=O)[C@@H](O)c2ccccc2)CC1,3.840745614,-0.9231,-4.763845614
COc1ccc(C(=O)N2CCC[C@H](c3nc(O)c4nnn(Cc5ccc(F)cc5)c4n3)C2)cc1OC,3.02645316,3.1513,0.12484684
NC(=O)CCOc1ccccc1NC(=O)c1ccc(Cl)c(Cl)c1,1.802330638,3.4999,1.697569362
C[C@@H](O)c1ccc(-c2ccccc2)cc1,1.840468032,3.4069,1.566431968
C[NH2+]C1CCC(N(C)C(=O)c2sccc2Br)CC1,3.545652041,2.087,-1.458652041
C[C@H]1CCC[NH+]1CC(=O)Nc1ccc(F)c(N2CCCS2(=O)=O)c1,4.340830087,0.3713,-3.969530087
CC[C@@H](C)NC(=O)[C@@H](C)NC(=O)C#Cc1ccccc1,2.993557533,1.4575,-1.536057533
c1ccc2sc(C3CC[NH+](C4CCSCC4)CC3)nc2c1,3.761832631,2.9542,-0.807632631
C=C1C[C@@]23CC[C@@H]4[C@@](C)(CCC[C@@]4(C)C(=O)[O-])[C@H]2CC[C@]1(O)C3,5.990082746,2.8203,-3.169782746
Nc1c(NCCN2CCOCC2)ncnc1Nc1ccc(Br)cn1,2.469548408,1.704,-0.765548408
Cn1c(COC(=O)CCc2ncc(-c3ccccc3F)o2)cc(=O)n(C)c1=O,2.549319379,1.5541,-0.995219379
COc1ccc(Br)cc1/C=C/C(=O)NC[C@@H]1CCCO1,2.604302267,2.7661,0.161797733
CCC[C@H](C)NC(=O)[C@@H](C)Nc1cccc(C)c1,2.577453333,3.10022,0.522766667
C[C@H]1C[C@H](C)CN(C(=O)COC(=O)c2cncc(Br)c2)C1,2.968923183,2.5054,-0.463523183
Nc1cc(-c2cccc(Cl)c2)nn1-c1ccccc1F,2.017455876,3.914,1.896544124
COc1ccc(NC(=O)C(=O)NC2CC2)cc1C,1.795305536,1.22072,-0.574585536
CC(C)N(Cc1cccnc1)C(=O)Nc1cnn(C[C@H]2CCCO2)c1,2.91680623,2.8996,-0.01720623
C[C@@H](CNS(=O)(=O)c1ccc2c(c1)NC(=O)CO2)[NH+](C)C1CC1,4.101762609,-0.6386,-4.740362609
COc1ccc(NC(=O)N2CCC[C@@H](c3nnc(C(=O)Nc4cccc(F)c4)s3)C2)cc1,2.731919328,4.3496,1.617680672
COC1CC[NH+](C[C@@H]2CCC[C@H](C)C2)CC1,4.704832754,1.5064,-3.198432754
CSc1ccc([C@H](Br)Cc2ccccc2)cc1,2.313993515,5.0872,2.773206485
C[C@@]1(c2ccccc2F)NC(=O)N(CC(=O)NCC(F)(F)F)C1=O,2.754553183,1.2712,-1.483353183
O=c1[nH]c(Cn2c(-c3ccccc3)noc2=O)nc2sc3c(c12)CCCC3,2.594235212,2.7284,0.134164788
Cc1cncc(C(=O)NCCC[NH+]2C[C@H](C)C[C@@H](C)C2)c1,4.271486349,1.07072,-3.200766349
CCc1ccc(NC(=O)[C@@H](C)S(=O)(=O)Cc2nncn2C)cc1,2.848858229,1.3195,-1.529358229
COc1ccccc1-c1ccccc1C[C@H]1CN(C(=O)c2ccccc2)CCN(C)C1=O,2.69381381,4.1353,1.44148619
Cc1nsc(N)c1-c1nc2cnccc2n1C1CC1,2.945782277,2.78032,-0.165462277
Cc1ccsc1CN(CC[NH+](C)C)C(=O)c1ccc2c(c1)nnn2C,3.351786841,1.12512,-2.226666841
Cc1cccnc1CCNC(=O)c1ccc([C@H]2CCC[NH+]2C(C)C)s1,4.198988791,2.55222,-1.646768791
CCc1ccsc1CNC(=O)C1(c2ccccc2F)CC1,2.52839512,3.7976,1.26920488
COc1cccc([C@@H](C)NC(=O)COC(=O)c2cncc(Br)c2)c1,2.462586621,2.8869,0.424313379
CS(=O)(=O)N(CC(=O)N1CCOCC1)c1ccc(F)c(Cl)c1,2.227684896,1.1039,-1.123784896
COc1ccc(N2C(=O)CC[C@H](C(=O)Nc3ccc(-n4ccnc4)cc3)[C@H]2c2ccccc2OC)cc1,3.144845929,5.0125,1.867654071
CC[C@@H]1C(=O)N2CCC[C@H]2C(=O)N1CC1CCC1,3.231498394,1.3983,-1.833198394
Cn1ccc(C(=O)N2C[C@@H]3CC[C@H](C2)[NH+](Cc2ccccc2)C3)n1,5.015766837,0.7396,-4.276166837
NC1=NC(=O)N(Cc2ccccc2)[C@@H]1c1ccccc1Br,2.854435153,3.4832,0.628764847
CC(C)[NH+]1CCCN(C(=O)c2cc(C3CC3)n(C(C)(C)C)n2)CC1,4.137344869,1.6547,-2.482644869
Cc1cc(CNC(=O)[C@@]2(C)CCC[NH2+]2)cc(C)c1F,3.955747577,1.17464,-2.781107577
C[C@@H]1CC(=O)N(c2ccc(NCc3cccnc3)c([N+](=O)[O-])c2)C1=O,2.76947101,2.5013,-0.26817101
CC1CCC2(CC1)NC(=O)N(NC(=O)CNC(=O)c1cccs1)C2=O,3.018403098,1.0098,-2.008603098
CCOc1ccc(C(C)=O)cc1Cn1c(C(F)(F)F)nc2ccccc21,2.185433336,4.7047,2.519266664
Cc1ccccc1N1C(=O)c2ccc(C(=O)OCC(=O)NC[C@H]3CCCO3)cc2C1=O,2.664389642,2.24762,-0.416769642
COC1CCC(CC(=O)Nc2[nH+]cccc2[O-])CC1,3.890243815,1.1081,-2.782143815
Nc1ccc(F)cc1NC(=O)CN1CCc2ccccc21,1.963255159,2.4091,0.445844841
COc1c(Br)cc(Cl)cc1S(=O)(=O)[N-]c1cc(F)ccc1F,3.247974726,4.7834,1.535425274
CCOc1ccc(NC(=O)N(C)Cc2ncnn2C)cc1C,2.274019952,2.18612,-0.087899952
C[C@H]1OCC[C@@H]1C(=O)Nc1ccc(Cl)c(C(=O)[O-])c1,3.500226302,1.067,-2.433226302
O=C(NCc1cccnc1)C(=O)N1CCCC1,1.957724353,0.3202,-1.637524353
C#CCOc1ccc2ccccc2c1CN1CCN(CC#N)CC1,2.405771046,2.49298,0.087208954
CCOc1ncccc1-c1ncnc2scc(C)c12,2.507808944,3.46042,0.952611056
CC(=O)NC1(c2ccc(F)cc2)CCN(C(=O)C2CCCCC2)CC1,2.31470064,3.3598,1.04509936
Cc1ccc(NC(=O)C(=O)NC[C@@H](c2ccc3c(c2)CCN3C)[NH+]2CCCCC2)c([N+](=O)[O-])c1,3.902352433,1.76032,-2.142032433
CCc1ccc(Oc2ccc3nnnn3n2)c(N)c1,2.589085867,1.4562,-1.132885867
Cc1cc(OCC(=O)OC(C)C)cc2c1C(=O)/C(=C/c1cc(Br)cc3c1OCOC3)O2,2.988519088,4.57062,1.582100912
COC(=O)[C@H]1[C@@H](NC(=S)NC[C@@H]2C=CN=N2)S[C@@H](C)[C@H]1C,5.238676386,1.6853,-3.553376386
C[C@H]1C[C@@H](C)CN(C(=O)CSc2ccccc2S(C)(=O)=O)C1,3.009957402,2.6867,-0.323257402
CC1(C)COc2cscc2OC1,3.146666448,2.5455,-0.601166448
Cc1cc(C)c(NC(=O)CNC(=O)CCc2nc3ccccc3c(=O)[nH]2)c(C)c1,2.263700819,2.53586,0.272159181
COC[C@@H](C)NC(=O)Nc1ccc(F)cc1F,2.338775673,2.1212,-0.217575673
CCOc1ccc(N[C@H](C)c2ccc(F)cc2F)cc1CO,2.516868547,4.0289,1.512031453
CCCn1nnnc1CN1C(=O)c2ccc([N+](=O)[O-])cc2C1=O,2.398476792,0.7875,-1.610976792
CC(C)CC(C)(C)C(=O)Nc1ccc(F)c(C(=O)N(C)C)c1,2.287139552,3.5383,1.251160448
Cn1cc(/C=C/C(=O)c2cccc(F)c2)c(=O)n(C)c1=O,2.300985321,1.1192,-1.181785321
O=C(NCc1ccnc(OC2CCC2)c1)c1[nH]ncc1Br,2.55409362,2.4285,-0.12559362
Cc1ccc([C@@H](CN2CCOCC2)NC(=O)COc2ccc(F)cc2)cc1,2.421749151,2.70262,0.280870849
================================================
FILE: graphium/data/multilevel_utils.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import pandas as pd
import ast
import numpy as np
from typing import List
import itertools
import math
def extract_labels(df: pd.DataFrame, task_level: str, label_cols: List[str]):
"""Extracts labels in label_cols from dataframe df for a given task_level.
Returns a list of numpy arrays converted to the correct shape. Multiple
targets are concatenated for each graph.
"""
def unpack(graph_data):
graph_data = pd.to_numeric(graph_data, errors="coerce")
if isinstance(graph_data, str):
graph_data_list = ast.literal_eval(graph_data)
return np.array(graph_data_list)
elif isinstance(graph_data, (int, float)):
return np.array([graph_data])
elif isinstance(graph_data, list):
return np.array(graph_data)
elif isinstance(graph_data, np.ndarray):
if len(graph_data.shape) == 0:
graph_data = np.expand_dims(graph_data, 0)
if graph_data.shape[0] == 0:
graph_data = np.array([np.nan])
# TODO: Warning
return graph_data
else:
raise ValueError(
f"Graph data should be one of str, float, int, list, np.ndarray, got {type(graph_data)}"
)
def unpack_column(data: pd.Series):
return data.apply(unpack)
def merge_columns(data: pd.Series):
data = data.to_list()
data = [np.array([np.nan]) if not isinstance(d, np.ndarray) and math.isnan(d) else d for d in data]
padded_data = itertools.zip_longest(*data, fillvalue=np.nan)
data = np.stack(list(padded_data), 1).T
return data
unpacked_df: pd.DataFrame = df[label_cols].apply(unpack_column)
output = unpacked_df.apply(merge_columns, axis="columns").to_list()
if task_level == "graph":
return np.concatenate(output)
return output
================================================
FILE: graphium/data/multitask/__init__.py
================================================
================================================
FILE: graphium/data/multitask/tiny_ZINC_SA.csv
================================================
SMILES,SA
Cc1cccc(COC(=O)Nc2ccc(-c3cn[nH]c3)cc2OCCN2CCCC2)c1.O=C(O)C(F)(F)F.O=C(O)C(F)(F)F,2.896514950230738
CCC(C)C1=C(C=CC=C1)OCCC[N]2C=CN=C2.O=C(O)C(=O)O,2.66931991795572
CCc1ccc(C(=O)/C(=C(/S)NC2CC2)[n+]2ccc(CC)cc2)cc1,2.775189478
C=CCOc1ccc(C(F)(F)F)cc1C(=O)NC[C@H]1CCC[C@H]1O,3.071497898
COc1cc(OC)cc([C@@H](NC(=O)Cn2ccccc2=O)c2nccn2C)c1,2.84000896
CN1CCC[C@@]2(CCN(C(=O)CN3CCNC(=O)C3)C2)C1=O,3.605168457
COc1ccc(Cl)cc1NC(=O)c1nn(C)cc1[N+](=O)[O-],2.132664673
Cc1nn(-c2ccccc2)c(O)c1/C=[NH+]/Cc1ccncc1,3.117639223
C=CCN1C(=O)C(C)(C)COc2ccc(NC(=O)C3(c4ccc(OC)cc4)CCOCC3)cc21,2.744789746
CON(C)C(=O)c1cnn2c(C)cc(C)nc12,2.547847184
NC(=O)[C@@H](NC(=O)CCCc1cccs1)c1ccccc1,2.444743614
COc1cccc(CN2CCC[NH+](CC(=O)Nc3ccc(F)cc3)S2(=O)=O)c1,3.546311784
Cc1ccc(-c2nc3ccc(C)c(C)c3[nH]2)nc1,2.342516783
Cc1n[nH]c(SCC(=O)Nc2cccc(C#Cc3ccccn3)c2)n1,2.685016252
COC(=O)c1cncc(C(=O)Nc2cccc(COC(C)C)c2)c1,2.04419285
Cc1ccc(C)c(Oc2ccc(CNC(=O)N3CCCSCC3)cn2)c1,2.369895336
CC[NH2+][C@@]1(C(=O)[O-])CCC[C@H](Oc2ccc(CC)cc2)C1,4.165551538
CC(=O)N1c2ccc(S(=O)(=O)N3CCCC3)cc2C[C@H]1C(=O)NCC[NH+](C)C1CCCCC1,3.989079408
C=CCN1CC(=O)N2[C@@H](Cc3c([nH]c4ccccc34)[C@@H]2c2ccccc2OC)C1=O,3.287567863
COc1ccc(-c2noc3ncnc(N4CCC[C@H](C(=O)[O-])C4)c23)cc1,3.152874099
N#Cc1cc(NC(=O)Cc2ccccc2Cl)ccc1N1CCC(O)CC1,2.173785623
Cc1ncc(-c2ccc(S(=O)(=O)NC3CCCCCC3)s2)o1,2.513723357
Cc1ccc(NC(=O)c2ccc(F)cc2F)cc1S(=O)(=O)Nc1ccc(Cl)cc1,1.947448768
CNC(=O)[C@@H]1CCC[NH+]1Cc1ccc(C)c(F)c1,4.391338086
CN1C(=O)N[C@@H](c2cccc([N+](=O)[O-])c2)C2=C1CN(c1ccc(F)cc1)C2=O,2.956072154
Cc1cc(CC(=O)N[C@H](c2ccc(F)cc2)C2CCC2)no1,2.665131639
Cc1cc(N2CCN(CC(=O)NC3CC3)CC2)nc(-c2ccc([N+](=O)[O-])cc2)n1,2.249569987
Cc1cc(C(=O)N2CCN(C(=O)N[C@H]3CC(=O)N(C4CC4)C3)CC2)c(C)o1,2.921725033
O=c1c2c3nc4ccccc4nc3n(CCC3=CCCCC3)c2ncn1C[C@H]1CCCO1,3.257260117
CC[NH+](C/C=C/c1ccc(C#N)cc1)C[C@H](C)C#N,4.36293875
COc1c(Cl)cc(C[NH2+][C@@H](Cc2c[nH]c3ccccc23)C(=O)[O-])cc1Cl,3.725368177
COc1ccccc1[C@H](C)NC(=O)C[C@H]1C[NH2+]CCO1,3.812797047
Cc1ccc(NC(=O)C(=O)NCCc2ccccc2F)c(C)c1,1.826753956
CC(C)(C)OC(=O)N[C@H]1CCN(c2cc(-c3cccs3)n[nH]2)C1,3.228025092
CC[C@H](C)C[NH+]1CCCC[C@@H]1C(=O)NC(C)(C)C,4.787067722
COC(=O)c1c(S(=O)(=O)NC2CC2)sc2c1CCN(Cc1ccc3ccccc3c1)C2,2.485526606
Cc1cc(C)cc(NC(=O)[C@@H](Sc2nnnn2C2CC2)c2ccccc2)c1,2.708478583
C=CCn1c([S-])nnc1-c1sc(NC(=O)c2cccc(NC(C)=O)c2)nc1C,2.943459024
O=C(CNc1nc(C2CC2)no1)N1CCc2sccc2C1,2.781430253
Cc1ccccc1NC(=O)NCCn1c([N+](=O)[O-])cnc1C,2.258312512
C[C@H](NC=C(C#N)C#N)c1ccc(-n2cncn2)cc1,3.277555708
CN1C[C@H](C(=O)NC[C@H]2CCC(C)(C)c3ccccc32)CC1=O,3.259412472
CC(C)n1cccc1C(=O)N[C@H]1CCc2cc(N)ccc21,2.887282024
CCN1C(=O)C(C#N)=C(C)/C(=C\Nc2ccc([N+](=O)[O-])cc2)C1=O,2.501067124
O=C(Nc1ccnn1Cc1cccc(Cl)c1)C1CCN(S(=O)(=O)c2ccccc2F)CC1,2.296421252
COc1cc([C@@H]2N[C@H](C(=O)[O-])CS2)ccc1OC(=O)c1ccccc1,3.273511868
CCOc1ccc(C2=CCN(C(=O)c3ccccc3F)CC2)cc1,2.000346516
CC(=O)N1CC[C@@H]([NH2+][C@@H](C)CSCC(C)C)C1,4.483674272
CCO[P@](=O)(CC(=O)[O-])c1ccccc1,3.510193788
O=C(CCc1nc2ccccc2c(=O)[nH]1)N1CCC[C@H]1C1CCCC1,2.774610155
O=C(/C=C/SCc1ccco1)Nc1cccc(N2CCCC2=O)c1,2.584851266
CC(=O)c1c(C)[nH]c(C(=O)OCC(=O)N2C[C@H](C)C[C@@H](C)C2)c1C,3.068452859
Cc1ccc(C(=O)N2CCC[C@@H](C(=O)N(C)CC(=O)NC(C)C)C2)cc1,2.492863438
C[C@@H](C#N)Sc1nnc(C23CC4CC(CC(C4)C2)C3)o1,4.576896432
CC1(C)CCC[C@H]1n1c(N)[nH+]c2ccccc21,4.151966768
CCOc1ccc(C[NH+]2CCS[C@H]3COCC[C@@H]32)cc1OC,4.514122047
COc1cc(OC)c([C@@H](C)[NH2+]C[C@H](O)C2CCOCC2)cc1Cl,3.960984106
COC(C[C@@](C)(O)C[C@@H]1CCCN(S(C)(=O)=O)C1)OC,3.669896168
C=CCSc1nnc(C2CCOCC2)n1N,2.825013358
O=C(Nc1c[nH]c2ccccc12)N1CCCC[C@@H]1CCO,2.674797385
Cc1nc(-c2ccc(-c3noc(C4CC4)n3)cc2)cs1,2.133219774
C[C@H]1CCCCN1C(=O)[C@@H](C)NC(=O)C(=O)Nc1ccnn1C(C)(C)C,3.418504414
c1cc2c(cc1N[C@@H]1CCOC3(CCC3)C1)CCC2,3.363061129
O=C(Nc1ccc(CNC(=O)N2CCCCCCC2)cc1)c1ccco1,1.931563902
CNC(=O)[C@H]1CCCCN1c1cccc(F)c1C(N)=S,2.82917551
O=C(NCc1nnc2n1CCC2)c1cc(-c2ccccc2)[nH]n1,2.376871447
Cc1nnc(-c2cc(CC(C)C)n(-c3cccc(C(=O)N[C@@H]4C[C@@H](C)CC(C)(C)C4)c3)n2)o1,3.548164288
COc1cccc([C@@H]2CC(=O)C3=C(C2)NC(=O)C[C@H]3c2ccc(OC)c(OC)c2OC)c1,3.226440403
C[NH2+]Cc1ccc(-c2cc(C)ccc2F)s1,3.128631552
CCC[C@H](NC(N)=O)C(=O)NC[C@@H]1CCCO1,2.869359841
COc1cc(F)cc(CNC(=O)[C@H]2CCCN2C(=O)Cc2ccccc2)c1,2.427176885
CCc1cc(C#N)c(NC(=O)CCC(=O)N(CC)CC)s1,2.493551313
C#Cc1cccc(NC(=O)C[NH+](C)[C@H](C)c2nnc(-c3ccccc3)o2)c1,3.779297449
CCOC[C@H](O)[C@](C)(CC)[NH+]1CCCC1,5.127905513
C[NH+](C)Cc1cccc(CNC(=O)NC[C@H](O)c2ccc(F)cc2)c1,3.24179101
Cc1ccsc1C(=O)NCCNC(=O)c1ccco1,2.114582277
COC(=O)[C@H](C)N(Cc1ccccc1)C(=O)Cc1ccon1,2.852054716
C/C=C(/C)C(=O)NC[C@H]1C[C@@]12CCc1ccccc12,3.940919906
O=C1Nc2ccccc2Oc2cc([N+](=O)[O-])cc(Oc3ccc(Cl)cc3)c21,2.400946341
CCc1onc(C)c1NC(=O)CCCC(C)(C)C,2.601600041
COC(=O)C(C)(C)COc1ccc2c(c1)OC[C@@H]2[NH3+],3.576525387
CC#CCCC(=O)Nc1cccc2c1C(=O)c1ccccc1C2=O,2.298959953
Cc1nc(C)c(S(=O)(=O)/N=C(\[O-])C[C@H]2CCCO2)s1,3.812988405
CCC(=O)N[C@H](CCSC)C(=O)Nc1ccc2[nH]c(C)cc2c1,2.732126257
COCC[NH+](C)Cc1c(C)cc(C)c(C(C)=O)c1C,3.955103091
Cc1cscc1CNC(=O)N[C@@H](C)c1nc(C(=O)[O-])cs1,3.628458628
CNC(=O)NCC(=O)NC[C@H](c1cccnc1)C(C)C,2.795116489
COc1cccc(C(=O)N2CCC[C@H]2[C@H]2CCC[C@@H]2O)c1,3.002385916
COc1ccc2c(c1)[C@H]([NH2+][C@H](C)CCN1CCOCC1)CCCO2,3.789130146
COc1cccc(OCC(=O)N2CCN(c3nc4ccc(S(C)(=O)=O)cc4s3)CC2)c1,2.245396536
COc1cccc(-c2nc(C(=O)O[C@@H](C)CNC(C)=O)c(C)[nH]2)c1,2.86962049
CCO[C@H]1C[C@@H]1C(=O)Nc1ccc(Sc2nncs2)c(Cl)c1,3.445502237
Cc1ccc(N(CC(=O)N2CCCC[C@H]2C)S(=O)(=O)c2c(C)nn(C)c2C)cc1,2.872766629
COc1ccc(C(=O)Nc2nc(CC(=O)Nc3ccc(N(C)C)cc3)cs2)cc1,1.983649537
O=c1c2ccccc2sn1-c1ncc(Br)s1,2.853247472
Cc1ocnc1CNC(=O)N(C)Cc1ccc(OC(F)F)cc1,2.60278761
COc1cc(Cl)c(C)cc1NC(=O)C(=O)NCc1ccccc1C[NH+](C)C,2.870575026
C[C@@H]([NH3+])c1cccc(Oc2cc(Br)ccc2Cl)c1,2.965149266
CC(C)NS(=O)(=O)c1cccc(C(=O)NCc2ccco2)c1,1.961616674
CN(Cc1ncnn1C)C(=O)CSCc1ccccn1,2.529083407
================================================
FILE: graphium/data/multitask/tiny_ZINC_logp.csv
================================================
SMILES,logp
Cc1cccc(COC(=O)Nc2ccc(-c3cn[nH]c3)cc2OCCN2CCCC2)c1.O=C(O)C(F)(F)F.O=C(O)C(F)(F)F,5.87502
CCC(C)C1=C(C=CC=C1)OCCC[N]2C=CN=C2.O=C(O)C(=O)O,3.0213
CCc1ccc(C(=O)/C(=C(/S)NC2CC2)[n+]2ccc(CC)cc2)cc1,3.7897
C=CCOc1ccc(C(F)(F)F)cc1C(=O)NC[C@H]1CCC[C@H]1O,3.161
COc1cc(OC)cc([C@@H](NC(=O)Cn2ccccc2=O)c2nccn2C)c1,1.5048
CN1CCC[C@@]2(CCN(C(=O)CN3CCNC(=O)C3)C2)C1=O,-1.1109
COc1ccc(Cl)cc1NC(=O)c1nn(C)cc1[N+](=O)[O-],2.2426
Cc1nn(-c2ccccc2)c(O)c1/C=[NH+]/Cc1ccncc1,0.98102
C=CCN1C(=O)C(C)(C)COc2ccc(NC(=O)C3(c4ccc(OC)cc4)CCOCC3)cc21,4.3197
CON(C)C(=O)c1cnn2c(C)cc(C)nc12,0.97954
NC(=O)[C@@H](NC(=O)CCCc1cccs1)c1ccccc1,2.4136
COc1cccc(CN2CCC[NH+](CC(=O)Nc3ccc(F)cc3)S2(=O)=O)c1,0.8084
Cc1ccc(-c2nc3ccc(C)c(C)c3[nH]2)nc1,3.55016
Cc1n[nH]c(SCC(=O)Nc2cccc(C#Cc3ccccn3)c2)n1,2.63872
COC(=O)c1cncc(C(=O)Nc2cccc(COC(C)C)c2)c1,3.0455
Cc1ccc(C)c(Oc2ccc(CNC(=O)N3CCCSCC3)cn2)c1,4.13924
CC[NH2+][C@@]1(C(=O)[O-])CCC[C@H](Oc2ccc(CC)cc2)C1,0.6424
CC(=O)N1c2ccc(S(=O)(=O)N3CCCC3)cc2C[C@H]1C(=O)NCC[NH+](C)C1CCCCC1,0.7123
C=CCN1CC(=O)N2[C@@H](Cc3c([nH]c4ccccc34)[C@@H]2c2ccccc2OC)C1=O,3.0474
COc1ccc(-c2noc3ncnc(N4CCC[C@H](C(=O)[O-])C4)c23)cc1,1.2597
N#Cc1cc(NC(=O)Cc2ccccc2Cl)ccc1N1CCC(O)CC1,3.35398
Cc1ncc(-c2ccc(S(=O)(=O)NC3CCCCCC3)s2)o1,3.71262
Cc1ccc(NC(=O)c2ccc(F)cc2F)cc1S(=O)(=O)Nc1ccc(Cl)cc1,4.97972
CNC(=O)[C@@H]1CCC[NH+]1Cc1ccc(C)c(F)c1,0.42742
CN1C(=O)N[C@@H](c2cccc([N+](=O)[O-])c2)C2=C1CN(c1ccc(F)cc1)C2=O,2.7309
Cc1cc(CC(=O)N[C@H](c2ccc(F)cc2)C2CCC2)no1,3.32222
Cc1cc(N2CCN(CC(=O)NC3CC3)CC2)nc(-c2ccc([N+](=O)[O-])cc2)n1,1.76082
Cc1cc(C(=O)N2CCN(C(=O)N[C@H]3CC(=O)N(C4CC4)C3)CC2)c(C)o1,1.12714
O=c1c2c3nc4ccccc4nc3n(CCC3=CCCCC3)c2ncn1C[C@H]1CCCO1,4.3639
CC[NH+](C/C=C/c1ccc(C#N)cc1)C[C@H](C)C#N,1.63596
COc1c(Cl)cc(C[NH2+][C@@H](Cc2c[nH]c3ccccc23)C(=O)[O-])cc1Cl,1.9079
COc1ccccc1[C@H](C)NC(=O)C[C@H]1C[NH2+]CCO1,0.2247
Cc1ccc(NC(=O)C(=O)NCCc2ccccc2F)c(C)c1,2.73994
CC(C)(C)OC(=O)N[C@H]1CCN(c2cc(-c3cccs3)n[nH]2)C1,3.2416
CC[C@H](C)C[NH+]1CCCC[C@@H]1C(=O)NC(C)(C)C,1.3846
COC(=O)c1c(S(=O)(=O)NC2CC2)sc2c1CCN(Cc1ccc3ccccc3c1)C2,3.6869
Cc1cc(C)cc(NC(=O)[C@@H](Sc2nnnn2C2CC2)c2ccccc2)c1,4.09694
C=CCn1c([S-])nnc1-c1sc(NC(=O)c2cccc(NC(C)=O)c2)nc1C,3.01252
O=C(CNc1nc(C2CC2)no1)N1CCc2sccc2C1,2.0053
Cc1ccccc1NC(=O)NCCn1c([N+](=O)[O-])cnc1C,2.22984
C[C@H](NC=C(C#N)C#N)c1ccc(-n2cncn2)cc1,1.84896
CN1C[C@H](C(=O)NC[C@H]2CCC(C)(C)c3ccccc32)CC1=O,2.4361
CC(C)n1cccc1C(=O)N[C@H]1CCc2cc(N)ccc21,3.0685
CCN1C(=O)C(C#N)=C(C)/C(=C\Nc2ccc([N+](=O)[O-])cc2)C1=O,2.11938
O=C(Nc1ccnn1Cc1cccc(Cl)c1)C1CCN(S(=O)(=O)c2ccccc2F)CC1,3.7633
COc1cc([C@@H]2N[C@H](C(=O)[O-])CS2)ccc1OC(=O)c1ccccc1,1.3679
CCOc1ccc(C2=CCN(C(=O)c3ccccc3F)CC2)cc1,4.1539
CC(=O)N1CC[C@@H]([NH2+][C@@H](C)CSCC(C)C)C1,0.9483
CCO[P@](=O)(CC(=O)[O-])c1ccccc1,0.3764
O=C(CCc1nc2ccccc2c(=O)[nH]1)N1CCC[C@H]1C1CCCC1,3.0369
O=C(/C=C/SCc1ccco1)Nc1cccc(N2CCCC2=O)c1,3.792
CC(=O)c1c(C)[nH]c(C(=O)OCC(=O)N2C[C@H](C)C[C@@H](C)C2)c1C,2.49544
Cc1ccc(C(=O)N2CCC[C@@H](C(=O)N(C)CC(=O)NC(C)C)C2)cc1,1.83022
C[C@@H](C#N)Sc1nnc(C23CC4CC(CC(C4)C2)C3)o1,3.54158
CC1(C)CCC[C@H]1n1c(N)[nH+]c2ccccc21,2.7888
CCOc1ccc(C[NH+]2CCS[C@H]3COCC[C@@H]32)cc1OC,1.3831
COc1cc(OC)c([C@@H](C)[NH2+]C[C@H](O)C2CCOCC2)cc1Cl,1.7691
COC(C[C@@](C)(O)C[C@@H]1CCCN(S(C)(=O)=O)C1)OC,0.8081
C=CCSc1nnc(C2CCOCC2)n1N,1.164
O=C(Nc1c[nH]c2ccccc12)N1CCCC[C@@H]1CCO,2.9367
Cc1nc(-c2ccc(-c3noc(C4CC4)n3)cc2)cs1,4.04592
C[C@H]1CCCCN1C(=O)[C@@H](C)NC(=O)C(=O)Nc1ccnn1C(C)(C)C,1.4823
c1cc2c(cc1N[C@@H]1CCOC3(CCC3)C1)CCC2,3.6889
O=C(Nc1ccc(CNC(=O)N2CCCCCCC2)cc1)c1ccco1,4.0076
CNC(=O)[C@H]1CCCCN1c1cccc(F)c1C(N)=S,1.5648
O=C(NCc1nnc2n1CCC2)c1cc(-c2ccccc2)[nH]n1,1.5444
Cc1nnc(-c2cc(CC(C)C)n(-c3cccc(C(=O)N[C@@H]4C[C@@H](C)CC(C)(C)C4)c3)n2)o1,5.37382
COc1cccc([C@@H]2CC(=O)C3=C(C2)NC(=O)C[C@H]3c2ccc(OC)c(OC)c2OC)c1,3.7253
C[NH2+]Cc1ccc(-c2cc(C)ccc2F)s1,2.55582
CCC[C@H](NC(N)=O)C(=O)NC[C@@H]1CCCO1,0.1186
COc1cc(F)cc(CNC(=O)[C@H]2CCCN2C(=O)Cc2ccccc2)c1,2.6842
CCc1cc(C#N)c(NC(=O)CCC(=O)N(CC)CC)s1,2.76928
C#Cc1cccc(NC(=O)C[NH+](C)[C@H](C)c2nnc(-c3ccccc3)o2)c1,1.9323
CCOC[C@H](O)[C@](C)(CC)[NH+]1CCCC1,0.2312
C[NH+](C)Cc1cccc(CNC(=O)NC[C@H](O)c2ccc(F)cc2)c1,1.003
Cc1ccsc1C(=O)NCCNC(=O)c1ccco1,1.80932
COC(=O)[C@H](C)N(Cc1ccccc1)C(=O)Cc1ccon1,1.8074
C/C=C(/C)C(=O)NC[C@H]1C[C@@]12CCc1ccccc12,2.9729
O=C1Nc2ccccc2Oc2cc([N+](=O)[O-])cc(Oc3ccc(Cl)cc3)c21,5.3985
CCc1onc(C)c1NC(=O)CCCC(C)(C)C,3.70032
COC(=O)C(C)(C)COc1ccc2c(c1)OC[C@@H]2[NH3+],0.94
CC#CCCC(=O)Nc1cccc2c1C(=O)c1ccccc1C2=O,3.204
Cc1nc(C)c(S(=O)(=O)/N=C(\[O-])C[C@H]2CCCO2)s1,0.77654
CCC(=O)N[C@H](CCSC)C(=O)Nc1ccc2[nH]c(C)cc2c1,3.06272
COCC[NH+](C)Cc1c(C)cc(C)c(C(C)=O)c1C,1.47556
Cc1cscc1CNC(=O)N[C@@H](C)c1nc(C(=O)[O-])cs1,1.43692
CNC(=O)NCC(=O)NC[C@H](c1cccnc1)C(C)C,0.8664
COc1cccc(C(=O)N2CCC[C@H]2[C@H]2CCC[C@@H]2O)c1,2.4608
COc1ccc2c(c1)[C@H]([NH2+][C@H](C)CCN1CCOCC1)CCCO2,1.5831
COc1cccc(OCC(=O)N2CCN(c3nc4ccc(S(C)(=O)=O)cc4s3)CC2)c1,2.436
COc1cccc(-c2nc(C(=O)O[C@@H](C)CNC(C)=O)c(C)[nH]2)c1,2.07512
CCO[C@H]1C[C@@H]1C(=O)Nc1ccc(Sc2nncs2)c(Cl)c1,3.7062
Cc1ccc(N(CC(=O)N2CCCC[C@H]2C)S(=O)(=O)c2c(C)nn(C)c2C)cc1,2.94166
COc1ccc(C(=O)Nc2nc(CC(=O)Nc3ccc(N(C)C)cc3)cs2)cc1,3.6512
O=c1c2ccccc2sn1-c1ncc(Br)s1,3.2712
Cc1ocnc1CNC(=O)N(C)Cc1ccc(OC(F)F)cc1,2.92602
COc1cc(Cl)c(C)cc1NC(=O)C(=O)NCc1ccccc1C[NH+](C)C,1.55642
C[C@@H]([NH3+])c1cccc(Oc2cc(Br)ccc2Cl)c1,4.1977
CC(C)NS(=O)(=O)c1cccc(C(=O)NCc2ccco2)c1,1.8963
CN(Cc1ncnn1C)C(=O)CSCc1ccccn1,1.1019
================================================
FILE: graphium/data/multitask/tiny_ZINC_score.csv
================================================
SMILES,score
Cc1cccc(COC(=O)Nc2ccc(-c3cn[nH]c3)cc2OCCN2CCCC2)c1.O=C(O)C(F)(F)F.O=C(O)C(F)(F)F,2.978505049769262
CCC(C)C1=C(C=CC=C1)OCCC[N]2C=CN=C2.O=C(O)C(=O)O,0.35198008204428
CCc1ccc(C(=O)/C(=C(/S)NC2CC2)[n+]2ccc(CC)cc2)cc1,1.014510522
C=CCOc1ccc(C(F)(F)F)cc1C(=O)NC[C@H]1CCC[C@H]1O,0.089502102
COc1cc(OC)cc([C@@H](NC(=O)Cn2ccccc2=O)c2nccn2C)c1,-1.33520896
CN1CCC[C@@]2(CCN(C(=O)CN3CCNC(=O)C3)C2)C1=O,-4.716068457
COc1ccc(Cl)cc1NC(=O)c1nn(C)cc1[N+](=O)[O-],0.109935327
Cc1nn(-c2ccccc2)c(O)c1/C=[NH+]/Cc1ccncc1,-2.136619223
C=CCN1C(=O)C(C)(C)COc2ccc(NC(=O)C3(c4ccc(OC)cc4)CCOCC3)cc21,1.574910254
CON(C)C(=O)c1cnn2c(C)cc(C)nc12,-1.568307184
NC(=O)[C@@H](NC(=O)CCCc1cccs1)c1ccccc1,-0.031143614
COc1cccc(CN2CCC[NH+](CC(=O)Nc3ccc(F)cc3)S2(=O)=O)c1,-2.737911784
Cc1ccc(-c2nc3ccc(C)c(C)c3[nH]2)nc1,1.207643217
Cc1n[nH]c(SCC(=O)Nc2cccc(C#Cc3ccccn3)c2)n1,-0.046296252
COC(=O)c1cncc(C(=O)Nc2cccc(COC(C)C)c2)c1,1.00130715
Cc1ccc(C)c(Oc2ccc(CNC(=O)N3CCCSCC3)cn2)c1,1.769344664
CC[NH2+][C@@]1(C(=O)[O-])CCC[C@H](Oc2ccc(CC)cc2)C1,-3.523151538
CC(=O)N1c2ccc(S(=O)(=O)N3CCCC3)cc2C[C@H]1C(=O)NCC[NH+](C)C1CCCCC1,-3.276779408
C=CCN1CC(=O)N2[C@@H](Cc3c([nH]c4ccccc34)[C@@H]2c2ccccc2OC)C1=O,-0.240167863
COc1ccc(-c2noc3ncnc(N4CCC[C@H](C(=O)[O-])C4)c23)cc1,-1.893174099
N#Cc1cc(NC(=O)Cc2ccccc2Cl)ccc1N1CCC(O)CC1,1.180194377
Cc1ncc(-c2ccc(S(=O)(=O)NC3CCCCCC3)s2)o1,1.198896643
Cc1ccc(NC(=O)c2ccc(F)cc2F)cc1S(=O)(=O)Nc1ccc(Cl)cc1,3.032271232
CNC(=O)[C@@H]1CCC[NH+]1Cc1ccc(C)c(F)c1,-3.963918086
CN1C(=O)N[C@@H](c2cccc([N+](=O)[O-])c2)C2=C1CN(c1ccc(F)cc1)C2=O,-0.225172154
Cc1cc(CC(=O)N[C@H](c2ccc(F)cc2)C2CCC2)no1,0.657088361
Cc1cc(N2CCN(CC(=O)NC3CC3)CC2)nc(-c2ccc([N+](=O)[O-])cc2)n1,-0.488749987
Cc1cc(C(=O)N2CCN(C(=O)N[C@H]3CC(=O)N(C4CC4)C3)CC2)c(C)o1,-1.794585033
O=c1c2c3nc4ccccc4nc3n(CCC3=CCCCC3)c2ncn1C[C@H]1CCCO1,1.106639883
CC[NH+](C/C=C/c1ccc(C#N)cc1)C[C@H](C)C#N,-2.72697875
COc1c(Cl)cc(C[NH2+][C@@H](Cc2c[nH]c3ccccc23)C(=O)[O-])cc1Cl,-1.817468177
COc1ccccc1[C@H](C)NC(=O)C[C@H]1C[NH2+]CCO1,-3.588097047
Cc1ccc(NC(=O)C(=O)NCCc2ccccc2F)c(C)c1,0.913186044
CC(C)(C)OC(=O)N[C@H]1CCN(c2cc(-c3cccs3)n[nH]2)C1,0.013574908
CC[C@H](C)C[NH+]1CCCC[C@@H]1C(=O)NC(C)(C)C,-3.402467722
COC(=O)c1c(S(=O)(=O)NC2CC2)sc2c1CCN(Cc1ccc3ccccc3c1)C2,1.201373394
Cc1cc(C)cc(NC(=O)[C@@H](Sc2nnnn2C2CC2)c2ccccc2)c1,1.388461417
C=CCn1c([S-])nnc1-c1sc(NC(=O)c2cccc(NC(C)=O)c2)nc1C,0.069060976
O=C(CNc1nc(C2CC2)no1)N1CCc2sccc2C1,-0.776130253
Cc1ccccc1NC(=O)NCCn1c([N+](=O)[O-])cnc1C,-0.028472512
C[C@H](NC=C(C#N)C#N)c1ccc(-n2cncn2)cc1,-1.428595708
CN1C[C@H](C(=O)NC[C@H]2CCC(C)(C)c3ccccc32)CC1=O,-0.823312472
CC(C)n1cccc1C(=O)N[C@H]1CCc2cc(N)ccc21,0.181217976
CCN1C(=O)C(C#N)=C(C)/C(=C\Nc2ccc([N+](=O)[O-])cc2)C1=O,-0.381687124
O=C(Nc1ccnn1Cc1cccc(Cl)c1)C1CCN(S(=O)(=O)c2ccccc2F)CC1,1.466878748
COc1cc([C@@H]2N[C@H](C(=O)[O-])CS2)ccc1OC(=O)c1ccccc1,-1.905611868
CCOc1ccc(C2=CCN(C(=O)c3ccccc3F)CC2)cc1,2.153553484
CC(=O)N1CC[C@@H]([NH2+][C@@H](C)CSCC(C)C)C1,-3.535374272
CCO[P@](=O)(CC(=O)[O-])c1ccccc1,-3.133793788
O=C(CCc1nc2ccccc2c(=O)[nH]1)N1CCC[C@H]1C1CCCC1,0.262289845
O=C(/C=C/SCc1ccco1)Nc1cccc(N2CCCC2=O)c1,1.207148734
CC(=O)c1c(C)[nH]c(C(=O)OCC(=O)N2C[C@H](C)C[C@@H](C)C2)c1C,-0.573012859
Cc1ccc(C(=O)N2CCC[C@@H](C(=O)N(C)CC(=O)NC(C)C)C2)cc1,-0.662643438
C[C@@H](C#N)Sc1nnc(C23CC4CC(CC(C4)C2)C3)o1,-1.035316432
CC1(C)CCC[C@H]1n1c(N)[nH+]c2ccccc21,-1.363166768
CCOc1ccc(C[NH+]2CCS[C@H]3COCC[C@@H]32)cc1OC,-3.131022047
COc1cc(OC)c([C@@H](C)[NH2+]C[C@H](O)C2CCOCC2)cc1Cl,-2.191884106
COC(C[C@@](C)(O)C[C@@H]1CCCN(S(C)(=O)=O)C1)OC,-2.861796168
C=CCSc1nnc(C2CCOCC2)n1N,-1.661013358
O=C(Nc1c[nH]c2ccccc12)N1CCCC[C@@H]1CCO,0.261902615
Cc1nc(-c2ccc(-c3noc(C4CC4)n3)cc2)cs1,1.912700226
C[C@H]1CCCCN1C(=O)[C@@H](C)NC(=O)C(=O)Nc1ccnn1C(C)(C)C,-1.936204414
c1cc2c(cc1N[C@@H]1CCOC3(CCC3)C1)CCC2,0.325838871
O=C(Nc1ccc(CNC(=O)N2CCCCCCC2)cc1)c1ccco1,2.076036098
CNC(=O)[C@H]1CCCCN1c1cccc(F)c1C(N)=S,-1.26437551
O=C(NCc1nnc2n1CCC2)c1cc(-c2ccccc2)[nH]n1,-0.832471447
Cc1nnc(-c2cc(CC(C)C)n(-c3cccc(C(=O)N[C@@H]4C[C@@H](C)CC(C)(C)C4)c3)n2)o1,1.825655712
COc1cccc([C@@H]2CC(=O)C3=C(C2)NC(=O)C[C@H]3c2ccc(OC)c(OC)c2OC)c1,0.498859597
C[NH2+]Cc1ccc(-c2cc(C)ccc2F)s1,-0.572811552
CCC[C@H](NC(N)=O)C(=O)NC[C@@H]1CCCO1,-2.750759841
COc1cc(F)cc(CNC(=O)[C@H]2CCCN2C(=O)Cc2ccccc2)c1,0.257023115
CCc1cc(C#N)c(NC(=O)CCC(=O)N(CC)CC)s1,0.275728687
C#Cc1cccc(NC(=O)C[NH+](C)[C@H](C)c2nnc(-c3ccccc3)o2)c1,-1.846997449
CCOC[C@H](O)[C@](C)(CC)[NH+]1CCCC1,-4.896705513
C[NH+](C)Cc1cccc(CNC(=O)NC[C@H](O)c2ccc(F)cc2)c1,-2.23879101
Cc1ccsc1C(=O)NCCNC(=O)c1ccco1,-0.305262277
COC(=O)[C@H](C)N(Cc1ccccc1)C(=O)Cc1ccon1,-1.044654716
C/C=C(/C)C(=O)NC[C@H]1C[C@@]12CCc1ccccc12,-0.968019906
O=C1Nc2ccccc2Oc2cc([N+](=O)[O-])cc(Oc3ccc(Cl)cc3)c21,2.997553659
CCc1onc(C)c1NC(=O)CCCC(C)(C)C,1.098719959
COC(=O)C(C)(C)COc1ccc2c(c1)OC[C@@H]2[NH3+],-2.636525387
CC#CCCC(=O)Nc1cccc2c1C(=O)c1ccccc1C2=O,0.905040047
Cc1nc(C)c(S(=O)(=O)/N=C(\[O-])C[C@H]2CCCO2)s1,-3.036448405
CCC(=O)N[C@H](CCSC)C(=O)Nc1ccc2[nH]c(C)cc2c1,0.330593743
COCC[NH+](C)Cc1c(C)cc(C)c(C(C)=O)c1C,-2.479543091
Cc1cscc1CNC(=O)N[C@@H](C)c1nc(C(=O)[O-])cs1,-2.191538628
CNC(=O)NCC(=O)NC[C@H](c1cccnc1)C(C)C,-1.928716489
COc1cccc(C(=O)N2CCC[C@H]2[C@H]2CCC[C@@H]2O)c1,-0.541585916
COc1ccc2c(c1)[C@H]([NH2+][C@H](C)CCN1CCOCC1)CCCO2,-2.206030146
COc1cccc(OCC(=O)N2CCN(c3nc4ccc(S(C)(=O)=O)cc4s3)CC2)c1,0.190603464
COc1cccc(-c2nc(C(=O)O[C@@H](C)CNC(C)=O)c(C)[nH]2)c1,-0.79450049
CCO[C@H]1C[C@@H]1C(=O)Nc1ccc(Sc2nncs2)c(Cl)c1,0.260697763
Cc1ccc(N(CC(=O)N2CCCC[C@H]2C)S(=O)(=O)c2c(C)nn(C)c2C)cc1,0.068893371
COc1ccc(C(=O)Nc2nc(CC(=O)Nc3ccc(N(C)C)cc3)cs2)cc1,1.667550463
O=c1c2ccccc2sn1-c1ncc(Br)s1,0.417952528
Cc1ocnc1CNC(=O)N(C)Cc1ccc(OC(F)F)cc1,0.32323239
COc1cc(Cl)c(C)cc1NC(=O)C(=O)NCc1ccccc1C[NH+](C)C,-1.314155026
C[C@@H]([NH3+])c1cccc(Oc2cc(Br)ccc2Cl)c1,1.232550734
CC(C)NS(=O)(=O)c1cccc(C(=O)NCc2ccco2)c1,-0.065316674
CN(Cc1ncnn1C)C(=O)CSCc1ccccn1,-1.427183407
================================================
FILE: graphium/data/normalization.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Optional
from loguru import logger
import numpy as np
import torch
from torch import Tensor
class LabelNormalization:
def __init__(
self,
method: Optional[str] = None,
min_clipping: Optional[int] = None,
max_clipping: Optional[int] = None,
normalize_val_test: Optional[bool] = False,
verbose: Optional[bool] = True,
):
"""
Parameters:
method: str
Normalization method. Supports the following values:
- `None` (default): No normalization applied
- `normal`: Normalize to have 0-mean and 1-variance
- `unit`: Normalize to have all values in the range 0-1
min_clipping: int
Minimum value to clip to. If `None` (default), no clipping is applied.
For example, if `min_clipping` is -2, all values below -2 will be clipped to -2.
This is applied before the normalization.
max_clipping: int
Maximum value to clip to. If `None` (default), no clipping is applied.
For example, if `max_clipping` is 2, all values above 2 will be clipped to 2.
This is applied before the normalization.
"""
self.method = method
self.min_clipping = min_clipping
self.max_clipping = max_clipping
self.normalize_val_test = normalize_val_test
self.verbose = verbose
self.data_max = None
self.data_min = None
self.data_mean = None
self.data_std = None
def calculate_statistics(self, array):
"""
Saves the normalization parameters (e.g. mean and variance) to the object.
"""
# add axis = 0 to make sure that the statistics for multiple column labels is a vector or list instead of a scalar
self.data_max = np.nanmax(array, axis=0).tolist()
self.data_min = np.nanmin(array, axis=0).tolist()
self.data_mean = np.nanmean(array, axis=0).tolist() # 5.380503871833475 for pcqm4mv2
self.data_std = np.nanstd(array, axis=0).tolist() # 1.17850688410978995 for pcqm4mv2
if self.verbose:
logger.info(f"Max value for normalization '{self.data_max}'")
logger.info(f"Min value for normalization '{self.data_min}'")
logger.info(f"Mean value for normalization '{self.data_mean}'")
logger.info(f"STD value for normalization '{self.data_std}'")
def normalize(self, input):
"""
Apply the normalization method to the data.
Saves the normalization parameters (e.g. mean and variance) to the object.
"""
assert self.data_max is not None, "calculate_statistic must be called before applying normalization"
if self.min_clipping is not None:
self.data_min = max(self.min_clipping, self.data_min)
if self.max_clipping is not None:
self.data_max = min(self.max_clipping, self.data_max)
clipping = self.min_clipping is not None and self.max_clipping is not None
# Need to check since np.clip fails if both a_min and a_max are None
if clipping:
if isinstance(input, np.ndarray):
input = np.clip(input, a_min=self.data_min, a_max=self.data_max)
elif isinstance(input, Tensor):
input = torch.clip(input, min=self.data_min, max=self.data_max)
if self.method is None:
return input
elif self.method == "normal":
return (input - self.data_mean) / self.data_std
elif self.method == "unit":
return (input - self.data_min) / (self.data_max - self.data_min)
else:
raise ValueError(f"normalization method {self.method} not recognised.")
def denormalize(self, input):
"""
Apply the inverse of the normalization method to the data.
"""
if self.method is None:
return input
elif self.method == "normal":
mean, std = torch.tensor(self.data_mean), torch.tensor(self.data_std)
if input.device.type != "ipu": # Cast to device if not on IPU
mean, std = mean.to(input.device), std.to(input.device)
return (input * std) + mean
elif self.method == "unit":
dmax, dmin = torch.tensor(self.data_max), torch.tensor(self.data_min)
if input.device.type != "ipu": # Cast to device if not on IPU
dmax, dmin = dmax.to(input.device), dmin.to(input.device)
return input * (dmax - dmin) + dmin
else:
raise ValueError(f"normalization method {self.method} not recognised.")
================================================
FILE: graphium/data/sampler.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Dict, Optional
from torch.utils.data.dataloader import Dataset
import torch.utils.data as data_utils
import numpy as np
import time
import os
import torch
class DatasetSubSampler(data_utils.Sampler):
def __init__(
self, dataset: Dataset, sampler_task_dict: Dict[str, Optional[float]], data_path: str, data_hash: str
):
"""
Random sample a subset of the dataset for each task each epoch and combine them for training.
Parameters:
dataset: the whole training dataset
sampler_task_dict: a dict which indicates the sampled amount of data for each task.
data_path: path to save the data indices.
data_hash: hash folder name for the data indices.
"""
self.dataset = dataset
self.sampler_task_dict = sampler_task_dict
self.task_indices = {}
now = time.time()
path_with_hash = os.path.join(data_path, data_hash)
os.makedirs(path_with_hash, exist_ok=True)
filename = os.path.join(path_with_hash, "dataset_indices.pkl")
# if file dataset_indices.pkl exists, load indices from file.
if os.path.isfile(filename):
print(f"Sampler--loading dataset indices from disk.")
self.task_indices = torch.load(filename)
# if not, iterate through all data items in dataset and save indices to disk
else:
for i in range(len(dataset)):
# the dataset[i]["labels"].keys() are not deterministic, need to sort by key length
task_name = sorted(dataset[i]["labels"].keys(), key=len)[-1]
if task_name not in self.task_indices:
self.task_indices[task_name] = []
self.task_indices[task_name].append(i)
elapsed = round(time.time() - now)
print(f"Sampler-->time spent on getting indices: {elapsed}.")
torch.save(self.task_indices, filename, pickle_protocol=4)
def __iter__(self):
indices = []
for task_name in self.task_indices.keys():
task_size = int(len(self.task_indices[task_name]) * self.sampler_task_dict[task_name])
indices += np.random.choice(self.task_indices[task_name], task_size, replace=False).tolist()
indices_set = set(indices)
self.total_size = len(indices_set)
return iter(indices_set)
def __len__(self):
return self.total_size
@classmethod
def check_sampling_required(cls, sampler_task_dict):
"""
Check if we need subsampling: if all items in the sampler_task_dict are 1.0,
skip subsampling.
"""
return not all(value == 1.0 for value in sampler_task_dict.values())
================================================
FILE: graphium/data/sdf2csv.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs and Recursion Pharmaceuticals are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import datamol as dm
import zipfile
import pandas as pd
from rdkit import Chem
from ogb.lsc import PCQM4Mv2Dataset
import graphium
from graphium.data.datamodule import BaseDataModule
import csv
def extract_zip(fname):
"""
#extract sdf from zip
"""
zf = zipfile.ZipFile(fname)
zf.extractall(".")
def extract_mols_from_sdf(fname):
"""
load sdf into mols
"""
mol_df = BaseDataModule._read_sdf(fname)
mols = mol_df["_rdkit_molecule_obj"]
return mols
def mols2cxs(mols):
"""
convert into a smiles that contains the 3D structure
"""
cxs = []
for mol in mols:
cxs.append(dm.to_smiles(mol, cxsmiles=True))
return cxs
def write_csv(cxs: any, homos: any, fname: str):
"""
write cxsmiles and homo lomo to file
write the training molecules with cxsmiles first
"""
outname = fname + ".csv"
fieldnames = ["cxsmiles", "homo_lumo_gap"]
with open(outname, "w") as outfile:
writer = csv.DictWriter(outfile, fieldnames=fieldnames)
writer.writeheader()
for i in range(len(cxs)):
writer.writerow({"cxsmiles": cxs[i], "homo_lumo_gap": homos[i]})
def sdf2csv(sdf_name: str = "pcqm4m-v2-train", outname: str = "pcqm4m-v2-train"):
"""
function converting sdf file molecules into csv format using CXSmiles and combine with mols without 3d positions from ogb
Parameters:
sdf_name: name to the extracted sdf file
outname: output name of the cxsmile file
"""
mols = extract_mols_from_sdf(sdf_name + ".sdf")
# download ogb smiles
dataset = PCQM4Mv2Dataset(root=".", only_smiles=True) # (smiles, homo_lomo)
homos = []
for i in range(len(mols)):
homo = dataset[i][1]
homos.append(homo)
# write the trainning set molecules first with cxsmiles
cxs = mols2cxs(mols)
for j in range(len(mols), len(dataset)):
cxs.append(dataset[j][0])
homos.append(dataset[j][1])
write_csv(cxs, homos, outname)
if __name__ == "__main__":
"""
#* main function
#! this script need to be located at the specific data folder as it uses relative dependencies
for example #* graphium/data/PCQM4Mv2
instruction on how to generate the csv file:
1. download the extract the sdf file from ogb: https://ogb.stanford.edu/docs/lsc/pcqm4mv2/
$ wget http://ogb-data.stanford.edu/data/lsc/pcqm4m-v2-train.sdf.tar.gz
$ md5sum pcqm4m-v2-train.sdf.tar.gz
$ tar -xf pcqm4m-v2-train.sdf.tar.gz
2. run this function (the smiles csv file will be directed downloaded in code)
"""
sdf_name = "pcqm4m-v2-train"
outname = "pcqm4m-v2-train"
sdf2csv(sdf_name=sdf_name, outname=outname)
#! check how many warning you get from loading cxsmiles
# path = "pcqm4m-v2-train.csv"
# df = pd.read_csv(path)
# smiles = df["cxsmiles"]
# print (smiles[0])
# graphium.data.datamodule.smiles_to_unique_mol_ids(smiles)
================================================
FILE: graphium/data/single_atom_dataset/single_atom_dataset.csv
================================================
SMILES,score
[CL-],17
[Br-],35
C,6
O,8
S,16
[O-],8
N,7
================================================
FILE: graphium/data/smiles_transform.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs and Recursion Pharmaceuticals are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Type, List, Dict, Union, Any, Callable, Optional, Tuple, Iterable
import os
import datamol as dm
def smiles_to_unique_mol_id(smiles: str) -> Optional[str]:
"""
Convert a smiles to a unique MD5 Hash ID. Returns None if featurization fails.
Parameters:
smiles: A smiles string to be converted to a unique ID
Returns:
mol_id: a string unique ID
"""
try:
mol = dm.to_mol(
mol=smiles
) # Doesn't need `ordered=True` because the unique_id doesn't depend on the atom order
mol_id = dm.unique_id(mol)
except:
mol_id = ""
if mol_id is None:
mol_id = ""
return mol_id
def did_featurization_fail(features: Any) -> bool:
"""
Check if a featurization failed.
"""
return (features is None) or isinstance(features, str)
class BatchingSmilesTransform:
"""
Class to transform a list of smiles using a transform function
"""
def __init__(self, transform: Callable):
"""
Parameters:
transform: Callable function to transform a single smiles
"""
self.transform = transform
def __call__(self, smiles_list: Iterable[str]) -> Any:
"""
Function to transform a list of smiles
"""
mol_id_list = []
for smiles in smiles_list:
mol_id_list.append(self.transform(smiles))
return mol_id_list
@staticmethod
def parse_batch_size(numel: int, desired_batch_size: int, n_jobs: int) -> int:
"""
Function to parse the batch size.
The batch size is limited by the number of elements divided by the number of jobs.
"""
assert ((n_jobs >= 0) or (n_jobs == -1)) and isinstance(
n_jobs, int
), f"n_jobs must be a positive integer or -1, got {n_jobs}"
assert (
isinstance(desired_batch_size, int) and desired_batch_size >= 0
), f"desired_batch_size must be a positive integer, got {desired_batch_size}"
if n_jobs == -1:
n_jobs = os.cpu_count()
if (n_jobs == 0) or (n_jobs == 1):
batch_size = 1
else:
batch_size = min(desired_batch_size, numel // n_jobs)
batch_size = max(1, batch_size)
return batch_size
def smiles_to_unique_mol_ids(
smiles: Iterable[str],
n_jobs=-1,
featurization_batch_size=1000,
backend="loky",
progress=True,
progress_desc="mols to ids",
) -> List[Optional[str]]:
"""
This function takes a list of smiles and finds the corresponding datamol unique_id
in an element-wise fashion, returning the corresponding unique_ids.
The ID is an MD5 hash of the non-standard InChiKey provided
by `dm.to_inchikey_non_standard()`. It guarantees uniqueness for
different tautomeric forms of the same molecule.
Parameters:
smiles: a list of smiles to be converted to mol ids
n_jobs: number of jobs to run in parallel
backend: Parallelization backend
progress: Whether to display the progress bar
Returns:
ids: A list of MD5 hash ids
"""
batch_size = BatchingSmilesTransform.parse_batch_size(
numel=len(smiles), desired_batch_size=featurization_batch_size, n_jobs=n_jobs
)
unique_mol_ids = dm.parallelized_with_batches(
BatchingSmilesTransform(smiles_to_unique_mol_id),
smiles,
batch_size=batch_size,
progress=progress,
n_jobs=n_jobs,
backend=backend,
tqdm_kwargs={"desc": f"{progress_desc}, batch={batch_size}"},
)
return unique_mol_ids
================================================
FILE: graphium/data/utils.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Union, List, Callable, Dict, Tuple, Any, Optional
import importlib.resources
import zipfile
from loguru import logger
import pandas as pd
import numpy as np
import graphium
from torch_geometric.data import Data
from graphium.features.featurizer import GraphDict
GRAPHIUM_DATASETS_BASE_URL = "https://storage.valencelabs.com/graphium/datasets"
GRAPHIUM_DATASETS = {
"graphium-zinc-micro": "zinc-micro.zip",
"graphium-zinc-bench-gnn": "zinc-bench-gnn.zip",
}
def load_micro_zinc() -> pd.DataFrame:
"""
Return a dataframe of micro ZINC (1000 data points).
Returns:
pd.DataFrame: A dataframe of micro ZINC.
"""
with importlib.resources.open_text("graphium.data.micro_ZINC", "micro_ZINC.csv") as f:
df = pd.read_csv(f)
return df # type: ignore
def load_tiny_zinc() -> pd.DataFrame:
"""
Return a dataframe of tiny ZINC (100 data points).
Returns:
pd.DataFrame: A dataframe of tiny ZINC.
"""
with importlib.resources.open_text("graphium.data.micro_ZINC", "micro_ZINC.csv") as f:
df = pd.read_csv(f, nrows=100)
return df # type: ignore
def graphium_package_path(graphium_path: str) -> str:
"""Return the path of a graphium file in the package."""
assert graphium_path.startswith(
"graphium://"
), f"Invalid graphium path, must start with 'graphium://': {graphium_path}"
graphium_path = graphium_path.replace("graphium://", "")
package, ressource = graphium_path.split("/")
with importlib.resources.path(package, ressource) as data_path:
pass
return str(data_path)
def list_graphium_datasets() -> set:
"""
List Graphium datasets available to download.
Returns:
set: A set of Graphium dataset names.
"""
return set(GRAPHIUM_DATASETS.keys())
def download_graphium_dataset(
name: str, output_path: str, extract_zip: bool = True, progress: bool = False
) -> str:
r"""Download a Graphium dataset to a specified location.
Args:
name: Name of the Graphium dataset from `graphium.data.utils.get_graphium_datasets()`.
output_path: Directory path where to download the dataset to.
extract_zip: Whether to extract the dataset if it's a zip file.
progress: Whether to show a progress bar during download.
Returns:
str: Path to the downloaded dataset.
"""
if name not in GRAPHIUM_DATASETS:
raise ValueError(f"'{name}' is not a valid Graphium dataset name. Choose from {GRAPHIUM_DATASETS}")
fname = GRAPHIUM_DATASETS[name]
dataset_path_source = graphium.utils.fs.join(GRAPHIUM_DATASETS_BASE_URL, fname)
dataset_path_destination = graphium.utils.fs.join(output_path, fname)
if not graphium.utils.fs.exists(dataset_path_destination):
graphium.utils.fs.copy(dataset_path_source, dataset_path_destination, progress=progress)
if extract_zip and str(dataset_path_destination).endswith(".zip"):
# Unzip the dataset
with zipfile.ZipFile(dataset_path_destination, "r") as zf:
zf.extractall(output_path)
if extract_zip:
# Set the destination path to the folder
# NOTE(hadim): this is a bit fragile.
dataset_path_destination = dataset_path_destination.split(".")[0]
return dataset_path_destination
def get_keys(pyg_data):
if isinstance(type(pyg_data).keys, property):
return pyg_data.keys
else:
return pyg_data.keys()
def found_size_mismatch(task: str, features: Union[Data, GraphDict], labels: np.ndarray, smiles: str) -> bool:
"""Check if a size mismatch exists between features and labels with respect to node/edge/nodepair.
Args:
task: The task name is needed to determine the task level (graph, node, edge or nodepair)
features: Features/information of molecule/graph (e.g., edge_index, feat, edge_feat, num_nodes, etc.)
labels: Target label of molecule for the task
smiles: Smiles string of molecule
Returns:
mismatch: Boolean variable indicating if a size mismatch was found between featurs and labels.
"""
mismatch = False
if np.isnan(labels).any():
pass
elif task.startswith("graph_"):
pass
elif task.startswith("node_"):
if labels.shape[0] != features.num_nodes:
mismatch = True
logger.warning(
(
f"Inconsistent number of nodes between labels and features in {task} task for {smiles}: {labels.shape[0]} vs {features.num_nodes}"
)
)
elif task.startswith("edge_"):
if labels.shape[0] != features.num_edges:
mismatch = True
logger.warning(
(
f"Inconsistent number of edges between labels and features in {task} task for {smiles}: {labels.shape[0]} vs {features.num_edges}"
)
)
elif task.startswith("nodepair_"):
if list(labels.shape[:2]) != 2 * [features.num_nodes]:
mismatch = True
logger.warning(
(
f"Inconsistent shape of nodepairs between labels and features in {task} task for {smiles}: {list(labels.shape[:2])} vs {2 * [features.num_nodes]}"
)
)
else:
raise ValueError("Unkown task level")
return mismatch
================================================
FILE: graphium/expts/pyg_batching_sparse.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch version: 1.13.0+cpu\n",
"pyg version: 2.3.0.dev20230306\n",
"batch.x = tensor(indices=tensor([[0, 1, 1, 2, 3, 4],\n",
" [1, 0, 1, 1, 0, 1]]),\n",
" values=tensor([1, 2, 3, 4, 5, 6]),\n",
" size=(5, 2), nnz=6, layout=torch.sparse_coo)\n",
"Data(x=[2, 2])\n",
"Data(x=[2, 2])\n",
"[Data(x=[2, 2]), Data(x=[3, 2])]\n",
"[tensor(indices=tensor([[0, 1, 1],\n",
" [1, 0, 1]]),\n",
" values=tensor([1, 2, 3]),\n",
" size=(2, 2), nnz=3, layout=torch.sparse_coo), tensor(indices=tensor([[0, 1, 2],\n",
" [1, 0, 1]]),\n",
" values=tensor([4, 5, 6]),\n",
" size=(3, 2), nnz=3, layout=torch.sparse_coo)]\n"
]
}
],
"source": [
"import torch\n",
"import torch_geometric\n",
"from torch_geometric.data import Data, Batch\n",
"\n",
"data1 = Data(x=torch.sparse_coo_tensor(torch.tensor([[0, 1, 1], [1, 0, 1]]), torch.tensor([1, 2, 3]), (2, 2)))\n",
"data2 = Data(x=torch.sparse_coo_tensor(torch.tensor([[0, 1, 2], [1, 0, 1]]), torch.tensor([4, 5, 6]), (3, 2)))\n",
"batch = Batch.from_data_list([data1, data2])\n",
"\n",
"print(\"torch version: \", torch.__version__)\n",
"print(\"pyg version: \", torch_geometric.__version__)\n",
"print(\"batch.x = \", batch.x) # WORKS\n",
"print(batch.get_example(0)) # FAILS\n",
"print(batch[0]) # FAILS\n",
"print(batch.to_data_list())\n",
"print([b.x for b in batch.to_data_list()])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(indices=tensor([[0, 1, 1, 2],\n",
" [1, 0, 1, 0]]),\n",
" values=tensor([1, 2, 3, 5]),\n",
" size=(3, 2), nnz=4, layout=torch.sparse_coo)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch.x.index_select(dim=0, index=torch.tensor([0, 1, 3]))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "graphium_ipu",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: graphium/features/README.md
================================================
The Graph Of LIfe Library.
## What is in this folder?
- ✅ `featurizer.py`: featurization code for the molecules, adding node, edge and graph features to the mol object
- `nmp.py`: check if a string can be converted to float, helper function for featurization
- `positional_encoding.py`: code for computing all raw positional and structural encoding of the graph, see `graph_positional_encoder` function
- `properties.py`: code for computing properties of the molecule
- `rw.py`: code for computing random walk positional encoding
- `spectral.py`: code for computing the spectral positional encoding such as the Laplacian eigenvalues and eigenvectors
================================================
FILE: graphium/features/__init__.py
================================================
from .featurizer import get_mol_atomic_features_onehot
from .featurizer import get_mol_atomic_features_float
from .featurizer import get_mol_edge_features
from .featurizer import mol_to_adj_and_features
from .featurizer import mol_to_graph_dict
from .featurizer import mol_to_graph_signature
from .featurizer import GraphDict
from .featurizer import mol_to_pyggraph
from .featurizer import to_dense_array
================================================
FILE: graphium/features/commute.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Tuple, Union, Dict, Any
import numpy as np
from scipy.sparse import spmatrix, issparse
from scipy.linalg import pinv
def compute_commute_distances(
adj: Union[np.ndarray, spmatrix], num_nodes: int, cache: Dict[str, Any]
) -> Tuple[np.ndarray, str, Dict[str, Any]]:
"""
Compute avg. commute time/distance between nodepairs. This is the avg. number of steps a random walker, starting
at node i, will take before reaching a given node j for the first time, and then return to node i.
Reference: Saerens et al. "The principal components analysis of a graph, and its relationships to spectral clustering." ECML. 2004.
Parameters:
adj [num_nodes, num_nodes]: Adjacency matrix
num_nodes: Number of nodes in the graph
cache: Dictionary of cached objects
Returns:
dist [num_nodes, num_nodes]: 2D array with avg. commute distances between nodepairs
base_level: Indicator of the output pos_level (node, edge, nodepair, graph) -> here nodepair
cache: Updated dictionary of cached objects
"""
base_level = "nodepair"
if "commute" in cache:
dist = cache["commute"]
else:
if issparse(adj):
adj = adj.toarray()
volG = adj.sum()
if "pinvL" in cache:
pinvL = cache["pinvL"]
else:
L = np.diagflat(np.sum(adj, axis=1)) - adj
pinvL = pinv(L)
cache["pinvL"] = pinvL
dist = volG * np.asarray(
[
[pinvL[i, i] + pinvL[j, j] - 2 * pinvL[i, j] for j in range(num_nodes)]
for i in range(num_nodes)
]
)
cache["commute"] = dist
return dist, base_level, cache
================================================
FILE: graphium/features/electrostatic.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Tuple, Union, Dict, Any
import numpy as np
from scipy.linalg import pinv
from scipy.sparse import spmatrix, issparse
def compute_electrostatic_interactions(
adj: Union[np.ndarray, spmatrix], cache: Dict[str, Any]
) -> Tuple[np.ndarray, str, Dict[str, Any]]:
"""
Compute electrostatic interaction of nodepairs.
Parameters:
adj [num_nodes, num_nodes]: Adjacency matrix
cache: Dictionary of cached objects
Returns:
electrostatic [num_nodes, num_nodes]: 2D array with electrostatic interactions of node nodepairs
base_level: Indicator of the output pos_level (node, edge, nodepair, graph) -> here nodepair
cache: Updated dictionary of cached objects
"""
base_level = "nodepair"
if "electrostatic" in cache:
electrostatic = cache["electrostatic"]
else:
if "pinvL" in cache:
pinvL = cache["pinvL"]
else:
if issparse(adj):
adj = adj.toarray()
L = np.diagflat(np.sum(adj, axis=1)) - adj
pinvL = pinv(L)
cache["pinvL"] = pinvL
electrostatic = pinvL - np.diag(pinvL) # This means that the "ground" is set to any given atom
cache["electrostatic"] = electrostatic
return electrostatic, base_level, cache
================================================
FILE: graphium/features/featurizer.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Union, List, Callable, Dict, Tuple, Any, Optional
import inspect
from loguru import logger
import numpy as np
from scipy.sparse import issparse, coo_matrix
import torch
from torch import Tensor
from torch_geometric.data import Data
from rdkit import Chem
import datamol as dm
from graphium.features import nmp
from graphium.utils.tensor import one_of_k_encoding
from graphium.features.positional_encoding import get_all_positional_encodings
def to_dense_array(array: np.ndarray, dtype: str = None) -> np.ndarray:
r"""
Assign the node data
Parameters:
array: The array to convert to dense
dtype: The dtype of the array
Returns:
The dense array
"""
if array is not None:
if issparse(array):
if array.dtype == np.float16: # float16 doesn't support `todense`
array = array.astype(np.float32)
array = array.todense()
if dtype is not None:
array = array.astype(dtype)
return array
def to_dense_tensor(tensor: Tensor, dtype: str = None) -> Tensor:
r"""
Assign the node data
Parameters:
array: The array to convert to dense
dtype: The dtype of the array
Returns:
The dense array
"""
if tensor is not None:
if tensor.is_sparse:
tensor = tensor.todense()
if dtype is not None:
tensor = tensor.to(dtype)
return tensor
def _mask_nans_inf(mask_nan: Optional[str], array: np.ndarray, array_name: str) -> np.ndarray:
r"""
mask the NaNs in the array
Parameters:
mask_nan: How to mask the NaNs
array: The array to mask
array_name: The name of the array
Returns:
The masked array
"""
if (mask_nan is None) or (array is None):
return array
new_array = array
if issparse(new_array):
new_array = new_array.data
nans = ~np.isfinite(new_array)
# Mask the NaNs
if nans.any():
msg = f"There are {np.sum(nans)} NaNs in `{array_name}`"
if mask_nan == "raise":
raise ValueError(msg)
elif mask_nan == "warn":
logger.warning(msg)
else:
new_array[nans] = mask_nan
if issparse(array):
array.data = new_array
new_array = array
return new_array
def get_mol_atomic_features_onehot(mol: dm.Mol, property_list: List[str]) -> Dict[str, Tensor]:
r"""
Get the following set of features for any given atom
* One-hot representation of the atom
* One-hot representation of the atom degree
* One-hot representation of the atom implicit valence
* One-hot representation of the the atom hybridization
* Whether the atom is aromatic
* The atom's formal charge
* The atom's number of radical electrons
Additionally, the following features can be set, depending on the value of input Parameters
* One-hot representation of the number of hydrogen atom in the the current atom neighborhood if `explicit_H` is false
* One-hot encoding of the atom chirality, and whether such configuration is even possible
Parameters:
mol:
molecule from which to extract the properties
property_list:
A list of integer atomic properties to get from the molecule.
The integer values are converted to a one-hot vector.
Callables are not supported by this function.
Accepted properties are:
- "atomic-number"
- "degree"
- "valence", "total-valence"
- "implicit-valence"
- "hybridization"
- "chirality"
- "phase"
- "type"
- "group"
- "period"
Returns:
prop_dict:
A dictionnary where the element of ``property_list`` are the keys
and the values are np.ndarray of shape (N, OH). N is the number of atoms
in ``mol`` and OH the lenght of the one-hot encoding.
"""
prop_dict = {}
for prop in property_list:
prop = prop.lower()
prop_name = prop
property_array = []
for ii, atom in enumerate(mol.GetAtoms()):
if prop in ["atomic-number"]:
one_hot = one_of_k_encoding(atom.GetSymbol(), nmp.ATOM_LIST)
elif prop in ["degree"]:
one_hot = one_of_k_encoding(atom.GetDegree(), nmp.ATOM_DEGREE_LIST)
elif prop in ["valence", "total-valence"]:
prop_name = "valence"
one_hot = one_of_k_encoding(atom.GetTotalValence(), nmp.VALENCE)
elif prop in ["implicit-valence"]:
one_hot = one_of_k_encoding(atom.GetImplicitValence(), nmp.VALENCE)
elif prop in ["hybridization"]:
one_hot = one_of_k_encoding(atom.GetHybridization(), nmp.HYBRIDIZATION_LIST)
elif prop in ["chirality"]:
try:
one_hot = one_of_k_encoding(atom.GetProp("_CIPCode"), nmp.CHIRALITY_LIST)
one_hot.append(int(atom.HasProp("_ChiralityPossible")))
except:
one_hot = [0, 0, int(atom.HasProp("_ChiralityPossible"))]
elif prop in "phase":
one_hot = one_of_k_encoding(nmp.PHASE[atom.GetAtomicNum() - 1], nmp.PHASE_SET)
elif prop in "type":
one_hot = one_of_k_encoding(nmp.TYPE[atom.GetAtomicNum() - 1], nmp.TYPE_SET)
elif prop in "group":
one_hot = one_of_k_encoding(nmp.GROUP[atom.GetAtomicNum() - 1], nmp.GROUP_SET)
elif prop in "period":
one_hot = one_of_k_encoding(nmp.PERIOD[atom.GetAtomicNum() - 1], nmp.PERIOD_SET)
else:
raise ValueError(f"Unsupported property `{prop}`")
property_array.append(np.asarray(one_hot, dtype=np.float16))
prop_dict[prop_name] = np.stack(property_array, axis=0)
return prop_dict
def get_mol_conformer_features(
mol: dm.Mol,
property_list: Union[List[str], List[Callable]],
mask_nan: Optional[Union[float, str]] = None,
) -> Dict[str, np.ndarray]:
r"""obtain the conformer features of a molecule
Parameters:
mol:
molecule from which to extract the properties
property_list:
A list of conformer property to get from the molecule
Accepted properties are:
- "positions_3d"
Returns:
prop_dict: a dictionary where the element of ``property_list`` are the keys
"""
prop_dict = {}
has_conf = True
try:
mol.GetConformer()
except:
has_conf = False
# * currently only accepts "positions_3d", raise errors otherwise
for prop in property_list:
if isinstance(prop, str):
if prop in ["positions_3d"]: # locating 3d conformer coordinates
if not has_conf:
positions = np.full((mol.GetNumAtoms(), 3), float("nan"), dtype=np.float16)
else:
positions = [[], [], []]
for i in range(mol.GetNumAtoms()):
pos = mol.GetConformer().GetAtomPosition(i)
positions[0].append(pos.x)
positions[1].append(pos.y)
positions[2].append(pos.z)
positions = np.asarray(positions, dtype=np.float16).T
prop_dict[prop] = positions
else:
raise ValueError(
str(prop) + " is not currently supported as a conformer property in `property_list`"
)
else:
raise ValueError(f"Elements in `property_list` must be str or callable, provided `{type(prop)}`")
prop_dict[prop] = _mask_nans_inf(mask_nan, prop_dict[prop], prop)
return prop_dict
def get_mol_atomic_features_float(
mol: dm.Mol,
property_list: Union[List[str], List[Callable]],
offset_carbon: bool = True,
mask_nan: Union[str, float, type(None)] = "raise",
) -> Dict[str, np.ndarray]:
r"""
Get a dictionary of floating-point arrays of atomic properties.
To ensure all properties are at a similar scale, some of the properties
are divided by a constant.
There is also the possibility of offseting by the carbon value using
the `offset_carbon` parameter.
Parameters:
mol:
molecule from which to extract the properties
property_list:
A list of atomic properties to get from the molecule, such as 'atomic-number',
'mass', 'valence', 'degree', 'electronegativity'.
Some elements are divided by a factor to avoid feature explosion.
Accepted properties are:
- "atomic-number"
- "mass", "weight"
- "valence", "total-valence"
- "implicit-valence"
- "hybridization"
- "chirality"
- "hybridization"
- "aromatic"
- "ring", "in-ring"
- "min-ring"
- "max-ring"
- "num-ring"
- "degree"
- "radical-electron"
- "formal-charge"
- "vdw-radius"
- "covalent-radius"
- "electronegativity"
- "ionization", "first-ionization"
- "melting-point"
- "metal"
- "single-bond"
- "aromatic-bond"
- "double-bond"
- "triple-bond"
- "is-carbon"
- "group"
- "period"
offset_carbon:
Whether to subract the Carbon property from the desired atomic property.
For example, if we want the mass of the Lithium (6.941), the mass of the
Carbon (12.0107) will be subracted, resulting in a value of -5.0697
mask_nan:
Deal with molecules that fail a part of the featurization.
NaNs can happen when taking the of a noble gas,
or other properties that are not measured for specific atoms.
- "raise": Raise an error when there is a nan or inf in the featurization
- "warn": Raise a warning when there is a nan or inf in the featurization
- "None": DEFAULT. Don't do anything
- "Floating value": Replace nans or inf by the specified value
Returns:
prop_dict:
A dictionnary where the element of ``property_list`` are the keys
and the values are np.ndarray of shape (N,). N is the number of atoms
in ``mol``.
"""
periodic_table = Chem.GetPeriodicTable()
prop_dict = {}
C = Chem.Atom("C")
C_num = C.GetAtomicNum()
offC = bool(offset_carbon)
atom_list = list(mol.GetAtoms())
for prop in property_list:
prop_name = None
property_array = np.zeros(mol.GetNumAtoms(), dtype=np.float16)
for ii, atom in enumerate(atom_list):
val = None
atomic_num = atom.GetAtomicNum()
if isinstance(prop, str):
prop = prop.lower()
prop_name = prop
if prop in ["atomic-number"]:
val = (atomic_num - (offC * C_num)) / 5
elif prop in ["mass", "weight"]:
prop_name = "mass"
val = (atom.GetMass() - (offC * C.GetMass())) / 10
elif prop in ["valence", "total-valence"]:
prop_name = "valence"
val = atom.GetTotalValence() - (offC * 4)
elif prop in ["implicit-valence"]:
val = atom.GetImplicitValence()
elif prop in ["hybridization"]:
val = atom.GetHybridization()
elif prop in ["chirality"]:
val = (atom.GetProp("_CIPCode") == "R") if atom.HasProp("_CIPCode") else 2
elif prop in ["hybridization"]:
val = atom.GetHybridization()
elif prop in ["aromatic"]:
val = atom.GetIsAromatic()
elif prop in ["ring", "in-ring"]:
prop_name = "in-ring"
val = atom.IsInRing()
elif prop in ["min-ring"]:
ring_info = mol.GetRingInfo()
val = ring_info.MinAtomRingSize(atom.GetIdx())
elif prop in ["max-ring"]:
rings = mol.GetRingInfo().AtomRings()
val = 0
for ring in rings:
if atom.GetIdx() in ring:
if len(ring) > val:
val = len(ring)
elif prop in ["num-ring"]:
ring_info = mol.GetRingInfo()
val = ring_info.NumAtomRings(atom.GetIdx())
elif prop in ["degree"]:
val = atom.GetTotalDegree() - (offC * 2)
elif prop in ["radical-electron"]:
val = atom.GetNumRadicalElectrons()
elif prop in ["formal-charge"]:
val = atom.GetFormalCharge()
elif prop in ["vdw-radius"]:
val = periodic_table.GetRvdw(atom.GetAtomicNum()) - offC * periodic_table.GetRvdw(C_num)
elif prop in ["covalent-radius"]:
val = periodic_table.GetRcovalent(atomic_num) - offC * periodic_table.GetRcovalent(C_num)
elif prop in ["electronegativity"]:
val = (
nmp.ELECTRONEGATIVITY[atom.GetAtomicNum() - 1]
- offC * nmp.ELECTRONEGATIVITY[C_num - 1]
)
elif prop in ["ionization", "first-ionization"]:
prop_name = "ionization"
val = (nmp.FIRST_IONIZATION[atomic_num - 1] - offC * nmp.FIRST_IONIZATION[C_num - 1]) / 5
elif prop in ["melting-point"]:
val = (nmp.MELTING_POINT[atomic_num - 1] - offC * nmp.MELTING_POINT[C_num - 1]) / 200
elif prop in ["metal"]:
val = nmp.METAL[atomic_num - 1]
elif prop in "group":
val = float(nmp.GROUP[atomic_num - 1]) - offC * float(nmp.GROUP[C_num - 1])
elif prop in "period":
val = float(nmp.PERIOD[atomic_num - 1]) - offC * float(nmp.PERIOD[C_num - 1])
elif "-bond" in prop:
bonds = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]
if prop in ["single-bond"]:
val = len([bond == 1 for bond in bonds])
elif prop in ["aromatic-bond"]:
val = len([bond == 1.5 for bond in bonds])
elif prop in ["double-bond"]:
val = len([bond == 2 for bond in bonds])
elif prop in ["triple-bond"]:
val = len([bond == 3 for bond in bonds])
else:
raise ValueError(f"{prop} is not a correct bond.")
val -= offC * 1
elif prop in ["is-carbon"]:
val = atom.GetAtomicNum() == 6
val -= offC * 1
else:
raise ValueError(f"Unsupported property `{prop}`")
elif callable(prop):
prop_name = str(prop)
val = prop(atom)
else:
ValueError(f"Elements in `property_list` must be str or callable, provided `{type(prop)}`")
if val is None:
raise ValueError("val is undefined.")
property_array[ii] = val
if prop_name is None:
raise ValueError("prop_name is undefined.")
# Mask the NaNs
prop_dict[prop_name] = _mask_nans_inf(mask_nan, property_array, "atom featurization")
return prop_dict
def get_simple_mol_conformer(mol: dm.Mol) -> Union[Chem.rdchem.Conformer, None]:
r"""
If the molecule has a conformer, then it will return the conformer at idx `0`.
Otherwise, it generates a simple molecule conformer using `rdkit.Chem.rdDistGeom.EmbedMolecule`
and returns it. This is meant to be used in simple functions like `GetBondLength`,
not in functions requiring complex 3D structure.
Parameters:
mol: Rdkit Molecule
Returns:
conf: A conformer of the molecule, or `None` if it fails
"""
val = 0
if mol.GetNumConformers() == 0:
val = Chem.rdDistGeom.EmbedMolecule(mol)
if val == -1:
val = Chem.rdDistGeom.EmbedMolecule(
mol,
enforceChirality=False,
ignoreSmoothingFailures=True,
useBasicKnowledge=True,
useExpTorsionAnglePrefs=True,
forceTol=0.1,
)
if val == -1:
conf = None
logger.warn("Couldn't compute conformer for molecule `{}`".format(Chem.MolToSmiles(mol)))
else:
conf = mol.GetConformer(0)
return conf
def get_estimated_bond_length(bond: Chem.rdchem.Bond, mol: dm.Mol) -> float:
r"""
Estimate the bond length between atoms by looking at the estimated atomic radius
that depends both on the atom type and the bond type. The resulting bond-length is
then the sum of the radius.
Keep in mind that this function only provides an estimate of the bond length and not
the true one based on a conformer. The vast majority od estimated bond lengths will
have an error below 5% while some bonds can have an error up to 20%. This function
is mostly useful when conformer generation fails for some molecules, or for
increased computation speed.
Parameters:
bond: The bond to measure its lenght
mol: The molecule containing the bond (used to get neighbouring atoms)
Returns:
bond_length: The bond length in Angstrom, typically a value around 1-2.
"""
# Get the atoms connected by the bond
idx1 = bond.GetBeginAtomIdx()
idx2 = bond.GetEndAtomIdx()
atom1 = mol.GetAtomWithIdx(idx1).GetAtomicNum()
atom2 = mol.GetAtomWithIdx(idx2).GetAtomicNum()
bond_type = bond.GetBondType()
# Get single bond atomic radius
if bond_type == Chem.rdchem.BondType.SINGLE:
rad1 = [nmp.BOND_RADIUS_SINGLE[atom1 - 1]]
rad2 = [nmp.BOND_RADIUS_SINGLE[atom2 - 1]]
# Get double bond atomic radius
elif bond_type == Chem.rdchem.BondType.DOUBLE:
rad1 = [nmp.BOND_RADIUS_DOUBLE[atom1 - 1]]
rad2 = [nmp.BOND_RADIUS_DOUBLE[atom2 - 1]]
# Get triple bond atomic radius
elif bond_type == Chem.rdchem.BondType.TRIPLE:
rad1 = [nmp.BOND_RADIUS_TRIPLE[atom1 - 1]]
rad2 = [nmp.BOND_RADIUS_TRIPLE[atom2 - 1]]
# Get average of single bond and double bond atomic radius
elif bond_type == Chem.rdchem.BondType.AROMATIC:
rad1 = [nmp.BOND_RADIUS_SINGLE[atom1 - 1], nmp.BOND_RADIUS_DOUBLE[atom1 - 1]]
rad2 = [nmp.BOND_RADIUS_SINGLE[atom2 - 1], nmp.BOND_RADIUS_DOUBLE[atom2 - 1]]
# Average the bond lengths, while ignoring nans in case some missing value
rad1_float = [elem for elem in rad1 if elem is not None]
rad2_float = [elem for elem in rad2 if elem is not None]
if len(rad1_float) > 0:
rad1_float = sum(rad1_float) / len(rad1_float)
else:
rad1_float = float(nmp.BOND_RADIUS_SINGLE[atom1 - 1])
if len(rad2_float) > 0:
rad2_float = sum(rad2_float) / len(rad2_float)
else:
rad2_float = float(nmp.BOND_RADIUS_SINGLE[atom2 - 1])
bond_length = rad1_float + rad2_float
return bond_length
def get_mol_edge_features(
mol: dm.Mol, property_list: List[str], mask_nan: Union[str, float, type(None)] = "raise"
) -> Dict[str, np.ndarray]:
r"""
Get the following set of features for any given bond
See `graphium.features.nmp` for allowed values in one hot encoding
* One-hot representation of the bond type. Note that you should not kekulize your
molecules, if you expect this to take aromatic bond into account.
* Bond stereo type, following CIP classification
* Whether the bond is conjugated
* Whether the bond is in a ring
Parameters:
mol: rdkit.Chem.Molecule
the molecule of interest
property_list:
A list of edge properties to return for the given molecule.
Accepted properties are:
- "bond-type-onehot"
- "bond-type-float"
- "stereo"
- "in-ring"
- "conjugated"
- "conformer-bond-length" (might cause problems with complex molecules)
- "estimated-bond-length"
Returns:
prop_dict:
A dictionnary where the element of ``property_list`` are the keys
and the values are np.ndarray of shape (N,). N is the number of atoms
in ``mol``.
"""
prop_dict = {}
# Compute features for each bond
num_bonds = mol.GetNumBonds()
for prop in property_list:
property_array = []
for ii in range(num_bonds):
prop = prop.lower()
bond = mol.GetBondWithIdx(ii)
if prop in ["bond-type-onehot"]:
encoding = one_of_k_encoding(bond.GetBondType(), nmp.BOND_TYPES)
elif prop in ["bond-type-float"]:
encoding = [bond.GetBondTypeAsDouble()]
elif prop in ["stereo"]:
encoding = one_of_k_encoding(bond.GetStereo(), nmp.BOND_STEREO)
elif prop in ["in-ring"]:
encoding = [bond.IsInRing()]
elif prop in ["conjugated"]:
encoding = [bond.GetIsConjugated()]
elif prop in ["conformer-bond-length"]:
conf = get_simple_mol_conformer(mol)
if conf is not None:
idx1 = bond.GetBeginAtomIdx()
idx2 = bond.GetEndAtomIdx()
encoding = [Chem.rdMolTransforms.GetBondLength(conf, idx1, idx2)]
else:
encoding = [0]
elif prop in ["estimated-bond-length"]:
encoding = [get_estimated_bond_length(bond, mol)]
else:
raise ValueError(f"Unsupported property `{prop}`")
property_array.append(np.asarray(encoding, dtype=np.float16))
if num_bonds > 0:
property_array = np.stack(property_array, axis=0)
# Mask the NaNs
prop_dict[prop] = _mask_nans_inf(mask_nan, property_array, "edge property")
else:
# Add an empty vector with the right shape
arr_len = 1
if prop in ["bond-type-onehot"]:
arr_len = len(nmp.BOND_TYPES) + 1
elif prop in ["stereo"]:
arr_len = len(nmp.BOND_STEREO) + 1
prop_dict[prop] = np.zeros((0, arr_len))
return prop_dict
def mol_to_adj_and_features(
mol: Union[str, dm.Mol],
atom_property_list_onehot: List[str] = [],
atom_property_list_float: List[Union[str, Callable]] = [],
conformer_property_list: List[str] = [],
edge_property_list: List[str] = [],
add_self_loop: bool = False,
explicit_H: bool = False,
use_bonds_weights: bool = False,
pos_encoding_as_features: Dict[str, Any] = None,
dtype: np.dtype = np.float16,
mask_nan: Union[str, float, type(None)] = "raise",
) -> Union[
coo_matrix,
Union[Tensor, None],
Union[Tensor, None],
Dict[str, Tensor],
Union[Tensor, None],
Dict[str, Tensor],
]:
r"""
Transforms a molecule into an adjacency matrix representing the molecular graph
and a set of atom and bond features.
It also returns the positional encodings associated to the graph.
Parameters:
mol:
The molecule to be converted
atom_property_list_onehot:
List of the properties used to get one-hot encoding of the atom type,
such as the atom index represented as a one-hot vector.
See function `get_mol_atomic_features_onehot`
atom_property_list_float:
List of the properties used to get floating-point encoding of the atom type,
such as the atomic mass or electronegativity.
See function `get_mol_atomic_features_float`
conformer_property_list:
list of properties used to encode the conformer information, outside of atom properties, currently support "positions_3d"
edge_property_list:
List of the properties used to encode the edges, such as the edge type
and the stereo type.
add_self_loop:
Whether to add a value of `1` on the diagonal of the adjacency matrix.
explicit_H:
Whether to consider the Hydrogens explicitely. If `False`, the hydrogens
are implicit.
use_bonds_weights:
Whether to use the floating-point value of the bonds in the adjacency matrix,
such that single bonds are represented by 1, double bonds 2, triple 3, aromatic 1.5
pos_encoding_as_features: keyword arguments for function `graph_positional_encoder`
to generate positional encoding for node features.
dtype:
The torch data type used to build the graph
mask_nan:
Deal with molecules that fail a part of the featurization.
NaNs can happen when taking the of a noble gas,
or other properties that are not measured for specific atoms.
- "raise": Raise an error when there is a nan or inf in the featurization
- "warn": Raise a warning when there is a nan or inf in the featurization
- "None": DEFAULT. Don't do anything
- "Floating value": Replace nans or inf by the specified value
Returns:
adj:
torch coo sparse adjacency matrix of the molecule
ndata:
Concatenated node data of the atoms, based on the properties from
`atom_property_list_onehot` and `atom_property_list_float`.
If no properties are given, it returns `None`
edata:
Concatenated node edge of the molecule, based on the properties from
`edge_property_list`.
If no properties are given, it returns `None`
pe_dict:
Dictionary of all positional encodings. Current supported keys:
- "pos_enc_feats_sign_flip":
Node positional encoding that requires augmentation via sign-flip.
For example, eigenvectors of the Laplacian are ambiguous to the
sign and are returned here.
- "pos_enc_feats_no_flip":
Node positional encoding that requires does not use sign-flip.
For example, distance from centroid are returned here.
- "rwse":
Node structural encoding corresponding to the diagonal of the random
walk matrix
conf_dict:
contains the 3d positions of a conformer of the molecule or 0s if none is found
"""
if isinstance(mol, str):
mol = dm.to_mol(mol, ordered=True)
# Add or remove explicit hydrogens
if explicit_H:
mol = Chem.AddHs(mol)
else:
mol = Chem.RemoveHs(mol)
num_nodes = mol.GetNumAtoms()
adj = mol_to_adjacency_matrix(
mol, use_bonds_weights=use_bonds_weights, add_self_loop=add_self_loop, dtype=dtype
)
# Get the node features
atom_features_onehot = get_mol_atomic_features_onehot(mol, atom_property_list_onehot)
atom_features_float = get_mol_atomic_features_float(mol, atom_property_list_float, mask_nan=mask_nan)
conf_dict = get_mol_conformer_features(mol, conformer_property_list, mask_nan=mask_nan)
ndata = list(atom_features_float.values()) + list(atom_features_onehot.values())
ndata = [d[:, np.newaxis] if d.ndim == 1 else d for d in ndata]
if len(ndata) > 0:
ndata = np.concatenate(ndata, axis=1).astype(dtype=dtype)
else:
ndata = None
# Get the edge features
edge_features = get_mol_edge_features(mol, edge_property_list, mask_nan=mask_nan)
edata = list(edge_features.values())
edata = [np.expand_dims(d, axis=1) if d.ndim == 1 else d for d in edata]
if len(edata) > 0:
edata = np.concatenate(edata, axis=1).astype(dtype=dtype)
else:
edata = None
# Get all positional encodings
pe_dict = get_all_positional_encodings(adj, num_nodes, pos_encoding_as_features)
# Mask the NaNs
for pe_key, pe_val in pe_dict.items():
pe_val = np.asarray(pe_val, dtype=dtype)
pe_dict[pe_key] = _mask_nans_inf(mask_nan, pe_val, pe_key)
return adj, ndata, edata, pe_dict, conf_dict
def mol_to_adjacency_matrix(
mol: dm.Mol,
use_bonds_weights: bool = False,
add_self_loop: bool = False,
dtype: np.dtype = np.float32,
) -> coo_matrix:
r"""
Convert a molecule to a sparse adjacency matrix, as a torch Tensor.
Instead of using the Rdkit `GetAdjacencyMatrix()` method, this method
uses the bond ordering from the molecule object, which is the same as
the bond ordering in the bond features.
Warning:
Do not use `Tensor.coalesce()` on the returned adjacency matrix, as it
will change the ordering of the bonds.
Args:
mol: A molecule in the form of a SMILES string or an RDKit molecule object.
use_bonds_weights:
If `True`, the adjacency matrix will contain the bond type as the
value of the edge. If `False`, the adjacency matrix will contain
`1` as the value of the edge.
add_self_loop:
If `True`, the adjacency matrix will contain a self-loop for each
node.
dtype:
The data type used to build the graph
Returns:
adj:
coo sparse adjacency matrix of the molecule
"""
# Get the indices for the adjacency matrix, and the bond value
adj_idx, adj_val = [], []
for bond in mol.GetBonds():
adj_idx.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
adj_idx.append([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()])
if use_bonds_weights:
val = nmp.BOND_TYPES[bond.GetBondType()]
else:
val = 1
adj_val.extend([val, val])
# Convert to torch coo sparse tensor
if len(adj_val) > 0: # ensure tensor is not empty
adj = coo_matrix(
(torch.as_tensor(adj_val), torch.as_tensor(adj_idx).T.reshape(2, -1)),
shape=(mol.GetNumAtoms(), mol.GetNumAtoms()),
dtype=dtype,
)
else:
# Special case for molecules with one atom
adj = coo_matrix(([], np.array([[], []])), shape=(mol.GetNumAtoms(), mol.GetNumAtoms()), dtype=dtype)
# Add self loops
if add_self_loop:
arange = np.arange(adj.shape[0], dtype=int)
adj[arange, arange] = 1
return adj
class GraphDict(dict):
def __init__(
self,
dic: Dict,
):
r"""
Store the parameters required to initialize a `pyg.data.Data`, but
as a dictionary to reduce memory consumption.
Possible keys for the dictionary:
- adj: A sparse Tensor containing the adjacency matrix
- ndata: A dictionnary containing different keys and Tensors
associated to the node features.
- edata: A dictionnary containing different keys and Tensors
associated to the edge features.
- dtype: The dtype for the floating data.
- mask_nan:
Deal with molecules that fail a part of the featurization.
NaNs can happen when taking the of a noble gas,
or other properties that are not measured for specific atoms.
- "raise": Raise an error when there is a nan or inf in the featurization
- "warn": Raise a warning when there is a nan or inf in the featurization
- "None": DEFAULT. Don't do anything
- "Floating value": Replace nans or inf by the specified value
"""
default_dic = {
"dtype": np.float16,
"mask_nan": "raise",
}
data = dic.pop("data", {})
# ndata = dic.pop("ndata", {})
# edata = dic.pop("edata", {})
# for key in edata.keys():
# assert key.startswith("edge_"), f"Edge features must start with 'edge_' but got {key}"
default_dic.update(dic)
default_dic.update(data)
# default_dic.update(ndata)
# default_dic.update(edata)
super().__init__(default_dic)
@property
def keys(self):
return list(super().keys())
@property
def values(self):
return list(super().self.values())
def make_pyg_graph(self, **kwargs) -> Data:
"""
Convert the current dictionary of parameters, containing an adjacency matrix with node/edge data
into a `pyg.data.Data` of torch Tensors.
`**kwargs` can be used to overwrite any parameter from the current dictionary. See `GraphDict.__init__`
for a list of parameters
"""
num_nodes = self.adj.shape[0]
data_dict = {}
# Convert the numpy and numpy sparse data to torch
for key, val in self.items():
if key in ["adj", "dtype", "mask_nan"]: # Skip the parameters
continue
elif isinstance(val, np.ndarray):
# Convert the data to the specified dtype in torch format
val = val.astype(self.dtype)
data_dict[key] = torch.as_tensor(val)
elif issparse(val):
data_dict[key] = torch.as_tensor(val.astype(np.float32).todense())
# `torch.sparse_coo_tensor` is too slow. Slows down the multiprocessing of features by >3x on 32 cores.
# indices = torch.from_numpy(np.vstack((val.row, val.col)).astype(np.int64))
# data_dict[key] = torch.sparse_coo_tensor(indices=indices, values=val.data, size=val.shape)
elif isinstance(val, torch.Tensor):
data_dict[key] = val
else:
pass # Skip the other parameters
# Create the PyG graph object `Data`
edge_index = torch.as_tensor(np.vstack((self.adj.row, self.adj.col)))
edge_weight = torch.as_tensor(self.adj.data)
data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes, **data_dict)
return data
@property
def adj(self):
return self["adj"]
@property
def dtype(self):
return self["dtype"]
@property
def mask_nan(self):
return self["mask_nan"]
@property
def num_nodes(self) -> int:
return self.adj.shape[0]
@property
def num_edges(self) -> int:
if issparse(self.adj):
return self.adj.nnz
else:
return np.count_nonzero(self.adj) # No division by 2 because edges are counted twice
def mol_to_graph_dict(
mol: dm.Mol,
atom_property_list_onehot: List[str] = [],
atom_property_list_float: List[Union[str, Callable]] = [],
conformer_property_list: List[str] = [],
edge_property_list: List[str] = [],
add_self_loop: bool = False,
explicit_H: bool = False,
use_bonds_weights: bool = False,
pos_encoding_as_features: Dict[str, Any] = None,
dtype: np.dtype = np.float16,
on_error: str = "ignore",
mask_nan: Union[str, float, type(None)] = "raise",
max_num_atoms: Optional[int] = None,
) -> Union[GraphDict, str]:
r"""
Transforms a molecule into an adjacency matrix representing the molecular graph
and a set of atom and bond features, and re-organizes them into a dictionary
that allows to build a `pyg.data.Data` object.
Compared to `mol_to_pyggraph`, this function does not build the graph directly,
and is thus faster, less memory heavy, and compatible with other frameworks.
Parameters:
mol:
The molecule to be converted
atom_property_list_onehot:
List of the properties used to get one-hot encoding of the atom type,
such as the atom index represented as a one-hot vector.
See function `get_mol_atomic_features_onehot`
atom_property_list_float:
List of the properties used to get floating-point encoding of the atom type,
such as the atomic mass or electronegativity.
See function `get_mol_atomic_features_float`
conformer_property_list:
list of properties used to encode the conformer information, outside of atom properties, currently support "positions_3d"
edge_property_list:
List of the properties used to encode the edges, such as the edge type
and the stereo type.
add_self_loop:
Whether to add a value of `1` on the diagonal of the adjacency matrix.
explicit_H:
Whether to consider the Hydrogens explicitely. If `False`, the hydrogens
are implicit.
use_bonds_weights:
Whether to use the floating-point value of the bonds in the adjacency matrix,
such that single bonds are represented by 1, double bonds 2, triple 3, aromatic 1.5
pos_encoding_as_features: keyword arguments for function `graph_positional_encoder`
to generate positional encoding for node features.
dtype:
The numpy data type used to build the graph
on_error:
What to do when the featurization fails. This can change the
behavior of `mask_nan`.
- "raise": Raise an error
- "warn": Raise a warning and return a string of the error
- "ignore": Ignore the error and return a string of the error
mask_nan:
Deal with molecules that fail a part of the featurization.
NaNs can happen when taking the of a noble gas,
or other properties that are not measured for specific atoms.
- "raise": Raise an error when there is a nan or inf in the featurization
- "warn": Raise a warning when there is a nan or inf in the featurization
- "None": DEFAULT. Don't do anything
- "Floating value": Replace nans or inf by the specified value
max_num_atoms:
Maximum number of atoms for a given molecule. If a molecule with more atoms
is give, an error is raised, but catpured according to the rules of
`on_error`.
Returns:
graph_dict:
A dictionary `GraphDict` containing the keys required to build a graph,
and which can be used to build a PyG graph. If it fails
to featurize the molecule, it returns a string with the error.
- "adj": A sparse int-array containing the adjacency matrix
- "data": A dictionnary containing different keys and numpy
arrays associated to the (node, edge & graph) features.
- "dtype": The numpy dtype for the floating data.
"""
input_mol = mol
try:
if isinstance(mol, str):
mol = dm.to_mol(mol, ordered=True)
if explicit_H:
mol = Chem.AddHs(mol)
else:
mol = Chem.RemoveHs(mol)
num_atoms = mol.GetNumAtoms()
if (max_num_atoms is not None) and (num_atoms > max_num_atoms):
raise ValueError(f"Maximum number of atoms greater than permitted {num_atoms}>{max_num_atoms}")
(
adj,
ndata,
edata,
pe_dict,
conf_dict,
) = mol_to_adj_and_features(
mol=mol,
atom_property_list_onehot=atom_property_list_onehot,
atom_property_list_float=atom_property_list_float,
conformer_property_list=conformer_property_list,
edge_property_list=edge_property_list,
add_self_loop=add_self_loop,
explicit_H=explicit_H,
use_bonds_weights=use_bonds_weights,
pos_encoding_as_features=pos_encoding_as_features,
mask_nan=mask_nan,
)
except Exception as e:
if on_error.lower() == "raise":
raise e
elif on_error.lower() == "warn":
smiles = input_mol
if isinstance(smiles, dm.Mol):
smiles = Chem.MolToSmiles(input_mol)
msg = str(e) + "\nIgnoring following molecule:" + smiles
logger.warning(msg)
return str(e)
elif on_error.lower() == "ignore":
return str(e)
graph_dict = {"adj": adj, "data": {}, "dtype": dtype}
# Assign the node data
if ndata is not None:
graph_dict["data"]["feat"] = ndata
# Assign the edge data
if edata is not None:
if issparse(edata):
edata = to_dense_array(edata, dtype=dtype)
hetero_edata = edata.repeat(2, axis=0)
graph_dict["data"]["edge_feat"] = hetero_edata
# Put the positional encodings as node features
# TODO: add support for PE on edges
for key, pe in pe_dict.items():
graph_dict["data"][key] = pe
# put the conformer positions here
for key, val in conf_dict.items():
graph_dict["data"][key] = val
graph_dict = GraphDict(graph_dict)
return graph_dict
def mol_to_pyggraph(
mol: dm.Mol,
atom_property_list_onehot: List[str] = [],
atom_property_list_float: List[Union[str, Callable]] = [],
conformer_property_list: List[str] = [],
edge_property_list: List[str] = [],
add_self_loop: bool = False,
explicit_H: bool = False,
use_bonds_weights: bool = False,
pos_encoding_as_features: Dict[str, Any] = None,
dtype: np.dtype = np.float16,
on_error: str = "ignore",
mask_nan: Union[str, float, type(None)] = "raise",
max_num_atoms: Optional[int] = None,
) -> Union[Data, str]:
r"""
Transforms a molecule into an adjacency matrix representing the molecular graph
and a set of atom and bond features.
Then, the adjacency matrix and node/edge features are used to build a
`pyg.data.Data` with pytorch Tensors.
Parameters:
mol:
The molecule to be converted
atom_property_list_onehot:
List of the properties used to get one-hot encoding of the atom type,
such as the atom index represented as a one-hot vector.
See function `get_mol_atomic_features_onehot`
atom_property_list_float:
List of the properties used to get floating-point encoding of the atom type,
such as the atomic mass or electronegativity.
See function `get_mol_atomic_features_float`
conformer_property_list:
list of properties used to encode the conformer information, outside of atom properties, currently support "positions_3d"
edge_property_list:
List of the properties used to encode the edges, such as the edge type
and the stereo type.
add_self_loop:
Whether to add a value of `1` on the diagonal of the adjacency matrix.
explicit_H:
Whether to consider the Hydrogens explicitely. If `False`, the hydrogens
are implicit.
use_bonds_weights:
Whether to use the floating-point value of the bonds in the adjacency matrix,
such that single bonds are represented by 1, double bonds 2, triple 3, aromatic 1.5
pos_encoding_as_features: keyword arguments for function `graph_positional_encoder`
to generate positional encoding for node features.
dtype:
The numpy data type used to build the graph
on_error:
What to do when the featurization fails. This can change the
behavior of `mask_nan`.
- "raise": Raise an error
- "warn": Raise a warning and return a string of the error
- "ignore": Ignore the error and return a string of the error
mask_nan:
Deal with molecules that fail a part of the featurization.
NaNs can happen when taking the of a noble gas,
or other properties that are not measured for specific atoms.
- "raise": Raise an error when there is a nan in the featurization
- "warn": Raise a warning when there is a nan in the featurization
- "None": DEFAULT. Don't do anything
- "Floating value": Replace nans by the specified value
max_num_atoms:
Maximum number of atoms for a given molecule. If a molecule with more atoms
is give, an error is raised, but catpured according to the rules of
`on_error`.
Returns:
graph:
Pyg graph, with `graph['feat']` corresponding to the concatenated
node data from `atom_property_list_onehot` and `atom_property_list_float`,
`graph['edge_feat']` corresponding to the concatenated edge data from `edge_property_list`.
There are also additional entries for the positional encodings.
"""
graph_dict = mol_to_graph_dict(
mol=mol,
atom_property_list_onehot=atom_property_list_onehot,
atom_property_list_float=atom_property_list_float,
conformer_property_list=conformer_property_list,
edge_property_list=edge_property_list,
add_self_loop=add_self_loop,
explicit_H=explicit_H,
use_bonds_weights=use_bonds_weights,
pos_encoding_as_features=pos_encoding_as_features,
dtype=dtype,
on_error=on_error,
mask_nan=mask_nan,
max_num_atoms=max_num_atoms,
)
if (graph_dict is not None) and not isinstance(graph_dict, str):
return graph_dict.make_pyg_graph()
else:
return graph_dict
def mol_to_graph_signature(featurizer_args: Dict[str, Any] = None) -> Dict[str, Any]:
"""
Get the default arguments of `mol_to_graph_dict` and update it
with a provided dict of arguments in order to get a fulle signature
of the featurizer args actually used for the features computation.
Parameters:
featurizer_args: A dictionary of featurizer arguments to update
Returns:
A dictionary of featurizer arguments
"""
# Get the signature of `mol_to_graph_dict`
signature = inspect.signature(mol_to_graph_dict)
# Filter out empty arguments (without default value)
parameters = list(filter(lambda param: param.default is not param.empty, signature.parameters.values()))
# Convert to dict
parameters = {param.name: param.default for param in parameters}
# Update the parameters with the supplied ones
if featurizer_args is not None:
parameters.update(featurizer_args)
return parameters
================================================
FILE: graphium/features/graphormer.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Tuple, Union, Dict, Any
import numpy as np
import networkx as nx
from scipy.sparse import spmatrix, issparse
def compute_graphormer_distances(
adj: Union[np.ndarray, spmatrix], num_nodes: int, cache: Dict[str, Any]
) -> Tuple[np.ndarray, str, Dict[str, Any]]:
"""
Compute Graphormer distance between nodepairs.
Parameters:
adj [num_nodes, num_nodes]: Adjacency matrix
num_nodes: Number of nodes in the graph
cache: Dictionary of cached objects
Returns:
dist [num_nodes, num_nodes]: 2D array with Graphormer distances between nodepairs
base_level: Indicator of the output pos_level (node, edge, nodepair, graph) -> here nodepair
cache: Updated dictionary of cached objects
"""
base_level = "nodepair"
if "graphormer" in cache:
dist = cache["graphormer"]
else:
if issparse(adj):
adj = adj.toarray()
G = nx.from_numpy_array(adj)
paths = nx.all_pairs_shortest_path(G)
dist_dict = {i: {j: len(path) - 1 for j, path in paths_from_i.items()} for i, paths_from_i in paths}
dist = np.asarray([[dist_dict[i][j] for j in range(num_nodes)] for i in range(num_nodes)])
cache["graphormer"] = dist
return dist, base_level, cache
================================================
FILE: graphium/features/nmp.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Tuple, Optional, Dict, Union
import importlib.resources
from copy import deepcopy
import pandas as pd
import numpy as np
import math
from rdkit import Chem
# NOTE(hadim): usually it's best to embed this in a function.
with importlib.resources.open_text("graphium.features", "periodic_table.csv") as f:
PERIODIC_TABLE = pd.read_csv(f)
PERIODIC_TABLE = PERIODIC_TABLE.set_index("AtomicNumber")
# Small function to convert strings to floats
def float_or_none(string: str) -> Union[float, None]:
"""
check if a string can be converted to float, return none if it can't
Parameters:
string: str
Returns:
val: float or None
"""
try:
val = float(string)
except:
val = None
return val
# It's much faster to index from a list than a DataFrame, which can have big impact
# when featurizing millions of molecules
BOND_RADIUS_SINGLE = [float_or_none(elem) for elem in PERIODIC_TABLE["SingleBondRadius"]]
BOND_RADIUS_DOUBLE = [float_or_none(elem) for elem in PERIODIC_TABLE["DoubleBondRadius"]]
BOND_RADIUS_TRIPLE = [float_or_none(elem) for elem in PERIODIC_TABLE["TripleBondRadius"]]
ELECTRONEGATIVITY = [float_or_none(elem) for elem in PERIODIC_TABLE["Electronegativity"]]
FIRST_IONIZATION = [float_or_none(elem) for elem in PERIODIC_TABLE["FirstIonization"]]
MELTING_POINT = [float_or_none(elem) for elem in PERIODIC_TABLE["MeltingPoint"]]
METAL = (2 * (PERIODIC_TABLE["Metal"] == "yes") + (PERIODIC_TABLE["Metalloid"] == "yes")).tolist()
PHASE = list(PERIODIC_TABLE["Phase"].values)
PHASE_SET = list(set(PHASE))
TYPE = list(deepcopy(PERIODIC_TABLE["Type"]))
TYPE = ["none" if (isinstance(t, float) and math.isnan(t)) else t for t in TYPE]
TYPE_SET = list(set(TYPE))
GROUP = deepcopy(PERIODIC_TABLE["Group"].values)
GROUP[np.isnan(GROUP)] = 19
GROUP_SET = list(set(GROUP))
PERIOD = list(PERIODIC_TABLE["Period"].values)
PERIOD_SET = list(set(PERIOD))
ATOM_LIST = [
"C",
"N",
"O",
"S",
"F",
"Si",
"P",
"Cl",
"Br",
"Mg",
"Na",
"Ca",
"Fe",
"As",
"Al",
"I",
"B",
"V",
"K",
"Tl",
"Yb",
"Sb",
"Sn",
"Ag",
"Pd",
"Co",
"Se",
"Ti",
"Zn",
"H",
"Li",
"Ge",
"Cu",
"Au",
"Ni",
"Cd",
"In",
"Mn",
"Zr",
"Cr",
"Pt",
"Hg",
"Pb",
]
ATOM_NUM_H = [0, 1, 2, 3, 4]
VALENCE = [0, 1, 2, 3, 4, 5, 6]
CHARGE_LIST = [-3, -2, -1, 0, 1, 2, 3]
RADICAL_E_LIST = [0, 1, 2]
ATOM_DEGREE_LIST = [0, 1, 2, 3, 4]
HYBRIDIZATION_LIST = [
Chem.rdchem.HybridizationType.names[k]
for k in sorted(Chem.rdchem.HybridizationType.names.keys(), reverse=True)
if k != "OTHER"
]
CHIRALITY_LIST = ["R"] # alternative is just S
BOND_TYPES = [
Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC,
]
BOND_STEREO = [
Chem.rdchem.BondStereo.STEREONONE,
Chem.rdchem.BondStereo.STEREOANY,
Chem.rdchem.BondStereo.STEREOZ,
Chem.rdchem.BondStereo.STEREOE,
Chem.rdchem.BondStereo.STEREOCIS,
Chem.rdchem.BondStereo.STEREOTRANS,
]
================================================
FILE: graphium/features/periodic_table.csv
================================================
AtomicNumber,Element,Symbol,AtomicMass,NumberofNeutrons,NumberofProtons,NumberofElectrons,Period,Group,Phase,Radioactive,Natural,Metal,Nonmetal,Metalloid,Type,AtomicRadius,Electronegativity,FirstIonization,Density,MeltingPoint,BoilingPoint,NumberOfIsotopes,Discoverer,Year,SpecificHeat,NumberofShells,NumberofValence,SingleBondRadius,DoubleBondRadius,TripleBondRadius
1,Hydrogen,H,1.007,0,1,1,1,1,gas,,yes,,yes,,Nonmetal,0.79,2.2,13.5984,8.99E-05,14.175,20.28,3,Cavendish,1766,14.304,1,1,0.32,-,-
2,Helium,He,4.002,2,2,2,1,18,gas,,yes,,yes,,Noble Gas,0.49,,24.5874,1.79E-04,,4.22,5,Janssen,1868,5.193,1,,0.46,-,-
3,Lithium,Li,6.941,4,3,3,2,1,solid,,yes,yes,,,Alkali Metal,2.1,0.98,5.3917,5.34E-01,453.85,1615,5,Arfvedson,1817,3.582,2,1,1.33,1.24,-
4,Beryllium,Be,9.012,5,4,4,2,2,solid,,yes,yes,,,Alkaline Earth Metal,1.4,1.57,9.3227,1.85E+00,1560.15,2742,6,Vaulquelin,1798,1.825,2,2,1.02,0.9,0.85
5,Boron,B,10.811,6,5,5,2,13,solid,,yes,,,yes,Metalloid,1.2,2.04,8.298,2.34E+00,2573.15,4200,6,Gay-Lussac,1808,1.026,2,3,0.85,0.78,0.73
6,Carbon,C,12.011,6,6,6,2,14,solid,,yes,,yes,,Nonmetal,0.91,2.55,11.2603,2.27E+00,3948.15,4300,7,Prehistoric,,0.709,2,4,0.75,0.67,0.6
7,Nitrogen,N,14.007,7,7,7,2,15,gas,,yes,,yes,,Nonmetal,0.75,3.04,14.5341,1.25E-03,63.29,77.36,8,Rutherford,1772,1.04,2,5,0.71,0.6,0.54
8,Oxygen,O,15.999,8,8,8,2,16,gas,,yes,,yes,,Nonmetal,0.65,3.44,13.6181,1.43E-03,50.5,90.2,8,Priestley/Scheele,1774,0.918,2,6,0.63,0.57,0.53
9,Fluorine,F,18.998,10,9,9,2,17,gas,,yes,,yes,,Halogen,0.57,3.98,17.4228,1.70E-03,53.63,85.03,6,Moissan,1886,0.824,2,7,0.64,0.59,0.53
10,Neon,Ne,20.18,10,10,10,2,18,gas,,yes,,yes,,Noble Gas,0.51,,21.5645,9.00E-04,24.703,27.07,8,Ramsay and Travers,1898,1.03,2,8,0.67,0.96,-
11,Sodium,Na,22.99,12,11,11,3,1,solid,,yes,yes,,,Alkali Metal,2.2,0.93,5.1391,9.71E-01,371.15,1156,7,Davy,1807,1.228,3,1,1.55,1.6,-
12,Magnesium,Mg,24.305,12,12,12,3,2,solid,,yes,yes,,,Alkaline Earth Metal,1.7,1.31,7.6462,1.74E+00,923.15,1363,8,Black,1755,1.023,3,2,1.39,1.32,1.27
13,Aluminum,Al,26.982,14,13,13,3,13,solid,,yes,yes,,,Metal,1.8,1.61,5.9858,2.70E+00,933.4,2792,8,Wshler,1827,0.897,3,3,1.26,1.13,1.11
14,Silicon,Si,28.086,14,14,14,3,14,solid,,yes,,,yes,Metalloid,1.5,1.9,8.1517,2.33E+00,1683.15,3538,8,Berzelius,1824,0.705,3,4,1.16,1.07,1.02
15,Phosphorus,P,30.974,16,15,15,3,15,solid,,yes,,yes,,Nonmetal,1.2,2.19,10.4867,1.82E+00,317.25,553,7,BranBrand,1669,0.769,3,5,1.11,1.02,0.94
16,Sulfur,S,32.065,16,16,16,3,16,solid,,yes,,yes,,Nonmetal,1.1,2.58,10.36,2.07E+00,388.51,717.8,10,Prehistoric,,0.71,3,6,1.03,0.94,0.95
17,Chlorine,Cl,35.453,18,17,17,3,17,gas,,yes,,yes,,Halogen,0.97,3.16,12.9676,3.21E-03,172.31,239.11,11,Scheele,1774,0.479,3,7,0.99,0.95,0.93
18,Argon,Ar,39.948,22,18,18,3,18,gas,,yes,,yes,,Noble Gas,0.88,,15.7596,1.78E-03,83.96,87.3,8,Rayleigh and Ramsay,1894,0.52,3,8,0.96,1.07,0.96
19,Potassium,K,39.098,20,19,19,4,1,solid,,yes,yes,,,Alkali Metal,2.8,0.82,4.3407,8.62E-01,336.5,1032,10,Davy,1807,0.757,4,1,1.96,1.93,-
20,Calcium,Ca,40.078,20,20,20,4,2,solid,,yes,yes,,,Alkaline Earth Metal,2.2,1,6.1132,1.54E+00,1112.15,1757,14,Davy,1808,0.647,4,2,1.71,1.47,1.33
21,Scandium,Sc,44.956,24,21,21,4,3,solid,,yes,yes,,,Transition Metal,2.1,1.36,6.5615,2.99E+00,1812.15,3109,15,Nilson,1878,0.568,4,,1.48,1.16,1.14
22,Titanium,Ti,47.867,26,22,22,4,4,solid,,yes,yes,,,Transition Metal,2,1.54,6.8281,4.54E+00,1933.15,3560,9,Gregor,1791,0.523,4,,1.36,1.17,1.08
23,Vanadium,V,50.942,28,23,23,4,5,solid,,yes,yes,,,Transition Metal,1.9,1.63,6.7462,6.11E+00,2175.15,3680,9, del Rio,1801,0.489,4,,1.34,1.12,1.06
24,Chromium,Cr,51.996,28,24,24,4,6,solid,,yes,yes,,,Transition Metal,1.9,1.66,6.7665,7.15E+00,2130.15,2944,9,Vauquelin,1797,0.449,4,,1.22,1.11,1.03
25,Manganese,Mn,54.938,30,25,25,4,7,solid,,yes,yes,,,Transition Metal,1.8,1.55,7.434,7.44E+00,1519.15,2334,11,"Gahn, Scheele",1774,0.479,4,,1.19,1.05,1.03
26,Iron,Fe,55.845,30,26,26,4,8,solid,,yes,yes,,,Transition Metal,1.7,1.83,7.9024,7.87E+00,1808.15,3134,10,Prehistoric,,0.449,4,,1.16,1.09,1.02
27,Cobalt,Co,58.933,32,27,27,4,9,solid,,yes,yes,,,Transition Metal,1.7,1.88,7.881,8.86E+00,1768.15,3200,14,Brandt,1735,0.421,4,,1.11,1.03,0.96
28,Nickel,Ni,58.693,31,28,28,4,10,solid,,yes,yes,,,Transition Metal,1.6,1.91,7.6398,8.91E+00,1726.15,3186,11,Cronstedt,1751,0.444,4,,1.1,1.01,1.01
29,Copper,Cu,63.546,35,29,29,4,11,solid,,yes,yes,,,Transition Metal,1.6,1.9,7.7264,8.96E+00,1357.75,2835,11,Prehistoric,,0.385,4,,1.12,1.15,1.2
30,Zinc,Zn,65.38,35,30,30,4,12,solid,,yes,yes,,,Transition Metal,1.5,1.65,9.3942,7.13E+00,692.88,1180,15,Prehistoric,,0.388,4,,1.18,1.2,-
31,Gallium,Ga,69.723,39,31,31,4,13,solid,,yes,yes,,,Metal,1.8,1.81,5.9993,5.91E+00,302.91,2477,14,de Boisbaudran,1875,0.371,4,3,1.24,1.17,1.21
32,Germanium,Ge,72.64,41,32,32,4,14,solid,,yes,,,yes,Metalloid,1.5,2.01,7.8994,5.32E+00,1211.45,3106,17,Winkler,1886,0.32,4,4,1.21,1.11,1.14
33,Arsenic,As,74.922,42,33,33,4,15,solid,,yes,,,yes,Metalloid,1.3,2.18,9.7886,5.78E+00,1090.15,887,14,Albertus Magnus,1250,0.329,4,5,1.21,1.14,1.06
34,Selenium,Se,78.96,45,34,34,4,16,solid,,yes,,yes,,Nonmetal,1.2,2.55,9.7524,4.81E+00,494.15,958,20,Berzelius,1817,0.321,4,6,1.16,1.07,1.07
35,Bromine,Br,79.904,45,35,35,4,17,liq,,yes,,yes,,Halogen,1.1,2.96,11.8138,3.12E+00,266.05,332,19,Balard,1826,0.474,4,7,1.14,1.09,1.1
36,Krypton,Kr,83.798,48,36,36,4,18,gas,,yes,,yes,,Noble Gas,1,,13.9996,3.73E-03,115.93,119.93,23,Ramsay and Travers,1898,0.248,4,8,1.17,1.21,1.08
37,Rubidium,Rb,85.468,48,37,37,5,1,solid,,yes,yes,,,Alkali Metal,3,0.82,4.1771,1.53E+00,312.79,961,20,Bunsen and Kirchoff,1861,0.363,5,1,2.1,2.02,-
38,Strontium,Sr,87.62,50,38,38,5,2,solid,,yes,yes,,,Alkaline Earth Metal,2.5,0.95,5.6949,2.64E+00,1042.15,1655,18,Davy,1808,0.301,5,2,1.85,1.57,1.39
39,Yttrium,Y,88.906,50,39,39,5,3,solid,,yes,yes,,,Transition Metal,2.3,1.22,6.2173,4.47E+00,1799.15,3609,21,Gadolin,1794,0.298,5,,1.63,1.3,1.24
40,Zirconium,Zr,91.224,51,40,40,5,4,solid,,yes,yes,,,Transition Metal,2.2,1.33,6.6339,6.51E+00,2125.15,4682,20,Klaproth,1789,0.278,5,,1.54,1.27,1.21
41,Niobium,Nb,92.906,52,41,41,5,5,solid,,yes,yes,,,Transition Metal,2.1,1.6,6.7589,8.57E+00,2741.15,5017,24,Hatchett,1801,0.265,5,,1.47,1.25,1.16
42,Molybdenum,Mo,95.96,54,42,42,5,6,solid,,yes,yes,,,Transition Metal,2,2.16,7.0924,1.02E+01,2890.15,4912,20,Scheele,1778,0.251,5,,1.38,1.21,1.13
43,Technetium,Tc,98,55,43,43,5,7,artificial,yes,,yes,,,Transition Metal,2,1.9,7.28,1.15E+01,2473.15,5150,23,Perrier and Segr?,1937,,5,,1.28,1.2,1.1
44,Ruthenium,Ru,101.07,57,44,44,5,8,solid,,yes,yes,,,Transition Metal,1.9,2.2,7.3605,1.24E+01,2523.15,4423,16,Klaus,1844,0.238,5,,1.25,1.14,1.03
45,Rhodium,Rh,102.906,58,45,45,5,9,solid,,yes,yes,,,Transition Metal,1.8,2.28,7.4589,1.24E+01,2239.15,3968,20,Wollaston,1803,0.243,5,,1.25,1.1,1.06
46,Palladium,Pd,106.42,60,46,46,5,10,solid,,yes,yes,,,Transition Metal,1.8,2.2,8.3369,1.20E+01,1825.15,3236,21,Wollaston,1803,0.244,5,,1.2,1.17,1.12
47,Silver,Ag,107.868,61,47,47,5,11,solid,,yes,yes,,,Transition Metal,1.8,1.93,7.5762,1.05E+01,1234.15,2435,27,Prehistoric,,0.235,5,,1.28,1.39,1.37
48,Cadmium,Cd,112.411,64,48,48,5,12,solid,,yes,yes,,,Transition Metal,1.7,1.69,8.9938,8.69E+00,594.33,1040,22,Stromeyer,1817,0.232,5,,1.36,1.44,-
49,Indium,In,114.818,66,49,49,5,13,solid,,yes,yes,,,Metal,2,1.78,5.7864,7.31E+00,429.91,2345,34,Reich and Richter,1863,0.233,5,3,1.42,1.36,1.46
50,Tin,Sn,118.71,69,50,50,5,14,solid,,yes,yes,,,Metal,1.7,1.96,7.3439,7.29E+00,505.21,2875,28,Prehistoric,,0.228,5,4,1.4,1.3,1.32
51,Antimony,Sb,121.76,71,51,51,5,15,solid,,yes,,,yes,Metalloid,1.5,2.05,8.6084,6.69E+00,904.05,1860,29,Early historic times,,0.207,5,5,1.4,1.33,1.27
52,Tellurium,Te,127.6,76,52,52,5,16,solid,,yes,,,yes,Metalloid,1.4,2.1,9.0096,6.23E+00,722.8,1261,29,von Reichenstein,1782,0.202,5,6,1.36,1.28,1.21
53,Iodine,I,126.904,74,53,53,5,17,solid,,yes,,yes,,Halogen,1.3,2.66,10.4513,4.93E+00,386.65,457.4,24,Courtois,1811,0.214,5,7,1.33,1.29,1.25
54,Xenon,Xe,131.293,77,54,54,5,18,gas,,yes,,yes,,Noble Gas,1.2,,12.1298,5.89E-03,161.45,165.03,31,Ramsay and Travers,1898,0.158,5,8,1.31,1.35,1.22
55,Cesium,Cs,132.905,78,55,55,6,1,solid,,yes,yes,,,Alkali Metal,3.3,0.79,3.8939,1.87E+00,301.7,944,22,Bunsen and Kirchoff,1860,0.242,6,1,2.32,2.09,-
56,Barium,Ba,137.327,81,56,56,6,2,solid,,yes,yes,,,Alkaline Earth Metal,2.8,0.89,5.2117,3.59E+00,1002.15,2170,25,Davy,1808,0.204,6,2,1.96,1.61,1.49
57,Lanthanum,La,138.905,82,57,57,6,3,solid,,yes,yes,,,Lanthanide,2.7,1.1,5.5769,6.15E+00,1193.15,3737,19,Mosander,1839,0.195,6,,1.8,1.39,1.39
58,Cerium,Ce,140.116,82,58,58,6,,solid,,yes,yes,,,Lanthanide,2.7,1.12,5.5387,6.77E+00,1071.15,3716,19,Berzelius,1803,0.192,6,,1.63,1.37,1.31
59,Praseodymium,Pr,140.908,82,59,59,6,,solid,,yes,yes,,,Lanthanide,2.7,1.13,5.473,6.77E+00,1204.15,3793,15,von Welsbach,1885,0.193,6,,1.76,1.38,1.28
60,Neodymium,Nd,144.242,84,60,60,6,,solid,,yes,yes,,,Lanthanide,2.6,1.14,5.525,7.01E+00,1289.15,3347,16,von Welsbach,1885,0.19,6,,1.74,1.37,-
61,Promethium,Pm,145,84,61,61,6,,artificial,yes,,yes,,,Lanthanide,2.6,1.13,5.582,7.26E+00,1204.15,3273,14,Marinsky et al.,1945,,6,,1.73,1.35,-
62,Samarium,Sm,150.36,88,62,62,6,,solid,,yes,yes,,,Lanthanide,2.6,1.17,5.6437,7.52E+00,1345.15,2067,17,Boisbaudran,1879,0.197,6,,1.72,1.34,-
63,Europium,Eu,151.964,89,63,63,6,,solid,,yes,yes,,,Lanthanide,2.6,1.2,5.6704,5.24E+00,1095.15,1802,21,Demarcay,1901,0.182,6,,1.68,1.34,-
64,Gadolinium,Gd,157.25,93,64,64,6,,solid,,yes,yes,,,Lanthanide,2.5,1.2,6.1501,7.90E+00,1585.15,3546,17,de Marignac,1880,0.236,6,,1.69,1.35,1.32
65,Terbium,Tb,158.925,94,65,65,6,,solid,,yes,yes,,,Lanthanide,2.5,1.2,5.8638,8.23E+00,1630.15,3503,24,Mosander,1843,0.182,6,,1.68,1.35,-
66,Dysprosium,Dy,162.5,97,66,66,6,,solid,,yes,yes,,,Lanthanide,2.5,1.22,5.9389,8.55E+00,1680.15,2840,21,de Boisbaudran,1886,0.17,6,,1.67,1.33,-
67,Holmium,Ho,164.93,98,67,67,6,,solid,,yes,yes,,,Lanthanide,2.5,1.23,6.0215,8.80E+00,1743.15,2993,29,Delafontaine and Soret,1878,0.165,6,,1.66,1.33,-
68,Erbium,Er,167.259,99,68,68,6,,solid,,yes,yes,,,Lanthanide,2.5,1.24,6.1077,9.07E+00,1795.15,3503,16,Mosander,1843,0.168,6,,1.65,1.33,-
69,Thulium,Tm,168.934,100,69,69,6,,solid,,yes,yes,,,Lanthanide,2.4,1.25,6.1843,9.32E+00,1818.15,2223,18,Cleve,1879,0.16,6,,1.64,1.31,-
70,Ytterbium,Yb,173.054,103,70,70,6,,solid,,yes,yes,,,Lanthanide,2.4,1.1,6.2542,6.97E+00,1097.15,1469,16,Marignac,1878,0.155,6,,1.7,1.29,-
71,Lutetium,Lu,174.967,104,71,71,6,,solid,,yes,yes,,,Lanthanide,2.3,1.27,5.4259,9.84E+00,1936.15,3675,22,Urbain/ von Welsbach,1907,0.154,6,,1.62,1.31,1.31
72,Hafnium,Hf,178.49,106,72,72,6,4,solid,,yes,yes,,,Transition Metal,2.2,1.3,6.8251,1.33E+01,2500.15,4876,17,Coster and von Hevesy,1923,0.144,6,,1.52,1.28,1.22
73,Tantalum,Ta,180.948,108,73,73,6,5,solid,,yes,yes,,,Transition Metal,2.1,1.5,7.5496,1.67E+01,3269.15,5731,19,Ekeberg,1801,0.14,6,,1.46,1.26,1.19
74,Wolfram,W,183.84,110,74,74,6,6,solid,,yes,yes,,,Transition Metal,2,2.36,7.864,1.93E+01,3680.15,5828,22,J. and F. d'Elhuyar,1783,0.132,6,,1.37,1.2,1.15
75,Rhenium,Re,186.207,111,75,75,6,7,solid,,yes,yes,,,Transition Metal,2,1.9,7.8335,2.10E+01,3453.15,5869,21,"Noddack, Berg, and Tacke",1925,0.137,6,,1.31,1.19,1.1
76,Osmium,Os,190.23,114,76,76,6,8,solid,,yes,yes,,,Transition Metal,1.9,2.2,8.4382,2.26E+01,3300.15,5285,19,Tennant,1803,0.13,6,,1.29,1.16,1.09
77,Iridium,Ir,192.217,115,77,77,6,9,solid,,yes,yes,,,Transition Metal,1.9,2.2,8.967,2.26E+01,2716.15,4701,25,Tennant,1804,0.131,6,,1.22,1.15,1.07
78,Platinum,Pt,195.084,117,78,78,6,10,solid,,yes,yes,,,Transition Metal,1.8,2.28,8.9587,2.15E+01,2045.15,4098,32,Ulloa/Wood,1735,0.133,6,,1.23,1.12,1.1
79,Gold,Au,196.967,118,79,79,6,11,solid,,yes,yes,,,Transition Metal,1.8,2.54,9.2255,1.93E+01,1337.73,3129,21,Prehistoric,,0.129,6,,1.24,1.21,1.23
80,Mercury,Hg,200.59,121,80,80,6,12,liq,,yes,yes,,,Transition Metal,1.8,2,10.4375,1.35E+01,234.43,630,26,Prehistoric,,0.14,6,,1.33,1.42,-
81,Thallium,Tl,204.383,123,81,81,6,13,solid,,yes,yes,,,Metal,2.1,2.04,6.1082,1.19E+01,577.15,1746,28,Crookes,1861,0.129,6,3,1.44,1.42,1.5
82,Lead,Pb,207.2,125,82,82,6,14,solid,,yes,yes,,,Metal,1.8,2.33,7.4167,1.13E+01,600.75,2022,29,Prehistoric,,0.129,6,4,1.44,1.35,1.37
83,Bismuth,Bi,208.98,126,83,83,6,15,solid,,yes,yes,,,Metal,1.6,2.02,7.2856,9.81E+00,544.67,1837,19,Geoffroy the Younger,1753,0.122,6,5,1.51,1.41,1.35
84,Polonium,Po,210,126,84,84,6,16,solid,yes,yes,,,yes,Metalloid,1.5,2,8.417,9.32E+00,527.15,1235,34,Curie,1898,,6,6,1.45,1.35,1.29
85,Astatine,At,210,125,85,85,6,17,solid,yes,yes,,yes,,Noble Gas,1.4,2.2,9.3,7.00E+00,575.15,610,21,Corson et al.,1940,,6,7,1.47,1.38,1.38
86,Radon,Rn,222,136,86,86,6,18,gas,yes,yes,yes,,,Alkali Metal,1.3,,10.7485,9.73E-03,202.15,211.3,20,Dorn,1900,0.094,6,8,1.42,1.45,1.33
87,Francium,Fr,223,136,87,87,7,1,solid,yes,yes,yes,,,Alkaline Earth Metal,,0.7,4.0727,1.87E+00,300.15,950,21,Perey,1939,,7,1,2.23,2.18,-
88,Radium,Ra,226,138,88,88,7,2,solid,yes,yes,yes,,,Actinide,,0.9,5.2784,5.50E+00,973.15,2010,15,Pierre and Marie Curie,1898,,7,2,2.01,1.73,1.59
89,Actinium,Ac,227,138,89,89,7,3,solid,yes,yes,yes,,,Actinide,,1.1,5.17,1.01E+01,1323.15,3471,11,Debierne/Giesel,1899,0.12,7,,1.86,1.53,1.4
90,Thorium,Th,232.038,142,90,90,7,,solid,yes,yes,yes,,,Actinide,,1.3,6.3067,1.17E+01,2028.15,5061,12,Berzelius,1828,0.113,7,,1.75,1.43,1.36
91,Protactinium,Pa,231.036,140,91,91,7,,solid,yes,yes,yes,,,Actinide,,1.5,5.89,1.54E+01,1873.15,4300,14,Hahn and Meitner,1917,,7,,1.69,1.38,1.29
92,Uranium,U,238.029,146,92,92,7,,solid,yes,yes,yes,,,Actinide,,1.38,6.1941,1.90E+01,1405.15,4404,15,Peligot,1841,0.116,7,,1.7,1.34,1.18
93,Neptunium,Np,237,144,93,93,7,,artificial,yes,,yes,,,Actinide,,1.36,6.2657,2.05E+01,913.15,4273,153,McMillan and Abelson,1940,,7,,1.71,1.36,1.16
94,Plutonium,Pu,244,150,94,94,7,,artificial,yes,,yes,,,Actinide,,1.28,6.0262,1.98E+01,913.15,3501,163,Seaborg et al.,1940,,7,,1.72,1.35,-
95,Americium,Am,243,148,95,95,7,,artificial,yes,,yes,,,Actinide,,1.3,5.9738,1.37E+01,1267.15,2880,133,Seaborg et al.,1944,,7,,1.66,1.35,-
96,Curium,Cm,247,151,96,96,7,,artificial,yes,,yes,,,Actinide,,1.3,5.9915,1.35E+01,1340.15,3383,133,Seaborg et al.,1944,,7,,1.66,1.36,-
97,Berkelium,Bk,247,150,97,97,7,,artificial,yes,,yes,,,Actinide,,1.3,6.1979,1.48E+01,1259.15,983,83,Seaborg et al.,1949,,7,,1.68,1.39,-
98,Californium,Cf,251,153,98,98,7,,artificial,yes,,yes,,,Actinide,,1.3,6.2817,1.51E+01,1925.15,1173,123,Seaborg et al.,1950,,7,,1.68,1.4,-
99,Einsteinium,Es,252,153,99,99,7,,artificial,yes,,yes,,,Actinide,,1.3,6.42,1.35E+01,1133.15,,123,Ghiorso et al.,1952,,7,,1.65,1.4,-
100,Fermium,Fm,257,157,100,100,7,,artificial,yes,,yes,,,Actinide,,1.3,6.5,,,,103,Ghiorso et al.,1953,,7,,1.67,-,-
101,Mendelevium,Md,258,157,101,101,7,,artificial,yes,,yes,,,Actinide,,1.3,6.58,,,,33,Ghiorso et al.,1955,,7,,1.73,1.39,-
102,Nobelium,No,259,157,102,102,7,,artificial,yes,,yes,,,Actinide,,1.3,6.65,,,,73,Ghiorso et al.,1958,,7,,1.76,-,-
103,Lawrencium,Lr,262,159,103,103,7,,artificial,yes,,yes,,,Actinide,,,,,,,203,Ghiorso et al.,1961,,7,,1.61,1.41,-
104,Rutherfordium,Rf,261,157,104,104,7,4,artificial,yes,,yes,,,Transactinide,,,,1.81E+01,,,,Ghiorso et al.,1969,,7,,1.57,1.4,1.31
105,Dubnium,Db,262,157,105,105,7,5,artificial,yes,,yes,,,Transactinide,,,,3.90E+01,,,,Ghiorso et al.,1970,,7,,1.49,1.36,1.26
106,Seaborgium,Sg,266,160,106,106,7,6,artificial,yes,,yes,,,Transactinide,,,,3.50E+01,,,,Ghiorso et al.,1974,,7,,1.43,1.28,1.21
107,Bohrium,Bh,264,157,107,107,7,7,artificial,yes,,yes,,,Transactinide,,,,3.70E+01,,,,Armbruster and M?nzenberg,1981,,7,,1.41,1.28,1.19
108,Hassium,Hs,267,159,108,108,7,8,artificial,yes,,yes,,,Transactinide,,,,4.10E+01,,,,Armbruster and M?nzenberg,1983,,7,,1.34,1.25,1.18
109,Meitnerium,Mt,268,159,109,109,7,9,artificial,yes,,yes,,,Transactinide,,,,3.50E+01,,,,"GSI, Darmstadt, West Germany",1982,,7,,1.29,1.25,1.13
110,Darmstadtium ,Ds ,271,161,110,110,7,10,artificial,yes,,yes,,,Transactinide,,,,,,,,,1994,,7,,1.28,1.16,1.12
111,Roentgenium ,Rg ,272,161,111,111,7,11,artificial,yes,,yes,,,Transactinide,,,,,,,,,1994,,7,,1.21,1.16,1.18
112,Copernicium ,Cn ,285,173,112,112,7,12,artificial,yes,,yes,,,Transactinide,,,,,,,,,1996,,7,,1.22,1.37,1.3
113,Nihonium,Nh,284,171,113,113,7,13,artificial,yes,,yes,,,,,,,,,,,,2004,,7,3,1.36,-,-
114,Flerovium,Fl,289,175,114,114,7,14,artificial,yes,,yes,,,Transactinide,,,,,,,,,1999,,7,4,1.43,-,-
115,Moscovium,Mc,288,173,115,115,7,15,artificial,yes,,yes,,,,,,,,,,,,2010,,7,5,1.62,-,-
116,Livermorium,Lv,292,176,116,116,7,16,artificial,yes,,yes,,,Transactinide,,,,,,,,,2000,,7,6,1.75,-,-
117,Tennessine,Ts,295,178,117,117,7,17,artificial,yes,,,yes,,,,,,,,,,,2010,,7,7,1.65,-,-
118,Oganesson,Og,294,176,118,118,7,18,artificial,yes,,,yes,,Noble Gas,,,,,,,,,2006,,7,8,1.57,-,-
================================================
FILE: graphium/features/positional_encoding.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Tuple, Union, Optional, Dict, Any, OrderedDict
from copy import deepcopy
import numpy as np
import torch
from scipy.sparse import spmatrix
from collections import OrderedDict as OderedDictClass
from graphium.features.spectral import compute_laplacian_pe
from graphium.features.rw import compute_rwse
from graphium.features.electrostatic import compute_electrostatic_interactions
from graphium.features.commute import compute_commute_distances
from graphium.features.graphormer import compute_graphormer_distances
from graphium.features.transfer_pos_level import transfer_pos_level
def get_all_positional_encodings(
adj: Union[np.ndarray, spmatrix],
num_nodes: int,
pos_kwargs: Optional[Dict] = None,
) -> Tuple["OrderedDict[str, np.ndarray]"]:
r"""
Get features positional encoding.
Parameters:
adj [num_nodes, num_nodes]: Adjacency matrix of the graph
num_nodes: Number of nodes in the graph
pos_encoding_as_features: keyword arguments for function `graph_positional_encoder`
to generate positional encoding for node features.
Returns:
pe_dict: Dictionary of positional and structural encodings
"""
pos_kwargs = {} if pos_kwargs is None else pos_kwargs
pe_dict = OderedDictClass()
# Initialize cache
cache = {}
# Get the positional encoding for the features
if len(pos_kwargs) > 0:
for pos_name, this_pos_kwargs in pos_kwargs["pos_types"].items():
this_pos_kwargs = deepcopy(this_pos_kwargs)
pos_type = this_pos_kwargs.pop("pos_type", None)
pos_level = this_pos_kwargs.pop("pos_level", None)
this_pe, cache = graph_positional_encoder(
deepcopy(adj),
num_nodes,
pos_type=pos_type,
pos_level=pos_level,
pos_kwargs=this_pos_kwargs,
cache=cache,
)
if pos_level == "node":
pe_dict.update({f"{pos_type}": this_pe})
else:
pe_dict.update({f"{pos_level}_{pos_type}": this_pe})
return pe_dict
def graph_positional_encoder(
adj: Union[np.ndarray, spmatrix],
num_nodes: int,
pos_type: Optional[str] = None,
pos_level: Optional[str] = None,
pos_kwargs: Optional[Dict[str, Any]] = None,
cache: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
r"""
Get a positional encoding that depends on the parameters.
Parameters:
adj [num_nodes, num_nodes]: Adjacency matrix of the graph
num_nodes: Number of nodes in the graph
pos_type: The type of positional encoding to use. If None, it must be provided by `pos_kwargs["pos_type"]`. Supported types are:
- laplacian_eigvec \
- laplacian_eigval \ -> cache connected comps. & eigendecomp.
- rwse
- electrostatic \
- commute \ -> cache pinvL
- graphormer
pos_level: Positional level to output. If None, it must be provided by `pos_kwargs["pos_level"]`.
- node
- edge
- nodepair
- graph
pos_kwargs: Extra keyword arguments for the positional encoding. Can include the keys pos_type and pos_level.
cache: Dictionary of cached objects
Returns:
pe: Positional or structural encoding
cache: Updated dictionary of cached objects
"""
pos_kwargs = deepcopy(pos_kwargs)
if pos_kwargs is None:
pos_kwargs = {}
if cache is None:
cache = {}
# Get the positional type
pos_type2 = pos_kwargs.pop("pos_type", None)
if pos_type is None:
pos_type = pos_type2
if pos_type2 is not None:
assert (
pos_type == pos_type2
), f"The positional type must be the same in `pos_type` and `pos_kwargs['pos_type']`. Provided: {pos_type} and {pos_type2}"
assert pos_type is not None, "Either `pos_type` or `pos_kwargs['pos_type']` must be provided."
# Get the positional level
pos_level2 = pos_kwargs.pop("pos_level", None)
if pos_level is None:
pos_level = pos_level2
if pos_level2 is not None:
assert (
pos_level == pos_level2
), f"The positional level must be the same in `pos_level` and `pos_kwargs['pos_level']`. Provided: {pos_level} and {pos_level2}"
assert pos_level is not None, "Either `pos_level` or `pos_kwargs['pos_level']` must be provided."
# Convert to numpy array
if isinstance(adj, torch.sparse.Tensor):
adj = adj.to_dense().numpy()
elif isinstance(adj, torch.Tensor):
adj = adj.numpy()
adj = adj.astype(np.float64)
# Calculate positional encoding
if pos_type == "laplacian_eigvec":
_, pe, base_level, cache = compute_laplacian_pe(adj, cache=cache, **pos_kwargs)
elif pos_type == "laplacian_eigval":
pe, _, base_level, cache = compute_laplacian_pe(adj, cache=cache, **pos_kwargs)
elif pos_type == "rw_return_probs":
pe, base_level, cache = compute_rwse(
adj.astype(np.float32), num_nodes=num_nodes, cache=cache, pos_type=pos_type, **pos_kwargs
)
elif pos_type == "rw_transition_probs":
pe, base_level, cache = compute_rwse(
adj.astype(np.float32), num_nodes=num_nodes, cache=cache, pos_type=pos_type, **pos_kwargs
)
elif pos_type == "electrostatic":
pe, base_level, cache = compute_electrostatic_interactions(adj, cache, **pos_kwargs)
elif pos_type == "commute":
pe, base_level, cache = compute_commute_distances(adj, num_nodes, cache, **pos_kwargs)
elif pos_type == "graphormer":
pe, base_level, cache = compute_graphormer_distances(adj, num_nodes, cache, **pos_kwargs)
else:
raise ValueError(f"Unknown `pos_type`: {pos_type}")
# Convert to float32 and Convert between different pos levels
if isinstance(pe, (list, tuple)):
pe = [this_pe.astype(np.float32) for this_pe in pe]
pe = [transfer_pos_level(this_pe, base_level, pos_level, adj, num_nodes, cache) for this_pe in pe]
else:
pe = np.real(pe).astype(np.float32)
pe = transfer_pos_level(pe, base_level, pos_level, adj, num_nodes, cache)
return pe, cache
================================================
FILE: graphium/features/properties.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Union, List, Callable
import numpy as np
import datamol as dm
from rdkit.Chem import rdMolDescriptors as rdMD
from loguru import logger
def get_prop_or_none(
prop: Callable, n: int, *args: Union[dm.Mol, str], **kwargs: Union[dm.Mol, str]
) -> Union[List[float], List[None]]:
r"""
return properties. If error, return list of `None` with lenght `n`.
Parameters:
prop: The property to compute.
n: The number of elements in the property.
*args: The arguments to pass to the property.
**kwargs: The keyword arguments to pass to the property.
Returns:
The property or a list of `None` with lenght `n`.
"""
logger.warning("get_prop_or_none is deprecated. Use `datamol.to_fp` instead.")
try:
return prop(*args, **kwargs)
except RuntimeError:
return [None] * n
def get_props_from_mol(
mol: Union[dm.Mol, str],
properties: Union[List[str], str] = "autocorr3d",
) -> np.ndarray:
r"""
Function to get a given set of desired properties from a molecule,
and output a property list.
Parameters:
mol: The molecule from which to compute the properties.
properties:
The list of properties to compute for each molecule. It can be the following:
- 'descriptors'
- 'autocorr3d'
- 'rdf'
- 'morse'
- 'whim'
- 'all'
Returns:
props: np.array(float)
The array of properties for the desired molecule
classes_start_idx: list(int)
The list of index specifying the start of each new class of
descriptor or property. For example, if props has 20 elements,
the first 5 are rotatable bonds, the next 8 are morse, and
the rest are whim, then ``classes_start_idx = [0, 5, 13]``.
This will mainly be useful to normalize the features of
each class.
classes_names: list(str)
The name of the classes associated to each starting index.
Will be usefull to understand what property is the network learning.
"""
logger.warning("get_props_from_mol is deprecated. Use `datamol.to_fp` instead.")
if isinstance(mol, str):
mol = dm.to_mol(
mol
) # Doesn't need `ordered=True` because the fingerprints don't depend on the atom order
if isinstance(properties, str):
properties = [properties]
properties = [p.lower() for p in properties]
# Initialize arrays
props = [] # Property vector for the features
classes_start_idx = [] # The starting index for each property class
classes_names = []
# Generate a 3D structure for the molecule
mol = dm.add_hs(mol)
if ("autocorr3d" in properties) or ("all" in properties):
# Some kind of 3D description of the molecule
classes_names.append("autocorr3d")
classes_start_idx.append(len(props))
props.extend(get_prop_or_none(rdMD.CalcAUTOCORR3D, 80, mol))
if ("rdf" in properties) or ("all" in properties):
# The radial distribution function (better than the inertia)
# https://en.wikipedia.org/wiki/Radial_distribution_function
classes_names.append("rdf")
classes_start_idx.append(len(props))
props.extend(get_prop_or_none(rdMD.CalcRDF, 210, mol))
if ("morse" in properties) or ("all" in properties):
# Molecule Representation of Structures based on Electron diffraction descriptors
classes_names.append("morse")
classes_start_idx.append(len(props))
props.extend(get_prop_or_none(rdMD.CalcMORSE, 224, mol))
if ("whim" in properties) or ("all" in properties):
# WHIM descriptors are 3D structural descriptors obtained from the
# (x,y,z)‐atomic coordinates of a molecular conformation of a chemical,
# and are used successfully in QSAR modelling.
classes_names.append("whim")
classes_start_idx.append(len(props))
props.extend(get_prop_or_none(rdMD.CalcWHIM, 114, mol))
return np.array(props), classes_start_idx, classes_names
================================================
FILE: graphium/features/rw.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Tuple, Union, Optional, List, Dict, Any, Iterable
from scipy.sparse import issparse, spmatrix, coo_matrix
import numpy as np
import torch
from torch_geometric.utils import to_dense_adj, from_scipy_sparse_matrix
from torch_scatter import scatter_add
from torch_geometric.utils.num_nodes import maybe_num_nodes
def compute_rwse(
adj: Union[np.ndarray, spmatrix],
ksteps: Union[int, List[int]],
num_nodes: int,
cache: Dict[str, Any],
pos_type: str = "rw_return_probs" or "rw_transition_probs",
space_dim: int = 0,
) -> Tuple[np.ndarray, str, Dict[str, Any]]:
"""
Compute Random Walk Spectral Embedding (RWSE) for given list of K steps.
Parameters:
adj [num_nodes, num_nodes]: Adjacency matrix
ksteps: List of numbers of steps for the random walks. If int, a list is generated from 1 to ksteps.
num_nodes: Number of nodes in the graph
cache: Dictionary of cached objects
pos_type: Desired output
space_dim: Estimated dimensionality of the space. Used to
correct the random-walk diagonal by a factor `k^(space_dim/2)`.
In euclidean space, this correction means that the height of
the gaussian distribution stays almost constant across the number of
steps, if `space_dim` is the dimension of the euclidean space.
Returns:
Two possible outputs:
rw_return_probs [num_nodes, len(ksteps)]: Random-Walk k-step landing probabilities
rw_transition_probs [num_nodes, num_nodes, len(ksteps)]: Random-Walk k-step transition probabilities
base_level: Indicator of the output pos_level (node, edge, nodepair, graph) -> here either node or nodepair
cache: Updated dictionary of cached objects
"""
base_level = "node" if pos_type == "rw_return_probs" else "nodepair"
# Manually handles edge case of 1 atom molecules here
if not isinstance(ksteps, Iterable):
ksteps = list(range(1, ksteps + 1))
if num_nodes == 1:
if pos_type == "rw_return_probs":
return np.ones((1, len(ksteps))), base_level, cache
else:
return np.ones((1, 1, len(ksteps))), base_level, cache
# Get the edge indices from the adjacency matrix
if not issparse(adj):
if "coo_adj" in cache:
adj = cache["coo_adj"]
elif "csr_adj" in cache:
adj = cache["csr_adj"]
else:
adj = coo_matrix(adj, dtype=np.float64)
cache["coo_adj"] = adj
edge_index, edge_weight = from_scipy_sparse_matrix(adj)
# Compute the random-walk transition probabilities
if "ksteps" in cache:
cached_k = cache["ksteps"]
missing_k = [k for k in ksteps if k not in cached_k]
if missing_k == []:
pass
elif min(missing_k) < min(cached_k):
Pk_dict = get_Pks(missing_k, edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes)
cache["ksteps"] = sorted(missing_k + cache["ksteps"])
for k in missing_k:
cache["Pk"][k] = Pk_dict[k]
else:
start_k = min([max(cached_k), min(missing_k)])
start_Pk = cache["Pk"][start_k]
Pk_dict = get_Pks(
missing_k,
edge_index=edge_index,
edge_weight=edge_weight,
num_nodes=num_nodes,
start_Pk=start_Pk,
start_k=start_k,
)
cache["ksteps"] = sorted(cache["ksteps"] + missing_k)
for k in missing_k:
cache["Pk"][k] = Pk_dict[k]
else:
Pk_dict = get_Pks(ksteps, edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes)
cache["ksteps"] = list(Pk_dict.keys())
cache["Pk"] = Pk_dict
pe_list = []
if pos_type == "rw_return_probs":
for k in ksteps:
pe_list.append(torch.diagonal(cache["Pk"][k], dim1=-2, dim2=-1) * (k ** (space_dim / 2)))
else:
for k in ksteps:
pe_list.append(cache["Pk"][k])
pe = torch.stack(pe_list, dim=-1).numpy()
return pe, base_level, cache
def get_Pks(
ksteps: List[int],
edge_index: Tuple[torch.Tensor, torch.Tensor],
edge_weight: Optional[torch.Tensor] = None,
num_nodes: Optional[int] = None,
start_Pk: Optional[torch.Tensor] = None,
start_k: Optional[int] = None,
) -> Dict[int, np.ndarray]:
"""
Compute Random Walk landing probabilities for given list of K steps.
Parameters:
ksteps: List of numbers of k-steps for which to compute the RW landings
edge_index: PyG sparse representation of the graph
edge_weight: Edge weights
num_nodes: Number of nodes in the graph
Returns:
2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs
"""
if edge_weight is None:
edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)
num_nodes = maybe_num_nodes(edge_index, num_nodes)
src = edge_index[0]
deg = scatter_add(edge_weight, src, dim=0, dim_size=num_nodes) # Out degrees.
deg_inv = deg.pow(-1.0)
deg_inv.masked_fill_(deg_inv == float("inf"), 0)
if edge_index.numel() == 0:
P = edge_index.new_zeros((1, num_nodes, num_nodes))
else:
# P = D^-1 * A
P = torch.diag(deg_inv).float() @ to_dense_adj(
edge_index, max_num_nodes=num_nodes
) # 1 x (Num nodes) x (Num nodes)
if start_Pk is not None:
Pk = start_Pk @ P.clone().detach().matrix_power(min(ksteps) - start_k)
else:
Pk = P.clone().detach().matrix_power(min(ksteps))
Pk_dict = {}
for k in range(min(ksteps), max(ksteps) + 1):
Pk_dict[k] = Pk.squeeze(0)
Pk = Pk @ P
return Pk_dict
================================================
FILE: graphium/features/spectral.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Tuple, Union, Dict, Any
from scipy.linalg import eig
from scipy.sparse import csr_matrix, diags, issparse, spmatrix
import numpy as np
import torch
import networkx as nx
from graphium.utils.tensor import is_dtype_torch_tensor, is_dtype_numpy_array
def compute_laplacian_pe(
adj: Union[np.ndarray, spmatrix],
num_pos: int,
cache: Dict[str, Any],
disconnected_comp: bool = True,
normalization: str = "none",
) -> Tuple[np.ndarray, str, Dict[str, Any]]:
r"""
Compute the Laplacian eigenvalues and eigenvectors of the Laplacian of the graph.
Parameters:
adj [num_nodes, num_nodes]: Adjacency matrix of the graph
num_pos: Number of Laplacian eigenvectors to compute
cache: Dictionary of cached objects
disconnected_comp: Whether to compute the eigenvectors for each connected component
normalization: Normalization to apply to the Laplacian
Returns:
Two possible outputs:
eigvals [num_nodes, num_pos]: Eigenvalues of the Laplacian repeated for each node.
This repetition is necessary in case of disconnected components, where
the eigenvalues of the Laplacian are not the same for each node.
eigvecs [num_nodes, num_pos]: Eigenvectors of the Laplacian
base_level: Indicator of the output pos_level (node, edge, nodepair, graph) -> here node
cache: Updated dictionary of cached objects
"""
base_level = "node"
# Sparsify the adjacency patrix
if not issparse(adj):
if "csr_adj" not in cache:
adj = csr_matrix(adj, dtype=np.float64)
cache["csr_adj"] = adj
else:
adj = cache["csr_adj"]
# Compute the Laplacian, and normalize it
if f"L_{normalization}_sp" not in cache:
D = np.array(np.sum(adj, axis=1)).flatten()
D_mat = diags(D)
L = -adj + D_mat
L_norm = normalize_matrix(L, degree_vector=D, normalization=normalization)
cache[f"L_{normalization}_sp"] = L_norm
else:
L_norm = cache[f"L_{normalization}_sp"]
components = []
if disconnected_comp:
if "components" not in cache:
# Get the list of connected components
components = list(nx.connected_components(nx.from_scipy_sparse_array(adj)))
cache["components"] = components
else:
components = cache["components"]
# Compute the eigenvectors for each connected component, and stack them together
if len(components) > 1:
if "lap_eig_comp" not in cache:
eigvals = np.zeros((adj.shape[0], num_pos), dtype=np.complex64)
eigvecs = np.zeros((adj.shape[0], num_pos), dtype=np.complex64)
for component in components:
comp = list(component)
this_L = L_norm[comp][:, comp]
this_eigvals, this_eigvecs = _get_positional_eigvecs(this_L, num_pos=num_pos)
# Eigenvalues previously set to infinity are now set to 0
# Any NaN in the eigvals or eigvecs will be set to 0
this_eigvecs[~np.isfinite(this_eigvecs)] = 0.0
this_eigvals[~np.isfinite(this_eigvals)] = 0.0
eigvals[comp, :] = np.expand_dims(this_eigvals, axis=0)
eigvecs[comp, :] = this_eigvecs
cache["lap_eig_comp"] = (eigvals, eigvecs)
else:
eigvals, eigvecs = cache["lap_eig_comp"]
else:
if "lap_eig" not in cache:
eigvals, eigvecs = _get_positional_eigvecs(L, num_pos=num_pos)
# Eigenvalues previously set to infinity are now set to 0
# Any NaN in the eigvals or eigvecs will be set to 0
eigvecs[~np.isfinite(eigvecs)] = 0.0
eigvals[~np.isfinite(eigvals)] = 0.0
eigvals = np.repeat(np.expand_dims(eigvals, axis=0), adj.shape[0], axis=0)
cache["lap_eig"] = (eigvals, eigvecs)
else:
eigvals, eigvecs = cache["lap_eig"]
return eigvals, eigvecs, base_level, cache
def _get_positional_eigvecs(
matrix: Union[np.ndarray, spmatrix],
num_pos: int,
) -> Tuple[np.ndarray, np.ndarray]:
r"""
compute the eigenvalues and eigenvectors of a matrix
Parameters:
matrix: Matrix to compute the eigenvalues and eigenvectors of
num_pos: Number of eigenvalues and eigenvectors to compute
Returns:
eigvals: Eigenvalues of the matrix
eigvecs: Eigenvectors of the matrix
"""
mat_len = matrix.shape[0]
eigvals, eigvecs = eig(matrix.todense())
# Pad with non-sense eigenvectors if required
if num_pos > mat_len:
temp_EigVal = np.ones(num_pos - mat_len, dtype=np.float64) + float("inf")
temp_EigVec = np.zeros((mat_len, num_pos - mat_len), dtype=np.float64)
eigvals = np.concatenate([eigvals, temp_EigVal], axis=0)
eigvecs = np.concatenate([eigvecs, temp_EigVec], axis=1)
# Sort and keep only the first `num_pos` elements
sort_idx = eigvals.argsort()
eigvals = eigvals[sort_idx]
eigvals = eigvals[:num_pos]
eigvecs = eigvecs[:, sort_idx]
eigvecs = eigvecs[:, :num_pos]
# Normalize the eigvecs
eigvecs = eigvecs / np.maximum(np.sqrt(np.sum(eigvecs**2, axis=0, keepdims=True)), 1e-4)
return eigvals, eigvecs
def normalize_matrix(
matrix: Union[np.ndarray, spmatrix],
degree_vector=None,
normalization: str = None,
) -> Union[np.ndarray, spmatrix]:
r"""
Normalize a given matrix using its degree vector
Parameters
---------------
matrix: torch.tensor(N, N) or scipy.sparse.spmatrix(N, N)
A square matrix representing either an Adjacency matrix or a Laplacian.
degree_vector: torch.tensor(N) or np.ndarray(N) or None
A vector representing the degree of ``matrix``.
``None`` is only accepted if ``normalization==None``
normalization: str or None, Default='none'
Normalization to use on the eig_matrix
- 'none' or ``None``: no normalization
- 'sym': Symmetric normalization ``D^-0.5 L D^-0.5``
- 'inv': Inverse normalization ``D^-1 L``
Returns
-----------
matrix: torch.tensor(N, N) or scipy.sparse.spmatrix(N, N)
The normalized matrix
"""
# Transform the degree vector into a matrix
if degree_vector is None:
if not ((normalization is None) or (normalization.lower() == "none")):
raise ValueError("`degree_vector` cannot be `None` if `normalization` is not `None`")
else:
if is_dtype_numpy_array(matrix.dtype):
with np.errstate(divide="ignore", invalid="ignore"):
degree_inv = np.expand_dims(degree_vector**-0.5, axis=1)
degree_inv[np.isinf(degree_inv)] = 0
elif is_dtype_torch_tensor(matrix.dtype):
degree_inv = torch.unsqueeze(degree_vector**-0.5, dim=1)
degree_inv[torch.isinf(degree_inv)] = 0
# Compute the normalized matrix
if (normalization is None) or (normalization.lower() == "none"):
pass
elif normalization.lower() == "sym":
matrix = degree_inv * matrix * degree_inv.T
elif normalization.lower() == "inv":
matrix = (degree_inv**2) * matrix
else:
raise ValueError(
f'`normalization` should be `None`, `"None"`, `"sym"` or `"inv"`, but `{normalization}` was provided'
)
return matrix
================================================
FILE: graphium/features/transfer_pos_level.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Tuple, Union, List, Dict, Any, Optional
import numpy as np
from scipy.sparse import spmatrix, issparse, coo_matrix
from torch_geometric.utils import from_scipy_sparse_matrix
def transfer_pos_level(
pe: np.ndarray,
in_level: str,
out_level: str,
adj: Union[np.ndarray, spmatrix],
num_nodes: int,
cache: Optional[Dict[str, Any]] = None,
) -> np.ndarray:
r"""
Transfer positional encoding between different positional levels (node, edge, nodepair, graph)
Parameters:
pe: Input pe with pos_level defined by in_level
in_level: pos_level of input pe
out_level: desired pos_level of output pe
adj [num_nodes, num_nodes]: Adjacency matrix of the graph
num_nodes: Number of nodes in the graph
cache: Dictionary of cached objects
Returns:
pe: Output pe with pos_level defined by out_level
"""
if cache is None:
cache = {}
if in_level == "node":
if out_level == "node":
pass
elif out_level == "edge":
pe, cache = node_to_edge(pe, adj, cache)
elif out_level == "nodepair":
pe = node_to_nodepair(pe, num_nodes)
elif out_level == "graph":
raise NotImplementedError("Transfer function (node -> graph) not yet implemented.")
else:
raise ValueError(f"Unknown `pos_level`: {out_level}")
elif in_level == "edge":
raise NotImplementedError("Transfer function (edge -> *) not yet implemented.")
elif in_level == "nodepair":
if len(pe.shape) == 2:
pe = np.expand_dims(pe, -1)
if out_level == "node":
pe = nodepair_to_node(pe)
elif out_level == "edge":
pe, cache = nodepair_to_edge(pe, adj, cache)
elif out_level == "nodepair":
pass
elif out_level == "graph":
raise NotImplementedError("Transfer function (nodepair -> graph) not yet implemented.")
else:
raise ValueError(f"Unknown `pos_level`: {out_level}")
elif in_level == "graph":
if out_level == "node":
pe = graph_to_node(pe, num_nodes, cache)
elif out_level in ["edge", "nodepair"]:
raise NotImplementedError("Transfer function (graph -> edge/nodepair) not yet implemented.")
else:
raise ValueError(f"Unknown `pos_level`: {out_level}")
else:
raise ValueError(f"Unknown `pos_level`: {in_level}")
return pe
# Transfer functions between different levels, i.e., node, edge, nodepair and graph level.
# TODO:
# - Implement missing transfer functions below
# - Are transfer functions graph -> edge/nodepair and edge -> graph needed?
def node_to_edge(
pe: np.ndarray, adj: Union[np.ndarray, spmatrix], cache: Optional[Dict[str, Any]] = None
) -> Tuple[np.ndarray, Dict[str, Any]]:
r"""
Get an edge-level positional encoding from a node-level positional encoding.
-> For each edge, concatenate the sum and absolute difference of pe of source and destination node.
Parameters:
pe [num_nodes, num_feat]: Node-level positional encoding
adj [num_nodes, num_nodes]: Adjacency matrix of the graph
cache: Dictionary of cached objects
Returns:
edge_pe [2 * num_edges, 2 * num_feat]: Edge-level positional encoding
cache: Updated dictionary of cached objects
"""
if cache is None:
cache = {}
if not issparse(adj):
if "coo_adj" in cache:
adj = cache["coo_adj"]
elif "csr_adj" in cache:
adj = cache["csr_adj"]
else:
adj = coo_matrix(adj, dtype=np.float64)
cache["coo_adj"] = adj
edge_index, _ = from_scipy_sparse_matrix(adj)
src, dst = edge_index[0], edge_index[1]
pe_sum = pe[src] + pe[dst]
pe_abs_diff = np.abs(pe[src] - pe[dst])
edge_pe = np.concatenate((pe_sum, pe_abs_diff), axis=-1)
return edge_pe, cache
def node_to_nodepair(pe: np.ndarray, num_nodes: int) -> np.ndarray:
r"""
Get a nodepair-level positional encoding from a node-level positional encoding.
-> For each nodepair (i,j) concatenate the sum and absolute difference of pe at node i and j.
Parameters:
pe [num_nodes, num_feat]: Node-level positional encoding
num_nodes: Number of nodes in the graph
Returns:
nodepair_pe [num_nodes, num_nodes, 2 * num_feat]: Nodepair-level positional encoding
"""
expanded_pe = np.expand_dims(pe, axis=1)
expanded_pe = np.repeat(expanded_pe, repeats=num_nodes, axis=1)
pe_sum = expanded_pe + expanded_pe.transpose([1, 0, 2])
pe_abs_diff = np.abs(expanded_pe - expanded_pe.transpose([1, 0, 2]))
nodepair_pe = np.concatenate((pe_sum, pe_abs_diff), axis=-1)
return nodepair_pe
def node_to_graph(pe: np.ndarray, num_nodes: int) -> np.ndarray:
r"""
Get a graph-level positional encoding from a node-level positional encoding.
-> E.g., min/max/mean-pooling of node features.
Parameters:
pe [num_nodes, num_feat]: Node-level positional encoding
num_nodes: Number of nodes in the graph
Returns:
graph_pe [1, num_feat]: Graph-level positional encoding
"""
raise NotImplementedError("Transfer function (node -> graph) not yet implemented.")
def edge_to_node(pe: np.ndarray, adj: Union[np.ndarray, spmatrix]) -> np.ndarray:
r"""
Get a node-level positional encoding from an edge-level positional encoding.
-> E.g., min/max/mean-pooling of information from edges (i,j) that contain node i
Parameters:
pe [num_edges, num_feat]: Edge-level positional encoding
adj [num_nodes, num_nodes]: Adjacency matrix of the graph
Returns:
node_pe [num_edges, num_feat]: Node-level positional encoding
"""
raise NotImplementedError("Transfer function (edge -> node) not yet implemented.")
def edge_to_nodepair(
pe: np.ndarray, adj: Union[np.ndarray, spmatrix], num_nodes: int, cache: Optional[Dict[str, Any]] = None
) -> np.ndarray:
r"""
Get a nodepair-level positional encoding from an edge-level positional encoding.
-> Zero-padding of non-existing edges.
Parameters:
pe [num_edges, num_feat]: Edge-level positional encoding
adj [num_nodes, num_nodes]: Adjacency matrix of the graph
num_nodes: Number of nodes in the graph
cache: Dictionary of cached objects
Returns:
nodepair_pe [num_edges, num_edges, num_feat]: Nodepair-level positional encoding
cache: Updated dictionary of cached objects
"""
if cache is None:
cache = {}
num_feat = pe.shape[-1]
if not isinstance(adj, coo_matrix):
if "coo_adj" in cache:
adj = cache["coo_adj"]
else:
adj = coo_matrix(adj, dtype=np.float64)
cache["coo_adj"] = adj
dst, src = adj.row, adj.col
nodepair_pe = np.zeros((num_nodes, num_nodes, num_feat))
for i in range(len(dst)):
nodepair_pe[dst[i], src[i], ...] = pe[i, ...]
return nodepair_pe, cache
def edge_to_graph(pe: np.ndarray) -> np.ndarray:
r"""
Get a graph-level positional encoding from an edge-level positional encoding.
Parameters:
pe [num_edges, num_feat]: Edge-level positional encoding
Returns:
graph_pe [1, num_feat]: Graph-level positional encoding
"""
raise NotImplementedError("Transfer function (edge -> graph) not yet implemented.")
def nodepair_to_node(pe: np.ndarray, stats_list: List = [np.min, np.mean, np.std]) -> np.ndarray:
r"""
Get a node-level positional encoding from a graph-level positional encoding.
-> Calculate statistics over rows & cols of input positional encoding
Parameters:
pe [num_nodes, num_nodes, num_feat]: Nodepair-level positional encoding
stats_list: List of statistics to calculate per row/col of nodepair-level pe
Returns:
node_pe [num_nodes, 2 * len(stats_list) * num_feat]: Node-level positional encoding
"""
num_feat = pe.shape[-1]
node_pe_list = []
for stat in stats_list:
for i in range(num_feat):
node_pe_list.append(stat(pe[..., i], axis=0))
node_pe_list.append(stat(pe[..., i], axis=1))
node_pe = np.stack(node_pe_list, axis=-1)
return node_pe
def nodepair_to_edge(
pe: np.ndarray, adj: Union[np.ndarray, spmatrix], cache: Optional[Dict[str, Any]] = None
) -> np.ndarray:
r"""
Get a edge-level positional encoding from a nodepair-level positional encoding.
-> Mask and sparsify nodepair-level positional encoding
Parameters:
pe [num_nodes, num_nodes, num_feat]: Nodepair-level positional encoding
adj [num_nodes, num_nodes]: Adjacency matrix of the graph
cache: Dictionary of cached objects
Returns:
edge_pe [num_edges, num_feat]: Edge-level positional encoding
cache: Updated dictionary of cached objects
"""
if cache is None:
cache = {}
num_feat = pe.shape[-1]
if not isinstance(adj, coo_matrix):
if "coo_adj" in cache:
adj = cache["coo_adj"]
else:
adj = coo_matrix(adj, dtype=np.float64)
cache["coo_adj"] = adj
dst, src = adj.row, adj.col
edge_pe = np.zeros((len(dst), num_feat))
for i in range(len(src)):
edge_pe[i, ...] = pe[dst[i], src[i]]
return edge_pe, cache
def nodepair_to_graph(pe: np.ndarray, num_nodes: int) -> np.ndarray:
r"""
Get a graph-level positional encoding from a nodepair-level positional encoding.
-> E.g., min/max/mean-pooling of entries of input pe
Parameters:
pe [num_nodes, num_nodes, num_feat]: Nodepair-level positional encoding
num_nodes: Number of nodes in the graph
Returns:
graph_pe [1, num_feat]: Graph-level positional encoding
"""
raise NotImplementedError("Transfer function (nodepair -> graph) not yet implemented.")
def graph_to_node(
pe: Union[np.ndarray, List], num_nodes: int, cache: Optional[Dict[str, Any]] = None
) -> np.ndarray:
r"""
Get a node-level positional encoding from a nodepair-level positional encoding.
-> E.g., expand dimension of graph-level pe
Parameters:
pe [num_feat]: Nodepair-level positional encoding (or list of them if graph disconnected)
num_nodes: Number of nodes in the graph
cache: Dictionary of cached objects
Returns:
node_pe [num_nodes, num_feat]: Node-level positional encoding
"""
if cache is None:
cache = {}
node_pe = None
# The key 'components' is only in cache if disconnected_comp == True when computing base pe
if "components" in cache:
if len(cache["components"]) > 1:
node_pe = np.zeros((num_nodes, len(pe)))
components = cache["components"]
for i, component in enumerate(components):
comp = list(component)
node_pe[comp, :] = np.real(pe[i])
if node_pe is None:
node_pe = np.tile(pe, (num_nodes, 1))
return node_pe
================================================
FILE: graphium/finetuning/__init__.py
================================================
from .utils import modify_cfg_for_finetuning
from .finetuning import GraphFinetuning
from .finetuning_architecture import FullGraphFinetuningNetwork
FINETUNING_CONFIG_KEY = "finetuning"
================================================
FILE: graphium/finetuning/finetuning.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Iterable, List, Dict, Tuple, Union, Callable, Any, Optional, Type
from collections import OrderedDict
import torch.nn as nn
import pytorch_lightning as pl
from torch.optim.optimizer import Optimizer
from pytorch_lightning.callbacks import BaseFinetuning
class GraphFinetuning(BaseFinetuning):
def __init__(
self,
finetuning_module: str,
added_depth: int = 0,
unfreeze_pretrained_depth: Optional[int] = None,
epoch_unfreeze_all: int = 0,
train_bn: bool = False,
):
"""
Finetuning training callback that (un)freezes modules as specified in the configuration file.
By default, the modified layers of the fineuning module and the finetuning head are unfrozen.
Parameters:
finetuning_module: Module to finetune from
added_depth: Number of layers of finetuning module that have been modified rel. to pretrained model
unfreeze_pretrained_depth: Number of additional layers to unfreeze before layers modified rel. to pretrained model
epoch_unfreeze_all: Epoch to unfreeze entire model
train_bn: Boolean value indicating if batchnorm layers stay in training mode
"""
super().__init__()
self.finetuning_module = finetuning_module
self.training_depth = added_depth
if unfreeze_pretrained_depth is not None:
self.training_depth += unfreeze_pretrained_depth
self.epoch_unfreeze_all = epoch_unfreeze_all
self.train_bn = train_bn
def freeze_before_training(self, pl_module: pl.LightningModule):
"""
Freeze everything up to finetuning module (and parts of finetuning module)
Parameters:
pl_module: PredictorModule used for finetuning
"""
# Access module map of pretrained module
module_map = pl_module.model.pretrained_model.net._module_map
for module_name in module_map.keys():
self.freeze_module(pl_module, module_name, module_map)
if module_name.startswith(self.finetuning_module):
# Do not freeze modules after finetuning module
break
def freeze_module(self, pl_module, module_name: str, module_map: Dict[str, Union[nn.ModuleList, Any]]):
"""
Freeze specific modules
Parameters:
module_name: Name of module to (partally) freeze
module_map: Dictionary mapping from module_name to corresponding module(s)
"""
modules = module_map[module_name]
if module_name == "pe_encoders":
for param in pl_module.model.pretrained_model.net.encoder_manager.parameters():
param.requires_grad = False
# We only partially freeze the finetuning module
if module_name.startswith(self.finetuning_module):
if self.training_depth == 0:
pass
else:
modules = modules[: -self.training_depth]
self.freeze(modules=modules, train_bn=self.train_bn)
def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer):
"""
Function unfreezing entire model at specified epoch
Parameters:
pl_module: PredictorModule used for finetuning
epoch: Current training epoch
optimizer: Optimizer used for finetuning
"""
if epoch == self.epoch_unfreeze_all:
self.unfreeze_and_add_param_group(modules=pl_module, optimizer=optimizer, train_bn=self.train_bn)
================================================
FILE: graphium/finetuning/finetuning_architecture.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Any, Dict, Optional, Union
import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.data import Batch
from graphium.nn.utils import MupMixin
from graphium.trainer.predictor import PredictorModule
from graphium.utils.spaces import FINETUNING_HEADS_DICT
class FullGraphFinetuningNetwork(nn.Module, MupMixin):
def __init__(
self,
pretrained_model: Union[str, "PretrainedModel"],
pretrained_model_kwargs: Dict[str, Any] = {},
pretrained_overwriting_kwargs: Dict[str, Any] = {},
finetuning_head_kwargs: Optional[Dict[str, Any]] = None,
num_inference_to_average: int = 1,
last_layer_is_readout: bool = False,
name: str = "FullFinetuningGNN",
):
r"""
Flexible class that allows to implement an end-to-end graph finetuning network architecture, supporting flexible pretrained models and finetuning heads.
The network decomposes into two parts of class PretrainedModel and FinetuningHead. The PretrainedModel class allows basic finetuning such as
finetuning from a specified module of the pretrained model and dropping/adding layers in this module. The (optional) FinetuningHead class allows more
flexible finetuning with a custom network applied after the pretrained model. If not specified, we fall back to basic finetuning integrated in PretrainedModel.
Parameters:
pretrained_model:
A PretrainedModel or an identifier of pretrained model within GRAPHIUM_PRETRAINED_MODELS_DICT or a valid .ckpt checkpoint path
pretrained_model_kwargs:
Key-word arguments to instantiate a model of the same class as the pretrained model (e.g., FullGraphMultitaskNetwork))
pretrained_overwriting_kwargs:
Key-word arguments indicating which parameters of loaded model are shared with the pretrained part of FullGraphFinetuningNetwork
finetuning_head_kwargs:
Key-word arguments to use for the finetuning head.
It must respect the following criteria:
- pretrained_model_kwargs[last_used_module]["out_level"] must be equal to finetuning_head_kwargs["in_level"]
- pretrained_model_kwargs[last_used_module]["out_dim"] must be equal to finetuning_head_kwargs["in_dim"]
Here, [last_used_module] represents the module that is finetuned from,
e.g., gnn, graph_output or (one of the) task_heads
num_inference_to_average:
Number of inferences to average at val/test time. This is used to avoid the noise introduced
by positional encodings with sign-flips. In case no such encoding is given,
this parameter is ignored.
NOTE: The inference time will be slowed-down proportionaly to this parameter.
last_layer_is_readout: Whether the last layer should be treated as a readout layer.
Allows to use the `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup
name:
Name attributed to the current network, for display and printing
purposes.
"""
super().__init__()
self.name = name
self.num_inference_to_average = num_inference_to_average
self.last_layer_is_readout = last_layer_is_readout
self._concat_last_layers = None
self.pretrained_model = pretrained_model
self.pretrained_overwriting_kwargs = pretrained_overwriting_kwargs
self.finetuning_head_kwargs = finetuning_head_kwargs
self.max_num_nodes_per_graph = None
self.max_num_edges_per_graph = None
self.finetuning_head = None
if not isinstance(self.pretrained_model, PretrainedModel):
self.pretrained_model = PretrainedModel(
self.pretrained_model, pretrained_model_kwargs, pretrained_overwriting_kwargs
)
if finetuning_head_kwargs is not None:
self.finetuning_head = FinetuningHead(finetuning_head_kwargs)
def forward(self, g: Batch) -> Tensor:
r"""
Apply the pre-processing neural network, the graph neural network,
and the post-processing neural network on the graph features.
Parameters:
g:
pyg Batch graph on which the convolution is done.
Must contain the following elements:
- Node key `"feat"`: `torch.Tensor[..., N, Din]`.
Input node feature tensor, before the network.
`N` is the number of nodes, `Din` is the input features dimension ``self.pre_nn.in_dim``
- Edge key `"edge_feat"`: `torch.Tensor[..., N, Ein]` **Optional**.
The edge features to use. It will be ignored if the
model doesn't supporte edge features or if
`self.in_dim_edges==0`.
- Other keys related to positional encodings `"pos_enc_feats_sign_flip"`,
`"pos_enc_feats_no_flip"`.
Returns:
`torch.Tensor[..., M, Dout]` or `torch.Tensor[..., N, Dout]`:
Node or graph feature tensor, after the network.
`N` is the number of nodes, `M` is the number of graphs,
`Dout` is the output dimension ``self.graph_output_nn.out_dim``
If the `self.gnn.pooling` is [`None`], then it returns node features and the output dimension is `N`,
otherwise it returns graph features and the output dimension is `M`
"""
g = self.pretrained_model.forward(g)
if self.finetuning_head is not None:
g = self.finetuning_head.forward(g)
return g
def make_mup_base_kwargs(self, divide_factor: float = 2.0) -> Dict[str, Any]:
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter:
divide_factor: Factor by which to divide the width.
Returns:
Dictionary with the kwargs to create the base model.
"""
kwargs = dict(
pretrained_model=self.pretrained_model,
pretrained_model_kwargs=None,
finetuning_head_kwargs=None,
num_inference_to_average=self.num_inference_to_average,
last_layer_is_readout=self.last_layer_is_readout,
name=self.name,
)
kwargs["pretrained_model_kwargs"] = self.pretrained_model.make_mup_base_kwargs(
divide_factor=divide_factor
)
if self.finetuning_head is not None:
kwargs["finetuning_head_kwargs"] = self.finetuning_head.make_mup_base_kwargs(
divide_factor=divide_factor, factor_in_dim=True
)
kwargs["pretrained_overwriting_kwargs"] = self.pretrained_overwriting_kwargs
return kwargs
def set_max_num_nodes_edges_per_graph(self, max_nodes: Optional[int], max_edges: Optional[int]) -> None:
r"""
Set the maximum number of nodes and edges for all gnn layers and encoder layers
Parameters:
max_nodes: Maximum number of nodes in the dataset.
This will be useful for certain architecture, but ignored by others.
max_edges: Maximum number of edges in the dataset.
This will be useful for certain architecture, but ignored by others.
"""
self.pretrained_model.net.set_max_num_nodes_edges_per_graph(max_nodes, max_edges)
class PretrainedModel(nn.Module, MupMixin):
def __init__(
self,
pretrained_model: str,
pretrained_model_kwargs: Dict[str, Any],
pretrained_overwriting_kwargs: Dict[str, Any],
):
r"""
Flexible class allowing to finetune pretrained models from GRAPHIUM_PRETRAINED_MODELS_DICT or from a ckeckpoint path.
Can be any model that inherits from nn.Module, MupMixin and comes with a module map (e.g., FullGraphMultitaskNetwork)
Parameters:
pretrained_model:
Identifier of pretrained model within GRAPHIUM_PRETRAINED_MODELS_DICT or from a checkpoint path
pretrained_model_kwargs:
Key-word arguments to instantiate a model of the same class as the pretrained model (e.g., FullGraphMultitaskNetwork))
pretrained_overwriting_kwargs:
Key-word arguments indicating which parameters of loaded model are shared with the pretrained part of FullGraphFinetuningNetwork
"""
super().__init__()
# Load pretrained model
pretrained_model = PredictorModule.load_pretrained_model(pretrained_model, device="cpu").model
pretrained_model.create_module_map()
# Initialize new model with architecture after desired modifications to architecture.
net = type(pretrained_model)
self.net = net(**pretrained_model_kwargs)
self.net.create_module_map()
# Overwrite parameters shared between loaded and modified pretrained model
self.overwrite_with_pretrained(pretrained_model, **pretrained_overwriting_kwargs)
def forward(self, g: Union[torch.Tensor, Batch]):
g = self.net.forward(g)
return g
def overwrite_with_pretrained(
self,
pretrained_model,
finetuning_module: str,
added_depth: int = 0,
sub_module_from_pretrained: str = None,
):
"""
Overwrite parameters shared between loaded and modified pretrained model
Parameters:
pretrained_model: Model from GRAPHIUM_PRETRAINED_MODELS_DICT
finetuning_module: Module to finetune from
added_depth: Number of modified layers at the end of finetuning module
sub_module_from_pretrained: Optional submodule to finetune from
"""
module_map = self.net._module_map
module_map_from_pretrained = pretrained_model._module_map
module_names_from_pretrained = module_map_from_pretrained.keys()
super_module_names_from_pretrained = set(
[module_name.split("-")[0] for module_name in module_names_from_pretrained]
)
for module_name in module_map.keys():
# Below exception handles some modules (e.g., pe_encoders in FullGraphMultitaskNetwork) that do not support len());
# They can always be replaced entirely
try:
shared_depth = len(module_map[module_name])
except:
module_map[module_name] = module_map_from_pretrained[module_name]
continue
if module_name.startswith(finetuning_module):
shared_depth -= added_depth
if module_name in module_map_from_pretrained.keys():
for idx in range(shared_depth):
module_map[module_name][idx] = module_map_from_pretrained[module_name][idx]
elif module_name.split("-")[0] in super_module_names_from_pretrained:
for idx in range(shared_depth):
module_map[module_name][idx] = module_map_from_pretrained[
"".join([module_name.split("-")[0], "-", sub_module_from_pretrained])
][idx]
else:
raise RuntimeError("Mismatch between loaded pretrained model and model to be overwritten.")
if module_name.startswith(finetuning_module):
break
def make_mup_base_kwargs(self, divide_factor: float = 2.0) -> Dict[str, Any]:
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter:
divide_factor: Factor by which to divide the width.
factor_in_dim: Whether to factor the input dimension
Returns:
Dictionary with the kwargs to create the base model.
"""
# For the post-nn network, all the dimension are divided
return self.net.make_mup_base_kwargs(divide_factor=divide_factor)
class FinetuningHead(nn.Module, MupMixin):
def __init__(self, finetuning_head_kwargs: Dict[str, Any]):
r"""
Flexible class allowing to use a custom finetuning head on top of the pretrained model.
Can be any model that inherits from nn.Module, MupMixin.
Parameters:
finetuning_head_kwargs: Key-word arguments needed to instantiate a custom (or existing) finetuning head from FINETUNING_HEADS_DICT
"""
super().__init__()
self.task = finetuning_head_kwargs.pop("task", None)
self.previous_module = finetuning_head_kwargs.pop("previous_module", "task_heads")
self.incoming_level = finetuning_head_kwargs.pop("incoming_level", "graph")
model_type = finetuning_head_kwargs.pop("model_type", "mlp")
net = FINETUNING_HEADS_DICT[model_type]
self.net = net(**finetuning_head_kwargs)
def forward(self, g: Union[Dict[str, Union[torch.Tensor, Batch]], torch.Tensor, Batch]):
if isinstance(g, (torch.Tensor, Batch)):
pass
elif isinstance(g, Dict) and len(g) == 1:
g = list(g.values())[0]
else:
raise TypeError("Output type from pretrained model not appropriate for finetuning head")
g = self.net.forward(g)
return {self.task: g}
def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: bool = False) -> Dict[str, Any]:
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter:
divide_factor: Factor by which to divide the width.
factor_in_dim: Whether to factor the input dimension
Returns:
Dictionary with the kwargs to create the base model.
"""
# For the post-nn network, all the dimension are divided
return self.net.make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=factor_in_dim)
================================================
FILE: graphium/finetuning/fingerprinting.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import torch
from collections import defaultdict
from typing import Literal, Union, List
import tqdm
from graphium.nn.architectures.global_architectures import FullGraphMultiTaskNetwork
from graphium.trainer.predictor import PredictorModule
class Fingerprinter:
"""Extract fingerprints from a [`FullGraphMultiTaskNetwork`][graphium.nn.architectures.global_architectures.FullGraphMultiTaskNetwork]
Fingerprints are hidden representations of a pre-trained network.
They can be used as an additional representation for task-specific ML models
in downstream tasks.
Info: Connection to linear probing.
This two-stage process is similar in concept to linear-probing,
but pre-computing the fingerprints further decouples the two stages
and allows for more flexibility.
Note: CLI support
You can extract fingerprints easily with the CLI. For more info, see
```sh
graphium finetune fingerprint --help
```
To return intermediate representations, the network will be altered
to save the readouts of the specified layers. This class is designed to be used
as a context manager that automatically reverts the network to its original state.
Examples:
Basic usage:
```python
# Return layer 1 of the gcn submodule.
# For a full list of options, see the `_module_map` attribute of the FullGraphMultiTaskNetwork.
fp_spec = "gcn:1"
# Create the object
fingerprinter = Fingerprinter(predictor, "gcn:1")
# Setup, run and teardown the fingerprinting process
fingerprinter.setup()
fp = fp.get_fingerprints_for_batch(batch)
fingerprinter.teardown()
```
As context manager:
```python
with Fingerprinter(predictor, "gcn:1") as fingerprinter:
fp = fp.get_fingerprints_for_batch(batch)
```
Concatenating multiple hidden representations:
```python
with Fingerprinter(predictor, ["gcn:0", "gcn:1"]) as fingerprinter:
fp = fp.get_fingerprints_for_batch(batch)
```
For an entire dataset (expects a PyG dataloader):
```python
with Fingerprinter(predictor, ["gcn:0", "gcn:1"]) as fingerprinter:
fps = fp.get_fingerprints_for_dataset(dataloader)
```
Returning numpy arrays instead of torch tensors:
```python
with Fingerprinter(predictor, ["gcn:0", "gcn:1"], out_type="numpy") as fingerprinter:
fps = fp.get_fingerprints_for_dataset(dataloader)
```
"""
def __init__(
self,
model: Union[FullGraphMultiTaskNetwork, PredictorModule],
fingerprint_spec: Union[str, List[str]],
out_type: Literal["torch", "numpy"] = "torch",
):
"""
Args:
model: Either a `FullGraphMultiTaskNetwork` or a `PredictorModule` that contains one. The benefit
of using a `PredictorModule` is that it will automatically convert the features to the correct dtype.
fingerprint_spec: The fingerprint specification. Of the format `module:layer`. If specified as a list,
the fingerprints from all the specified layers will be concatenated.
out_type: The output type of the fingerprints. Either `torch` or `numpy`.
"""
network = model
predictor = None
if isinstance(model, PredictorModule):
predictor = model
network = model.model
if not isinstance(network, FullGraphMultiTaskNetwork):
raise NotImplementedError(
f"{self.__class__.__name__} only supports fingerprints for the FullGraphMultiTaskNetwork"
)
if out_type not in ["torch", "numpy"]:
raise ValueError(f"Invalid output type {out_type}, pick from 'torch' or 'numpy'")
self.network = network
self.predictor = predictor
self._out_type = out_type
if isinstance(fingerprint_spec, str):
fingerprint_spec = [fingerprint_spec]
self._spec = defaultdict(list)
for spec_str in fingerprint_spec:
module_name, layer = spec_str.split(":")
self._spec[module_name].append(int(layer.strip()))
self._module_map_backup = None
def setup(self):
"""Prepare the network for fingerprinting"""
if hasattr(self.network, "_module_map"):
self._module_map_backup = self.network._module_map
self.network._enable_readout_cache(list(self._spec.keys()))
return self
def get_fingerprints_for_batch(self, batch):
"""Get the fingerprints for a single batch"""
if not self.network._cache_readouts:
raise RuntimeError(
f"To use {self.__class__.__name__}, you must enable readout caching in the network. "
f"Alternatively, you can use the {self.__class__.__name__} as a context manager "
"to automatically setup the network for fingerprinting and revert the changes afterwards"
)
# Run the batch through the model.
with torch.inference_mode():
if self.predictor is not None:
batch["features"] = self.predictor._convert_features_dtype(batch["features"])
self.network(batch["features"])
readout_list = []
for module_name, layers in self._spec.items():
readout_list.extend(
[self.network._module_map[module_name]._readout_cache[layer].cpu() for layer in layers]
)
feats = torch.cat(readout_list, dim=-1)
return self._convert_output_type(feats)
def get_fingerprints_for_dataset(self, dataloader):
"""Return the fingerprints for an entire dataset"""
original_out_type = self._out_type
self._out_type = "torch"
fps = []
for batch in tqdm.tqdm(dataloader, desc="Fingerprinting batches"):
feats = self.get_fingerprints_for_batch(batch)
fps.append(feats)
fps = torch.cat(fps, dim=0)
self._out_type = original_out_type
return self._convert_output_type(fps)
def teardown(self):
"""Restore the network to its original state"""
self.network._disable_readout_cache()
if self._module_map_backup is not None:
self.network._module_map = self._module_map_backup
else:
del self.network._module_map
return self
def __enter__(self):
"""Setup the network for fingerprinting"""
return self.setup()
def __exit__(self, exc_type, exc_val, exc_tb):
"""Revert the network to its original state"""
if exc_type is not None:
raise exc_type(exc_val)
return self.teardown()
def _convert_output_type(self, feats: torch.Tensor):
"""Small utility function to convert output types"""
if self._out_type == "numpy":
feats = feats.numpy()
return feats
================================================
FILE: graphium/finetuning/utils.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from copy import deepcopy
from typing import Any, Dict, List, Union
from loguru import logger
from graphium.trainer import PredictorModule
import graphium
def filter_cfg_based_on_admet_benchmark_name(config: Dict[str, Any], names: Union[List[str], str]):
"""
Filter a base config for the full TDC ADMET benchmarking group to only
have settings related to a subset of the endpoints
"""
if config["datamodule"]["module_type"] != "ADMETBenchmarkDataModule":
# NOTE (cwognum): For now, this implies we only support the ADMET benchmark from TDC.
# It is easy to extend this in the future to support more datasets.
raise ValueError("You can only use this method for the `ADMETBenchmarkDataModule`")
if isinstance(names, str):
names = [names]
def _filter(d):
return {k: v for k, v in d.items() if k in names}
cfg = deepcopy(config)
# Update the datamodule arguments
cfg["datamodule"]["args"]["tdc_benchmark_names"] = names
# Filter the relevant config sections
if "architecture" in cfg and "task_heads" in cfg["architecture"]:
cfg["architecture"]["task_heads"] = _filter(cfg["architecture"]["task_heads"])
if "predictor" in cfg and "metrics_on_progress_bar" in cfg["predictor"]:
cfg["predictor"]["metrics_on_progress_bar"] = _filter(cfg["predictor"]["metrics_on_progress_bar"])
if "predictor" in cfg and "loss_fun" in cfg["predictor"]:
cfg["predictor"]["loss_fun"] = _filter(cfg["predictor"]["loss_fun"])
if "metrics" in cfg:
cfg["metrics"] = _filter(cfg["metrics"])
return cfg
def modify_cfg_for_finetuning(cfg: Dict[str, Any]):
"""
Function combining information from configuration and pretrained model for finetuning.
"""
task = cfg["finetuning"]["task"]
# Filter the config based on the task name
# NOTE (cwognum): This prevents the need for having many different files for each of the tasks
# with lots and lots of config repetition.
cfg = filter_cfg_based_on_admet_benchmark_name(cfg, task)
cfg_finetune = cfg["finetuning"]
# Load pretrained model
pretrained_model = cfg_finetune["pretrained_model"]
pretrained_predictor = PredictorModule.load_pretrained_model(pretrained_model, device="cpu")
# Inherit shared configuration from pretrained
# Architecture
pretrained_architecture = pretrained_predictor.model_kwargs
arch_keys = pretrained_architecture.keys()
arch_keys = [key.replace("_kwargs", "") for key in arch_keys]
cfg_arch = {arch_keys[idx]: value for idx, value in enumerate(pretrained_architecture.values())}
cfg_arch_from_pretrained = deepcopy(cfg_arch)
# Featurization
cfg["datamodule"]["args"]["featurization"] = pretrained_predictor.featurization
finetuning_module = cfg_finetune["finetuning_module"]
sub_module_from_pretrained = cfg_finetune.get("sub_module_from_pretrained", None)
new_sub_module = cfg_finetune.pop("new_sub_module", None)
keep_modules_after_finetuning_module = cfg_finetune.pop("keep_modules_after_finetuning_module", None)
# Find part of config of module to finetune from
pretrained_predictor.model.create_module_map()
module_map_from_pretrained = pretrained_predictor.model._module_map
if not any([module.startswith(finetuning_module) for module in module_map_from_pretrained.keys()]):
raise ValueError("Unkown module {finetuning_module}")
elif sub_module_from_pretrained is None:
new_module_kwargs = deepcopy(cfg_arch[finetuning_module])
else:
new_module_kwargs = deepcopy(cfg_arch[finetuning_module][sub_module_from_pretrained])
# Modify config according to desired finetuning architecture
out_dim = (
cfg_arch[finetuning_module].get("out_dim")
if sub_module_from_pretrained is None
else cfg_arch[finetuning_module][sub_module_from_pretrained].get("out_dim")
)
if new_module_kwargs["depth"] is None:
new_module_kwargs["depth"] = len(new_module_kwargs["hidden_dims"]) + 1
upd_kwargs = {
"out_dim": cfg_finetune.pop("new_out_dim", out_dim),
"depth": new_module_kwargs["depth"]
+ cfg_finetune.get("added_depth", 0)
- cfg_finetune.pop("drop_depth", 0),
}
new_last_activation = cfg_finetune.pop("new_last_activation", None)
if new_last_activation is not None:
upd_kwargs["last_activation"] = new_last_activation
# Update config
new_module_kwargs.update(upd_kwargs)
if sub_module_from_pretrained is None:
cfg_arch[finetuning_module] = new_module_kwargs
else:
cfg_arch[finetuning_module] = {new_sub_module: new_module_kwargs}
# Remove modules of pretrained model after module to finetune from unless specified differently
module_list = list(module_map_from_pretrained.keys())
super_module_list = []
for module in module_list:
if module.split("-")[0] not in super_module_list: # Only add each supermodule once
super_module_list.append(module.split("-")[0])
# Set configuration of modules after finetuning module to None
cutoff_idx = (
super_module_list.index(finetuning_module) + 1
) # Index of module after module to finetune from
for module in super_module_list[cutoff_idx:]:
cfg_arch[module] = None
# If desired, we can keep specific modules after the finetuning module (specified in cfg/finetuning/keep_modules_after_finetuning_module)
if keep_modules_after_finetuning_module is not None:
for module_name, updates in keep_modules_after_finetuning_module.items():
cfg_arch = update_cfg_arch_for_module(cfg_arch, cfg_arch_from_pretrained, module_name, updates)
# Change architecture to FullGraphFinetuningNetwork
cfg_arch["model_type"] = "FullGraphFinetuningNetwork"
cfg["architecture"] = cfg_arch
pretrained_overwriting_kwargs = deepcopy(cfg["finetuning"])
drop_keys = [
"task",
"level",
"finetuning_head",
"unfreeze_pretrained_depth",
"epoch_unfreeze_all",
]
for key in drop_keys:
pretrained_overwriting_kwargs.pop(key, None)
finetuning_training_kwargs = deepcopy(cfg["finetuning"])
drop_keys = ["task", "level", "pretrained_model", "sub_module_from_pretrained", "finetuning_head"]
for key in drop_keys:
finetuning_training_kwargs.pop(key, None)
cfg["finetuning"].update(
{"overwriting_kwargs": pretrained_overwriting_kwargs, "training_kwargs": finetuning_training_kwargs}
)
return cfg
def update_cfg_arch_for_module(
cfg_arch: Dict[str, Any],
cfg_arch_from_pretrained: Dict[str, Any],
module_name: str,
updates: Dict[str, Any],
):
"""
Function to modify the key-word arguments of modules after the finetuning module if they are kept.
Parameters:
cfg_arch: Configuration of the architecture of the model used for finetuning
cfg_arch_from_pretrained: Configuration of the architecture of the loaded pretrained model
module_name: Module of loaded pretrained model
updates: Changes to apply to key-work arguments of selected module
"""
# We need to distinguish between modules with & without submodules
if "-" not in module_name:
if cfg_arch[module_name] is None:
cfg_arch[module_name] = {}
cfg_arch_from_pretrained[module_name].update({key: value for key, value in updates.items()})
cfg_arch.update({module_name, cfg_arch_from_pretrained})
else:
module_name, sub_module = module_name.split("-")
new_sub_module = updates.pop("new_sub_module", sub_module)
if cfg_arch[module_name] is None:
cfg_arch[module_name] = {}
cfg_arch_from_pretrained[module_name][sub_module].update(
{key: value for key, value in updates.items()}
)
cfg_arch[module_name].update({new_sub_module: cfg_arch_from_pretrained[module_name][sub_module]})
return cfg_arch
================================================
FILE: graphium/hyper_param_search/__init__.py
================================================
from .results import extract_main_metric_for_hparam_search
HYPER_PARAM_SEARCH_CONFIG_KEY = "hyper_param_search"
================================================
FILE: graphium/hyper_param_search/results.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
_OBJECTIVE_KEY = "objective"
def extract_main_metric_for_hparam_search(results: dict, cfg: dict):
"""Processes the results in the context of a hyper-parameter search."""
# Extract the objectives
objectives = cfg[_OBJECTIVE_KEY]
if isinstance(objectives, str):
objectives = [objectives]
# Extract the objective values
objective_values = [results[k] for k in objectives]
if len(objective_values) == 1:
objective_values = objective_values[0]
return objective_values
================================================
FILE: graphium/ipu/README.md
================================================
The Graph Of LIfe Library.
## What is in this folder?
code for IPU acceleration support
- `ipu_dataloader.py`: code for handling dataloader on IPU
- `ipu_losses.py`: code for computing losses on IPU
- `ipu_simple_lightning.py`: code for pytorch lightning support on IPU
- `ipu_utils.py`: utils functions for IPU
- `ipu_wrapper.py`: wrapper code for IPU support
================================================
FILE: graphium/ipu/__init__.py
================================================
================================================
FILE: graphium/ipu/ipu_dataloader.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Callable, Iterable, Optional, List, Tuple, Dict, Any, Union
from copy import deepcopy
from dataclasses import dataclass
import numpy as np
from loguru import logger
from torch import Tensor
import torch
from torch_geometric.data import Data, Batch, Dataset
from torch_geometric.transforms import BaseTransform
from graphium.data.utils import get_keys
from graphium.ipu.ipu_utils import import_poptorch
from graphium.utils.packing import (
fast_packing,
hybrid_packing,
get_pack_sizes,
node_to_pack_indices_mask,
estimate_max_pack_node_size,
)
@dataclass
class IPUDataloaderOptions:
r"""
This data class stores the arguments necessary to instantiate a model for the Predictor.
Parameters:
model_class:
pytorch module used to create a model
model_kwargs:
Key-word arguments used to initialize the model from `model_class`.
"""
batch_size: int
max_num_nodes: Optional[int] = None
max_num_nodes_per_graph: Optional[int] = None
max_num_edges: Optional[int] = None
max_num_edges_per_graph: Optional[int] = None
mode: "poptorch.DataLoaderMode" = "Sync"
def set_kwargs(self):
# Get the maximum number of nodes
if self.max_num_nodes is not None:
assert (
self.max_num_nodes_per_graph is None
), "Cannot use `max_num_nodes` and `max_num_nodes_per_graph` simultaneously"
elif self.max_num_nodes_per_graph is not None:
assert (
self.max_num_nodes is None
), "Cannot use `max_num_nodes` and `max_num_nodes_per_graph` simultaneously"
self.max_num_nodes = self.max_num_nodes_per_graph * self.batch_size
else:
raise ValueError("Must provide either `max_num_nodes` or `max_num_nodes_per_graph`")
# Get the maximum number of edges
if self.max_num_edges is not None:
assert (
self.max_num_edges_per_graph is None
), "Cannot use `max_num_edges` and `max_num_edges_per_graph` simultaneously"
elif self.max_num_edges_per_graph is not None:
assert (
self.max_num_edges is None
), "Cannot use `max_num_edges` and `max_num_edges_per_graph` simultaneously"
self.max_num_edges = self.max_num_edges_per_graph * self.batch_size
else:
raise ValueError("Must provide either `max_num_nodes` or `max_num_nodes_per_graph`")
# poptorch mode
poptorch = import_poptorch()
if isinstance(self.mode, str):
if self.mode.lower() == "sync":
self.mode = poptorch.DataLoaderMode.Sync
elif self.mode.lower() == "async":
self.mode = poptorch.DataLoaderMode.Async
elif self.mode.lower() == "asyncrebatched":
self.mode = poptorch.DataLoaderMode.AsyncRebatched
else:
raise ValueError(f"`{self.mode}` not a valid parameter.")
class CombinedBatchingCollator:
"""
Collator object that manages the combined batch size defined as:
combined_batch_size = batch_size * device_iterations
* replication_factor * gradient_accumulation
This is intended to be used in combination with the poptorch.DataLoader
"""
def __init__(
self,
batch_size: int,
max_num_nodes: int,
max_num_edges: int,
dataset_max_nodes_per_graph: int,
dataset_max_edges_per_graph: int,
collate_fn: Optional[Callable] = None,
):
"""
Parameters:
batch_size: mini batch size used by the model
max_num_nodes: Maximum number of nodes in the batched padded graph
max_num_edges: Maximum number of edges in the batched padded graph
dataset_max_nodes_per_graph: Maximum number of nodes per graph in the full dataset
dataset_max_edges_per_graph: Maximum number of edges per graph in the full dataset
collate_fn: Function used to collate (or batch) the single data or graphs together
"""
super().__init__()
self.batch_size = batch_size
self.collate_fn = collate_fn
self.max_num_nodes = max_num_nodes
self.max_num_edges = max_num_edges
self.dataset_max_nodes_per_graph = dataset_max_nodes_per_graph
self.dataset_max_edges_per_graph = dataset_max_edges_per_graph
def __call__(
self, batch: List[Dict[str, Union[Data, Dict[str, Tensor]]]]
) -> Dict[str, Union[Batch, Dict[str, Tensor], Any]]:
"""
Stack tensors, batch the pyg graphs, and pad each tensor to be same size.
Parameters:
batch: The batch of data, including pyg-graphs `Data` and labels `Dict[str, Tensor]` to be padded
Returns:
out_batch: A dictionary where the graphs are batched and the labels or other Tensors are stacked
"""
# Sort the batch such that large graphs are paired with small graphs
num_nodes = [b["features"].num_nodes for b in batch]
packed_indices = hybrid_packing(num_nodes, batch_size=self.batch_size)
packs = [[batch[idx] for idx in pack] for pack in packed_indices]
# Loop all mini-batches within the global batch
all_batches = []
for pack in packs:
if self.collate_fn != None:
local_batch = self.collate_fn(pack)
transform = Pad(
max_num_nodes=self.max_num_nodes,
max_num_edges=self.max_num_edges,
dataset_max_nodes_per_graph=self.dataset_max_nodes_per_graph,
dataset_max_edges_per_graph=self.dataset_max_edges_per_graph,
)
local_batch["features"] = transform(local_batch["features"])
local_batch["labels"] = transform(local_batch["labels"])
all_batches.append(local_batch)
out_batch = {}
# Stack tensors in the first dimension to allow IPUs to differentiate between local and global graph
all_keys = get_keys(all_batches[0]["labels"])
out_batch["labels"] = {
key: torch.stack([this_batch["labels"][key] for this_batch in all_batches], 0) for key in all_keys
}
out_graphs = [this_batch["features"] for this_batch in all_batches]
stacked_features = deepcopy(out_graphs[0])
for key, val in out_graphs[0].items():
if isinstance(val, torch.Tensor):
stacked_features[key] = torch.stack([this_graph[key] for this_graph in out_graphs], dim=0)
out_batch["features"] = stacked_features
for key in all_batches[0].keys():
if key not in ("features", "labels"):
out_batch[key] = [this_batch[key] for this_batch in all_batches]
#
for data_key, data_val in out_batch.items():
if isinstance(data_val, Batch):
for sub_key, sub_val in data_val.items():
if isinstance(sub_val, Tensor) and sub_val.dtype == torch.int64:
out_batch[data_key][sub_key] = sub_val.to(torch.int32)
return out_batch
def create_ipu_dataloader(
dataset: Dataset,
ipu_dataloader_options: IPUDataloaderOptions,
ipu_options: Optional["poptorch.Options"] = None,
batch_size: Optional[int] = 1,
collate_fn=None,
num_workers: Optional[int] = 0,
**kwargs,
) -> "poptorch.DataLoader":
"""
Creates a poptorch.DataLoader for graph datasets
Applies the mini-batching method of concatenating multiple graphs into a
single graph with multiple disconnected subgraphs. See:
https://pytorch-geometric.readthedocs.io/en/2.0.2/notes/batching.html
Parameters:
dataset: The torch_geometric.data.Dataset instance from which to
load the graph examples for the IPU.
ipu_dataloader_options: The options to initialize the Dataloader for IPU
ipu_options: The poptorch.Options used by the
poptorch.DataLoader. Will use the default options if not provided.
batch_size: How many graph examples to load in each batch
(default: 1).
collate_fn: The function used to collate batches
**kwargs (optional): Additional arguments of :class:`poptorch.DataLoader`.
Returns:
The dataloader
"""
poptorch = import_poptorch()
if ipu_options is None:
# Create IPU default options
ipu_options = poptorch.Options()
# Define the collater function
collater = CombinedBatchingCollator(
batch_size,
collate_fn=collate_fn,
max_num_nodes=ipu_dataloader_options.max_num_nodes,
max_num_edges=ipu_dataloader_options.max_num_edges,
dataset_max_nodes_per_graph=dataset.max_num_nodes_per_graph,
dataset_max_edges_per_graph=dataset.max_num_edges_per_graph,
)
# Get the global batch size
num_nodes = np.asarray(dataset.num_nodes_list)
accum = ipu_options.Training.gradient_accumulation
repli = ipu_options._values["replication_factor"]
device_iter = ipu_options._values["device_iterations"]
combined_batch_size = batch_size * accum * repli * device_iter
num_batches = len(dataset) // combined_batch_size
num_workers = min(num_batches, num_workers)
buffer_size = num_batches // num_workers if num_workers > 0 else None
buffer_size = 3 if buffer_size is None else buffer_size
async_options = {
"sharing_strategy": poptorch.SharingStrategy.ForkServer,
"early_preload": True,
"buffer_size": buffer_size,
"load_indefinitely": True,
"miss_sleep_time_in_ms": 0,
}
# Estimate the packing size needed
max_pack_size, max_pack_size_per_graph = 0, 0
for _ in range(4):
this_max_pack_size, this_max_pack_size_per_graph = estimate_max_pack_node_size(
num_nodes=num_nodes,
batch_size=batch_size,
combined_batch_size=combined_batch_size,
)
max_pack_size = max(max_pack_size, this_max_pack_size)
max_pack_size_per_graph = max(max_pack_size_per_graph, this_max_pack_size_per_graph)
max_num_nodes = collater.max_num_nodes
# Log the estimated pack size, with warnings if too big or too small
logger.info(
f"Estimating pack max_pack_size={max_pack_size} or max_pack_size_per_graph={max_pack_size_per_graph}"
)
logger.info(f"Provided `max_num_nodes={max_num_nodes}`")
if max_pack_size > max_num_nodes - 10:
logger.warning(
f"The value of `max_num_nodes={max_num_nodes}` seems to be insufficient compared to `max_pack_size={max_pack_size}` and will likely crash"
)
elif max_pack_size < max_num_nodes - 20:
logger.warning(
f"The value of `max_num_nodes={max_num_nodes}` seems to be large compared to `max_pack_size={max_pack_size}` and will likely waste memory"
)
return poptorch.DataLoader(
options=deepcopy(ipu_options),
dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
collate_fn=collater,
async_options=async_options,
**kwargs,
)
class Pad(BaseTransform):
"""
Data transform that applies padding to enforce consistent tensor shapes.
"""
def __init__(
self,
max_num_nodes: int,
dataset_max_nodes_per_graph,
dataset_max_edges_per_graph,
max_num_edges: Optional[int] = None,
node_value: float = 0,
edge_value: float = 0,
):
"""
Parameters:
max_num_nodes: The maximum number of nodes for the total padded graph
dataset_max_nodes_per_graph: the maximum number of nodes per graph in the dataset
dataset_max_edges_per_graph: the maximum number of edges per graph in the dataset
max_num_edges: The maximum number of edges for the total padded graph
node_value: Value to add to the node padding
edge_value: Value to add to the edge padding
"""
super().__init__()
self.max_num_nodes = max_num_nodes
self.dataset_max_nodes_per_graph = dataset_max_nodes_per_graph
self.dataset_max_edges_per_graph = dataset_max_edges_per_graph
if max_num_edges:
self.max_num_edges = max_num_edges
else:
# Assume fully connected graph
self.max_num_edges = max_num_nodes * (max_num_nodes - 1)
self.node_value = node_value
self.edge_value = edge_value
def validate(self, data):
"""
Validates that the input graph does not exceed the constraints that:
* the number of nodes must be <= max_num_nodes
* the number of edges must be <= max_num_edges
Returns:
Tuple containing the number nodes and the number of edges
"""
num_nodes = data.num_nodes
num_edges = data.num_edges
assert num_nodes <= self.max_num_nodes, (
f"Too many nodes. Graph has {num_nodes} nodes " f"and max_num_nodes is {self.max_num_nodes}."
)
assert num_edges <= self.max_num_edges, (
f"Too many edges. Graph has {num_edges} edges defined "
f"and max_num_edges is {self.max_num_edges}."
)
return num_nodes, num_edges
def __call__(self, batch: Batch) -> Batch:
return self._call(batch)
def forward(self, batch: Batch) -> Batch:
return self._call(batch)
def _call(self, batch: Batch) -> Batch:
"""
Pad the batch with a fake graphs that has the desired
number of nodes and edges.
"""
num_nodes, num_edges = self.validate(batch)
num_pad_nodes = self.max_num_nodes - num_nodes
num_pad_edges = self.max_num_edges - num_edges
# Create a copy to update with padded features
new_batch = deepcopy(batch)
real_graphs = new_batch.to_data_list()
for g in real_graphs:
g.graph_is_true = torch.tensor([1], dtype=bool)
g.node_is_true = torch.full([g.num_nodes], True, dtype=bool)
g.edge_is_true = torch.full([g.num_edges], True, dtype=bool)
# create fake graph with the needed # of nodes and edges
fake = Data()
fake.num_nodes = num_pad_nodes
fake.num_edges = num_pad_edges
fake.graph_is_true = torch.tensor([False], dtype=bool)
fake.node_is_true = torch.full([num_pad_nodes], False, dtype=bool)
fake.edge_is_true = torch.full([num_pad_edges], False, dtype=bool)
for key, value in real_graphs[0]:
if not torch.is_tensor(value):
continue
if key == "graph_is_true" or key == "node_is_true" or key == "edge_is_true":
continue
dim = real_graphs[0].__cat_dim__(key, value)
pad_shape = list(value.shape)
if batch.is_node_attr(key):
pad_shape[dim] = num_pad_nodes
pad_value = self.node_value
elif batch.is_edge_attr(key):
pad_shape[dim] = num_pad_edges
if key == "edge_index":
# Padding edges are self-loops on the first padding node
pad_value = 0
else:
pad_value = self.edge_value
# identify graph attributes, pad nan label for the fake graph
elif key.startswith("graph_"):
num_pad_graphs = 1 # we pad with one big fake graph
pad_shape[dim] = num_pad_graphs
pad_value = float("nan")
else:
continue
pad_value = value.new_full(pad_shape, pad_value)
fake[key] = torch.cat([pad_value], dim=dim)
real_graphs.append(fake)
new_batch = Batch.from_data_list(real_graphs)
if "num_nodes" in new_batch:
new_batch.num_nodes = self.max_num_nodes
return new_batch
def __repr__(self) -> str:
s = f"{self.__class__.__name__}("
s += f"max_num_nodes={self.max_num_nodes}, "
s += f"max_num_edges={self.max_num_edges}, "
s += f"node_value={self.node_value}, "
s += f"edge_value={self.edge_value})"
return s
================================================
FILE: graphium/ipu/ipu_losses.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import torch
from torch import Tensor
from torch.nn import BCELoss, BCEWithLogitsLoss, MSELoss, L1Loss
from torch._C import _infer_size
from loguru import logger
from graphium.trainer.losses import HybridCELoss
class BCEWithLogitsLossIPU(BCEWithLogitsLoss):
"""
A modified version of the `torch.nn.BCEWithLogitsLoss` that can ignore NaNs
by giving them a weight of `0`. This allows it to work with compilation
and IPUs since it doesn't modify the tensor's shape.
"""
def forward(self, input: Tensor, target: Tensor) -> Tensor:
prev_weight = None
target = target.clone().to(input.dtype)
weight = self.weight
# Get the original weight matrix. If None, set all weights = 1
if weight is not None:
prev_weight = self.weight.clone()
new_size = _infer_size(target.size(), weight.size())
weight = weight.expand(new_size).clone()
else:
weight = torch.ones(target.shape, dtype=input.dtype, device=input.device)
# Replace the nan-targets by 0 or 1. Take the value closest to the input.
# Give a weight of 0 where there are nan-targets
nan_targets = target.isnan()
nan_targets_0 = (input < 0.5) & nan_targets
nan_targets_1 = (input >= 0.5) & nan_targets
target[nan_targets_0] = 0.0
target[nan_targets_1] = 1.0
weight[nan_targets] = 0.0
# Compute the loss, and rescale by the number of nan elements
self.weight = weight
loss = super().forward(input, target)
num_real_targets = (~nan_targets).sum()
factor1 = torch.where(num_real_targets > 0, 1, 0)
factor2 = torch.where(num_real_targets > 0, 0, 1)
loss = factor1 * loss * nan_targets.numel() / (num_real_targets + factor2)
# Reset the self.weight to its original value
self.weight = prev_weight
return loss
class BCELossIPU(BCELoss):
"""
A modified version of the `torch.nn.BCELoss` that can ignore NaNs
by giving them a weight of `0`. This allows it to work with compilation
and IPUs since it doesn't modify the tensor's shape.
"""
def forward(self, input: Tensor, target: Tensor) -> Tensor:
prev_weight = None
target = target.clone().to(input.dtype)
weight = self.weight
# Get the original weight matrix. If None, set all weights = 1
if weight is not None:
prev_weight = self.weight.clone()
new_size = _infer_size(target.size(), weight.size())
weight = weight.expand(new_size).clone()
else:
weight = torch.ones(target.shape, dtype=input.dtype, device=input.device)
# Replace the nan-targets by 0 or 1. Take the value closest to the input.
# Give a weight of 0 where there are nan-targets
nan_targets = target.isnan()
nan_targets_0 = (input < 0.5) & nan_targets
nan_targets_1 = (input >= 0.5) & nan_targets
target[nan_targets_0] = 0.0
target[nan_targets_1] = 1.0
weight[nan_targets] = 0.0
# Compute the loss, and rescale by the number of nan elements
self.weight = weight
loss = super().forward(input, target)
num_real_targets = (~nan_targets).sum()
factor1 = torch.where(num_real_targets > 0, 1, 0)
factor2 = torch.where(num_real_targets > 0, 0, 1)
loss = factor1 * loss * nan_targets.numel() / (num_real_targets + factor2)
# Reset the self.weight to its original value
self.weight = prev_weight
return loss
class MSELossIPU(MSELoss):
"""
A modified version of the `torch.nn.MSELoss` that can ignore NaNs
by giving them the same value for both `input` and `target`.
This allows it to work with compilation
and IPUs since it doesn't modify the tensor's shape.
"""
def forward(self, input: Tensor, target: Tensor) -> Tensor:
target = target.clone().to(input.dtype)
input = input.clone()
# Replace the nan-targets in the input/target tensors by 0
nan_targets = target.isnan()
input[nan_targets] = 0.0
target[nan_targets] = 0.0
# Compute the loss, and rescale by the number of nan elements
loss = super().forward(input, target)
num_real_targets = (~nan_targets).sum()
factor1 = torch.where(num_real_targets > 0, 1, 0)
factor2 = torch.where(num_real_targets > 0, 0, 1)
loss = factor1 * loss * nan_targets.numel() / (num_real_targets + factor2)
return loss
class L1LossIPU(L1Loss):
"""
A modified version of the `torch.nn.L1Loss` that can ignore NaNs
by giving them the same value for both `input` and `target`.
This allows it to work with compilation
and IPUs since it doesn't modify the tensor's shape.
"""
def forward(self, input: Tensor, target: Tensor) -> Tensor:
target = target.clone().to(input.dtype)
input = input.clone()
# Replace the nan-targets in the input/target tensors by 0
nan_targets = target.isnan()
input[nan_targets] = 0.0
target[nan_targets] = 0.0
# Compute the loss, and rescale by the number of nan elements
loss = super().forward(input, target)
num_real_targets = (~nan_targets).sum()
factor1 = torch.where(num_real_targets > 0, 1, 0)
factor2 = torch.where(num_real_targets > 0, 0, 1)
loss = factor1 * loss * nan_targets.numel() / (num_real_targets + factor2)
return loss
class HybridCELossIPU(HybridCELoss):
def __init__(
self,
n_brackets,
alpha: float = 0.5,
) -> None:
"""
Parameters:
n_brackets: the number of brackets that will be used to group the regression targets.
Expected to have the same size as the number of classes in the transformed regression task.
"""
super().__init__(n_brackets=n_brackets, alpha=alpha)
def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""
Parameters:
input: (batch_size x n_classes) tensor of logits predicted for each bracket.
target: (batch_size) or (batch_size, 1) tensor of target brackets in {0, 1, ..., self.n_brackets}.
"""
target = target.clone().to(input.dtype)
input = input.clone()
# Replace the nan-targets in the input/target tensors by 0
nan_targets = target.isnan()
# Compute the loss, and rescale by the number of nan elements
loss = super().forward(input, target, nan_targets)
return loss
================================================
FILE: graphium/ipu/ipu_metrics.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Optional, Tuple, Sequence, Literal
import torch
from torch import BoolTensor, IntTensor, Tensor
from torchmetrics.functional import auroc, average_precision, pearson_corrcoef, r2_score
from torchmetrics.utilities.checks import _input_squeeze
from torchmetrics.functional.classification.accuracy import (
_mode,
_check_subset_validity,
_accuracy_compute,
_accuracy_update,
)
from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute
from torchmetrics.functional.classification.f_beta import _fbeta_compute
from torchmetrics.functional import mean_squared_error, mean_absolute_error
from torchmetrics.utilities.checks import _input_squeeze
from torchmetrics.utilities.enums import AverageMethod
from graphium.utils.tensor import nan_mean
from graphium.ipu.ipu_utils import import_poptorch
def auroc_ipu(
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
task: Optional[Literal["binary", "multiclass", "multilabel"]] = None,
pos_label: Optional[int] = None,
average: Optional[str] = "macro",
max_fpr: Optional[float] = None,
sample_weights: Optional[Sequence] = None,
):
"""
A modified version of the `torchmetrics.functional.auroc` that can ignore NaNs
by giving them the same value for both `preds` and `target`.
This allows it to work with compilation
and IPUs since it doesn't modify the tensor's shape.
"""
target = target.clone()
preds = preds.clone()
# Replace the nan-targets in the preds/target tensors by 0
nan_targets = target.isnan()
preds[nan_targets] = 0.0
target[nan_targets] = 0.0
# Get the original weight matrix. If None, set all weights = 1
if sample_weights is None:
sample_weights = torch.ones(target.shape[0], dtype=preds.dtype, device=preds.device)
sample_weights[nan_targets] = 0.0
# Compute the loss, and rescale by the number of nan elements
score = auroc(
preds=preds,
target=target.to(int),
num_classes=num_classes,
task=task,
pos_label=pos_label,
average=average,
max_fpr=max_fpr,
sample_weights=sample_weights,
)
return score
def average_precision_ipu(
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
task: Optional[Literal["binary", "multiclass", "multilabel"]] = None,
ignore_index: Optional[int] = None,
pos_label: Optional[int] = None,
average: Optional[str] = "macro",
sample_weights: Optional[Sequence] = None,
):
"""
A modified version of the `torchmetrics.functional.average_precision` that can ignore NaNs
by giving them the same value for both `preds` and `target`.
This allows it to work with compilation
and IPUs since it doesn't modify the tensor's shape.
"""
target = target.clone()
preds = preds.clone()
# Replace the nan-targets in the preds/target tensors by 0
# Average precision is not sensitive to true negatives
nan_targets = target.isnan()
preds[nan_targets] = 0.0
target[nan_targets] = 0.0
# No need to use sample weights (which is no longer supported in torchmetrics >=0.10)
# # Get the original weight matrix. If None, set all weights = 1
# if sample_weights is None:
# sample_weights = torch.ones(target.shape[0], dtype=preds.dtype, device=preds.device)
# sample_weights[nan_targets] = 0.0
# Compute the loss, and rescale by the number of nan elements
score = average_precision(
preds=preds,
target=target,
num_classes=num_classes,
task=task,
ignore_index=ignore_index,
pos_label=pos_label,
average=average,
# sample_weights=sample_weights,
)
return score
def precision_ipu(
preds: Tensor,
target: Tensor,
average: Optional[str] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
num_classes: Optional[int] = None,
threshold: float = 0.5,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
):
"""
A modified version of the `torchmetrics.functional.precision` that can ignore NaNs
by giving them the same value for both `preds` and `target`.
This allows it to work with compilation
and IPUs since it doesn't modify the tensor's shape.
"""
(tp, fp, tn, fn), mode = get_confusion_matrix(
preds=preds,
target=target,
average=average,
mdmc_average=mdmc_average,
threshold=threshold,
top_k=top_k,
subset_accuracy=False,
num_classes=num_classes,
multiclass=multiclass,
ignore_index=ignore_index,
)
return _precision_compute(tp, fp, fn, average, mdmc_average)
def recall_ipu(
preds: Tensor,
target: Tensor,
average: Optional[str] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
num_classes: Optional[int] = None,
threshold: float = 0.5,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
):
"""
A modified version of the `torchmetrics.functional.recall` that can ignore NaNs
by giving them the same value for both `preds` and `target`.
This allows it to work with compilation
and IPUs since it doesn't modify the tensor's shape.
"""
(tp, fp, tn, fn), mode = get_confusion_matrix(
preds=preds,
target=target,
average=average,
mdmc_average=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
multiclass=multiclass,
ignore_index=ignore_index,
)
return _recall_compute(tp, fp, fn, average, mdmc_average)
def accuracy_ipu(
preds: Tensor,
target: Tensor,
average: Optional[str] = "micro",
mdmc_average: Optional[str] = "global",
threshold: float = 0.5,
top_k: Optional[int] = None,
subset_accuracy: bool = False,
num_classes: Optional[int] = None,
multiclass: Optional[bool] = None,
ignore_index: Optional[int] = None,
) -> Tensor:
"""
A modified version of the `torchmetrics.functional.accuracy` that can ignore NaNs
by giving them the same value for both `preds` and `target`.
This allows it to work with compilation
and IPUs since it doesn't modify the tensor's shape.
Args:
preds: Predictions from model (probabilities, logits or labels)
target: Ground truth labels
average:
Defines the reduction that is applied. Should be one of the following:
- ``'micro'`` [default]: Calculate the metric globally, across all samples and classes.
- ``'macro'``: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.
- ``'samples'``: Calculate the metric for each sample, and average the metrics
across samples (with equal weights for each sample).
.. note:: What is considered a sample in the multi-dimensional multi-class case
depends on the value of ``mdmc_average``.
.. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``,
the value for the class will be ``nan``.
mdmc_average:
Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
``average`` parameter). Should be one of the following:
- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class.
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`pages/classification:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
num_classes:
Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
threshold:
Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
top_k:
Number of the highest probability or logit score predictions considered finding the correct label,
relevant only for (multi-dimensional) multi-class inputs. The
default value (``None``) will be interpreted as 1 for these inputs.
Should be left at default (``None``) for all other types of inputs.
multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section `
for a more detailed explanation and examples.
ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
or ``'none'``, the score for the ignored class will be returned as ``nan``.
subset_accuracy:
Whether to compute subset accuracy for multi-label and multi-dimensional
multi-class inputs (has no effect for other input types).
- For multi-label inputs, if the parameter is set to ``True``, then all labels for
each sample must be correctly predicted for the sample to count as correct. If it
is set to ``False``, then all labels are counted separately - this is equivalent to
flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``).
- For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all
sub-sample (on the extra axis) must be correct for the sample to be counted as correct.
If it is set to ``False``, then all sub-samples are counter separately - this is equivalent,
in the case of label predictions, to flattening the inputs beforehand (i.e.
``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
still applies in both cases, if set.
Raises:
ValueError:
If ``top_k`` parameter is set for ``multi-label`` inputs.
ValueError:
If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``.
ValueError:
If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``.
ValueError:
If ``average`` is set but ``num_classes`` is not provided.
ValueError:
If ``num_classes`` is set
and ``ignore_index`` is not in the range ``[0, num_classes)``.
ValueError:
If ``top_k`` is not an ``integer`` larger than ``0``.
"""
(tp, fp, tn, fn), mode = get_confusion_matrix(
preds=preds,
target=target,
average=average,
mdmc_average=mdmc_average,
threshold=threshold,
top_k=top_k,
subset_accuracy=subset_accuracy,
num_classes=num_classes,
multiclass=multiclass,
ignore_index=ignore_index,
)
return _accuracy_compute(tp, fp, tn, fn, average, mdmc_average, mode)
def get_confusion_matrix(
preds: Tensor,
target: Tensor,
average: Optional[str] = "micro",
mdmc_average: Optional[str] = "global",
threshold: float = 0.5,
top_k: Optional[int] = None,
subset_accuracy: bool = False,
num_classes: Optional[int] = None,
multiclass: Optional[bool] = None,
ignore_index: Optional[int] = None,
) -> Tuple[Tuple[Tensor], Tensor]:
"""
Calculates the confusion matrix according to the specified average method.
Args:
preds: Predictions from model (probabilities, logits or labels)
target: Ground truth labels
average:
Defines the reduction that is applied. Should be one of the following:
- ``'micro'`` [default]: Calculate the metric globally, across all samples and classes.
- ``'macro'``: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.
- ``'samples'``: Calculate the metric for each sample, and average the metrics
across samples (with equal weights for each sample).
.. note:: What is considered a sample in the multi-dimensional multi-class case
depends on the value of ``mdmc_average``.
.. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``,
the value for the class will be ``nan``.
mdmc_average:
Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
``average`` parameter). Should be one of the following:
- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class.
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`pages/classification:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
num_classes:
Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
threshold:
Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
top_k:
Number of the highest probability or logit score predictions considered finding the correct label,
relevant only for (multi-dimensional) multi-class inputs. The
default value (``None``) will be interpreted as 1 for these inputs.
Should be left at default (``None``) for all other types of inputs.
multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section `
for a more detailed explanation and examples.
ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
"""
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")
if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1):
raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.")
allowed_mdmc_average = [None, "samplewise", "global"]
if mdmc_average not in allowed_mdmc_average:
raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.")
if num_classes and ignore_index is not None and (not ignore_index < num_classes or num_classes == 1):
raise ValueError(
f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes"
)
if top_k is not None and (not isinstance(top_k, int) or top_k <= 0):
raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}")
#### ADDED ####
# Put all the NaNs as the 0-class
nans = torch.isnan(target)
target[nans] = 0
preds[nans] = 0
if (preds.ndim > 1) and (preds.shape[1] > 1):
preds[nans, 0] = 1
target = target.to(int)
#### END ADDED ####
preds, target = _input_squeeze(preds, target)
mode = _mode(preds, target, threshold, top_k, num_classes, multiclass, ignore_index)
reduce = "macro" if average in ["weighted", "none", None] else average
if subset_accuracy and _check_subset_validity(mode):
# correct, total = _subset_accuracy_update(preds, target, threshold, top_k, ignore_index)
# return _subset_accuracy_compute(correct, total)
raise NotImplementedError("subset_accuracy not implemented")
tp, fp, tn, fn = _accuracy_update(
preds, target, reduce, mdmc_average, threshold, num_classes, top_k, multiclass, ignore_index, mode
)
#### ADDED ####
num_nans = nans.sum(0)
if tp.numel() > 1:
tp[0] = tp[0] - num_nans
tn[1:] = tn[1:] - num_nans
else:
tn = tn - num_nans
if (preds.ndim > 1) and (preds.shape[1] > 1):
tp = tp - num_nans
#### END ADDED ####
return (tp, fp, tn, fn), mode
class NaNTensor(Tensor):
"""
Class to create and manage a NaN tensor along it's properties
The goal of the class is to override the regular tensor such that the basic
operations (sum, mean, max, etc) ignore the NaNs in the input.
It also supports NaNs in integer tensors (as the lowest integer possible).
"""
@property
def get_nans(self) -> BoolTensor:
"""
Gets the boolean Tensor containing the location of NaNs.
In the case of an integer tensor, this returns where the tensor is equal to its minimal value
In the case of a boolean tensor, this returns a Tensor filled with `False`
"""
if self.is_floating_point():
return self.isnan()
elif self.is_signed():
return self == torch.iinfo(self.dtype).min
else:
return torch.zeros(self.shape, device=self.device, dtype=bool)
def sum(self, *args, **kwargs) -> Tensor:
"""
Overloads the traditional sum to ignore the NaNs
"""
tensor = self.to(float)
tensor[self.get_nans] = float("nan")
if self.is_floating_point():
dtype = self.dtype
else:
dtype = torch.int64
return tensor.nansum(*args, **kwargs).to(dtype)
def mean(self, *args, **kwargs) -> Tensor:
"""
Overloads the traditional mean to ignore the NaNs
"""
tensor = self.to(float)
tensor[self.get_nans] = float("nan")
return nan_mean(tensor, *args, **kwargs).to(self.dtype)
def numel(self) -> int:
"""
Returns the number of non-NaN elements.
"""
return super(NaNTensor, ~self.get_nans).sum()
def min(self, *args, **kwargs) -> Tensor:
"""
Returns the min vale of a tensor whitout NaNs
"""
tensor = self
tensor = tensor[~self.get_nans]
return super(NaNTensor, tensor).min(*args, **kwargs)
def max(self, *args, **kwargs) -> Tensor:
"""
Returns the max vale of a tensor whitout NaNs
"""
tensor = self
tensor = tensor[~self.get_nans]
return super(NaNTensor, tensor).max(*args, **kwargs)
def argsort(self, dim=-1, descending=False) -> IntTensor:
"""
Return the indices that sort the tensor, while putting all the NaNs to the end of the sorting.
"""
tensor = self
if descending:
tensor[tensor.get_nans] = float("-inf")
else:
tensor[tensor.get_nans] = float("inf")
return super(NaNTensor, tensor).argsort(dim=dim, descending=descending)
def size(self, dim) -> Tensor:
"""
Instead of returning the size, return the number of non-NaN elements in
a specific dimension. Useful for the `r2_score` metric.
"""
return (~self.get_nans).sum(dim=dim)
def __lt__(self, other) -> Tensor:
"""
Stupid fix that allows the code to work with `r2_score`,
since it requires the size to be > 2. But since `self.size` now returns
a Tensor instead of a value, we check that all elements are > 2.
"""
if (not isinstance(other, Tensor)) and (other == 2):
return super().__lt__(other).all()
else:
return super().__lt__(other)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
"""
This __torch_function__ implementation wraps subclasses such that
methods called on subclasses return a subclass instance instead of
a ``torch.Tensor`` instance.
One corollary to this is that you need coverage for torch.Tensor
methods if implementing __torch_function__ for subclasses.
Affects the call torch.sum() as to behave the same way as NaNTensor.sum()
We recommend always calling ``super().__torch_function__`` as the base
case when doing the above.
While not mandatory, we recommend making `__torch_function__` a classmethod.
"""
if func.__name__ == "sum":
kwargs = {} if kwargs is None else kwargs
return args[0].sum(*args[1:], **kwargs)
else:
return super().__torch_function__(func, types, args=args, kwargs=kwargs)
def pearson_ipu(preds, target):
"""Computes pearson correlation coefficient.
Handles NaNs in the target without reshaping tensors in order to work on IPU.
Args:
preds: estimated scores
target: ground truth scores
"""
preds = NaNTensor(preds)
target = NaNTensor(target)
preds[target.get_nans] = float("nan")
pearson = pearson_corrcoef(preds, target.to(preds.dtype))
return Tensor(pearson)
def spearman_ipu(preds, target):
"""Computes spearman rank correlation coefficient.
Handles NaNs in the target without reshaping tensors in order to work on IPU.
Args:
preds: estimated scores
target: ground truth scores
"""
nans = target.isnan()
dtype = preds.dtype
preds[nans] = float("inf")
target[nans] = float("inf")
preds_sort = _rank_data(preds).to(dtype=dtype)
target_sort = _rank_data(target).to(dtype=dtype)
target_sort[nans] = float("nan")
spearman = pearson_ipu(preds_sort, target_sort)
return Tensor(spearman)
def _rank_data(data: Tensor) -> Tensor:
"""Calculate the rank for each element of a tensor.
The rank refers to the indices of an element in the corresponding sorted tensor (starting from 1).
Duplicates of the same value will be assigned the mean of their rank.
Adopted from `Rank of element tensor`_
"""
n = data.numel()
rank = torch.empty_like(data)
idx = data.argsort()
rank[idx] = torch.arange(1, n + 1, dtype=data.dtype, device=data.device)
# TODO: Repeats not yet supported
# repeats = _find_repeats(data)
# for r in repeats:
# condition = data == r
# rank[condition] = rank[condition].mean()
return rank
def r2_score_ipu(preds, target, *args, **kwargs) -> Tensor:
"""
Computes r2 score also known as `R2 Score_Coefficient Determination`_:
.. math:: R^2 = 1 - \frac{SS_{res}}{SS_{tot}}
where :math:`SS_{res}=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and
:math:`SS_{tot}=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate
adjusted r2 score given by
.. math:: R^2_{adj} = 1 - \frac{(1-R^2)(n-1)}{n-k-1}
where the parameter :math:`k` (the number of independent regressors) should
be provided as the ``adjusted`` argument.
Handles NaNs without reshaping tensors in order to work on IPU.
Args:
preds: estimated labels
target: ground truth labels
adjusted: number of independent regressors for calculating adjusted r2 score.
multioutput: Defines aggregation in the case of multiple output scores. Can be one of the following strings:
* ``'raw_values'`` returns full set of scores
* ``'uniform_average'`` scores are uniformly averaged
* ``'variance_weighted'`` scores are weighted by their individual variances
"""
preds = NaNTensor(preds)
target = NaNTensor(target)
preds[target.get_nans] = float("nan")
score = r2_score(preds, target, *args, **kwargs)
return Tensor(score)
def fbeta_score_ipu(
preds: Tensor,
target: Tensor,
beta: float = 1.0,
average: Optional[str] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
num_classes: Optional[int] = None,
threshold: float = 0.5,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
):
"""
A modified version of the `torchmetrics.functional.classification.f_beta._fbeta_compute`
that can ignore NaNs by giving them the same value for both `preds` and `target`.
This allows it to work with compilation
and IPUs since it doesn't modify the tensor's shape.
Args:
preds: Predictions from model (probabilities, logits or labels)
target: Ground truth labels
average:
Defines the reduction that is applied. Should be one of the following:
- ``'micro'`` [default]: Calculate the metric globally, across all samples and classes.
- ``'macro'``: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.
- ``'samples'``: Calculate the metric for each sample, and average the metrics
across samples (with equal weights for each sample).
.. note:: What is considered a sample in the multi-dimensional multi-class case
depends on the value of ``mdmc_average``.
.. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``,
the value for the class will be ``nan``.
mdmc_average:
Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
``average`` parameter). Should be one of the following:
- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class.
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`pages/classification:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
num_classes:
Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
threshold:
Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
top_k:
Number of the highest probability or logit score predictions considered finding the correct label,
relevant only for (multi-dimensional) multi-class inputs. The
default value (``None``) will be interpreted as 1 for these inputs.
Should be left at default (``None``) for all other types of inputs.
multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section `
for a more detailed explanation and examples.
ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
or ``'none'``, the score for the ignored class will be returned as ``nan``.
subset_accuracy:
Whether to compute subset accuracy for multi-label and multi-dimensional
multi-class inputs (has no effect for other input types).
- For multi-label inputs, if the parameter is set to ``True``, then all labels for
each sample must be correctly predicted for the sample to count as correct. If it
is set to ``False``, then all labels are counted separately - this is equivalent to
flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``).
- For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all
sub-sample (on the extra axis) must be correct for the sample to be counted as correct.
If it is set to ``False``, then all sub-samples are counter separately - this is equivalent,
in the case of label predictions, to flattening the inputs beforehand (i.e.
``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
still applies in both cases, if set.
Raises:
ValueError:
If ``top_k`` parameter is set for ``multi-label`` inputs.
ValueError:
If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``.
ValueError:
If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``.
ValueError:
If ``average`` is set but ``num_classes`` is not provided.
ValueError:
If ``num_classes`` is set
and ``ignore_index`` is not in the range ``[0, num_classes)``.
ValueError:
If ``top_k`` is not an ``integer`` larger than ``0``.
"""
(tp, fp, tn, fn), mode = get_confusion_matrix(
preds=preds,
target=target,
average=average,
mdmc_average=mdmc_average,
ignore_index=ignore_index,
num_classes=num_classes,
threshold=threshold,
top_k=top_k,
multiclass=multiclass,
)
b2 = beta**2
fbeta = ((1 + b2) * tp) / ((1 + b2) * tp + b2 * fn + fp)
if average in (None, "none", AverageMethod.NONE):
pass
elif average == AverageMethod.MICRO:
pass
elif average == AverageMethod.MACRO:
fbeta = fbeta.mean()
elif average == AverageMethod.WEIGHTED:
weights = tp + fn
fbeta = (weights * fbeta).sum() / weights.sum()
else:
raise ValueError(
f"`average={average}` not yet supported. Chose between None, Micro, Macro, or Weighted"
)
return fbeta
def f1_score_ipu(
preds: Tensor,
target: Tensor,
beta: float = 1.0,
average: Optional[str] = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
num_classes: Optional[int] = None,
threshold: float = 0.5,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
):
"""
A modified version of the `torchmetrics.functional.classification.f_beta._fbeta_compute`
that can ignore NaNs by giving them the same value for both `preds` and `target`.
Used to calculate the f1_score on IPU with beta parameter equal to 1.0
This allows it to work with compilation and IPUs since it doesn't modify the tensor's shape.
Computes f_beta metric from stat scores: true positives, false positives, true negatives, false negatives.
Args:
tp: True positives
fp: False positives
tn: True negatives
fn: False negatives
beta: The parameter `beta` (which determines the weight of recall in the combined score)
ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method
average: Defines the reduction that is applied
mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
``average`` parameter)
"""
return fbeta_score_ipu(
preds,
target,
beta=beta,
average=average,
mdmc_average=mdmc_average,
ignore_index=ignore_index,
num_classes=num_classes,
threshold=threshold,
top_k=top_k,
multiclass=multiclass,
)
def mean_squared_error_ipu(preds: Tensor, target: Tensor, squared: bool) -> Tensor:
"""Computes mean squared error.
Handles NaNs without reshaping tensors in order to work on IPU.
Args:
preds: estimated labels
target: ground truth labels
squared: returns RMSE value if set to False
Return:
Tensor with MSE
"""
target = target.clone()
preds = preds.clone()
# Replace the nan-targets in the preds/target tensors by 0
nan_targets = target.isnan()
preds[nan_targets] = 0.0
target[nan_targets] = 0.0
# Compute the loss, and rescale by the number of nan elements
loss = mean_squared_error(preds, target, squared)
if squared:
factor = nan_targets.numel() / ((~nan_targets).sum())
else:
factor = (nan_targets.numel() / ((~nan_targets).sum())).sqrt()
loss = loss * factor
return loss
def mean_absolute_error_ipu(preds: Tensor, target: Tensor) -> Tensor:
"""Computes mean absolute error.
Handles NaNs without reshaping tensors in order to work on IPU.
Args:
preds: estimated labels
target: ground truth labels
Return:
Tensor with MAE
"""
target = target.clone()
preds = preds.clone()
# Replace the nan-targets in the preds/target tensors by 0
nan_targets = target.isnan()
preds[nan_targets] = 0.0
target[nan_targets] = 0.0
# Compute the loss, and rescale by the number of nan elements
loss = mean_absolute_error(preds, target)
loss = loss * nan_targets.numel() / ((~nan_targets).sum())
return loss
================================================
FILE: graphium/ipu/ipu_simple_lightning.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import lightning
from lightning_graphcore import IPUStrategy
from lightning.pytorch.loggers import WandbLogger
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import mup
from graphium.nn.base_layers import FCLayer
from graphium.utils.mup import set_base_shapes
ON_IPU = True # Change this line to run on CPU
SEED = 42
# The simple PyTorch model used in each of these examples
class SimpleTorchModel(torch.nn.Module):
def __init__(self, in_dim, hidden_dim, kernel_size, num_classes):
super().__init__()
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.num_classes = num_classes
conv_block = nn.Sequential(
nn.Conv2d(in_channels=in_dim, out_channels=hidden_dim, kernel_size=kernel_size),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(),
nn.MaxPool2d(kernel_size),
nn.MaxPool2d(kernel_size),
)
self.the_network = nn.Sequential(
conv_block,
torch.nn.Flatten(),
FCLayer(4 * hidden_dim, hidden_dim),
FCLayer(hidden_dim, hidden_dim),
FCLayer(hidden_dim, num_classes, activation=None, is_readout_layer=True),
nn.LogSoftmax(1),
)
def make_mup_base_kwargs(self, divide_factor: float = 2.0):
return dict(
in_dim=self.in_dim,
hidden_dim=round(self.hidden_dim / divide_factor),
kernel_size=self.kernel_size,
num_classes=self.num_classes,
)
def forward(self, x):
return self.the_network(x)
# This class shows a minimal lightning example. This example uses our own
# SimpleTorchModel which is a basic 2 conv, 2 FC torch network. It can be
# found in simple_torch_model.py.
class SimpleLightning(lightning.LightningModule):
def __init__(self, in_dim, hidden_dim, kernel_size, num_classes, on_ipu):
super().__init__()
self.model = SimpleTorchModel(
in_dim=in_dim, hidden_dim=hidden_dim, kernel_size=kernel_size, num_classes=num_classes
)
self.on_ipu = on_ipu
def training_step(self, batch, _):
x, label = batch
prediction = self.model(x)
loss = torch.nn.functional.nll_loss(prediction, label)
return loss
def validation_step(self, batch, _):
x, label = batch
prediction = self.model(x)
preds = torch.argmax(prediction, dim=1)
acc = torch.sum(preds == label).float() / len(label)
loss = torch.nn.functional.nll_loss(prediction, label)
return loss, acc
# PopTorch doesn't currently support logging within steps. Use the Lightning
# callback hooks instead.
def on_train_batch_end(self, outputs, batch, batch_idx):
self.log("StepLoss", outputs["loss"])
def validation_epoch_end(self, outputs):
loss = [out[0] for out in outputs]
self.log("val_loss", torch.stack(loss).mean(), prog_bar=True)
acc = [out[1] for out in outputs]
self.log("val_acc", torch.stack(acc).mean(), prog_bar=True)
def configure_optimizers(self):
adam = torch.optim.Adam
if self.on_ipu:
import poptorch
adam = poptorch.optim.Adam
optimizer = mup.MuAdam(self.parameters(), lr=0.01, impl=adam)
return optimizer
if __name__ == "__main__":
torch.manual_seed(SEED)
# Create the model as usual.
predictor = SimpleLightning(in_dim=1, hidden_dim=32, kernel_size=3, num_classes=10, on_ipu=ON_IPU)
model = predictor.model
base = model.__class__(**model.make_mup_base_kwargs(divide_factor=2))
predictor.model = set_base_shapes(model, base, rescale_params=False)
torch.manual_seed(SEED)
# Normal PyTorch dataset.
train_set = torchvision.datasets.FashionMNIST(
"out/FashionMNIST", train=True, download=True, transform=transforms.Compose([transforms.ToTensor()])
)
val_set = torchvision.datasets.FashionMNIST(
"out/FashionMNIST", train=False, download=True, transform=transforms.Compose([transforms.ToTensor()])
)
# Normal PyTorch dataloader.
train_loader = torch.utils.data.DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=16, shuffle=False)
torch.manual_seed(SEED)
ipus = None
plugins = None
if ON_IPU:
import poptorch
training_opts = poptorch.Options()
inference_opts = poptorch.Options()
# Set the seeds
training_opts.randomSeed(SEED)
inference_opts.randomSeed(SEED)
ipus = 1
strategy = IPUStrategy(training_opts=training_opts, inference_opts=inference_opts)
trainer = lightning.Trainer(
logger=WandbLogger(),
ipus=ipus,
max_epochs=3,
log_every_n_steps=1,
plugins=plugins,
)
# When fit is called the model will be compiled for IPU and will run on the available IPU devices.
trainer.fit(predictor, train_dataloaders=train_loader, val_dataloaders=val_loader)
================================================
FILE: graphium/ipu/ipu_utils.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import os
import tempfile
from datetime import datetime
from copy import deepcopy
from types import ModuleType
from typing import Optional, Tuple, List
import torch
def import_poptorch(raise_error=True) -> Optional[ModuleType]:
"""
Import poptorch and returns it.
It is wrapped in a function to avoid breaking the code
for non-IPU devices which did not install poptorch.
Parameters:
raise_error: Whether to raise an error if poptorch is unavailable.
If `False`, return `None`
Returns:
The poptorch module
"""
try:
import poptorch
return poptorch
except ImportError as e:
if raise_error:
raise e
return
def is_running_on_ipu() -> bool:
"""
Returns whether the current module is running on ipu.
Needs to be used in the `forward` or `backward` pass.
"""
poptorch = import_poptorch(raise_error=False)
on_ipu = (poptorch is not None) and (poptorch.isRunningOnIpu())
return on_ipu
def load_ipu_options(
ipu_opts: List[str],
seed: Optional[int] = None,
model_name: Optional[str] = None,
gradient_accumulation: Optional[int] = None,
precision: Optional[int] = None,
ipu_inference_opts: Optional[List[str]] = None,
) -> Tuple["poptorch.Options", "poptorch.Options"]:
"""
Load the IPU options from the config file.
Parameters:
ipu_cfg: The list configurations for the IPU, written as a list of strings to make use of `poptorch.Options.loadFromFile`
write a temporary config gile, and read it. See `Options.loadFromFile`
#? see the tutorial for IPU options here
# https://github.com/graphcore/tutorials/tree/sdk-release-2.6/tutorials/pytorch/efficient_data_loading
#? see the full documentation for ipu options here
# https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/reference.html?highlight=options#poptorch.Options
***minibatch size***: The number of samples processed by one simple fwd/bwd pass.
= # of samples in a minibatch
***device iterations***: A device iteration corresponds to one iteration of the training loop executed on the IPU, starting with data-loading and ending with a weight update.
In this simple case, when we set n deviceIterations, the host will prepare n mini-batches in an infeed queue so the IPU can perform efficiently n iterations.
= # of minibatches to be processed at a time
= # of training / backward pass in this call
***gradient accumulation factor***: After each backward pass the gradients are accumulated together for K mini-batches. set K in the argument
= # of minibatches to accumulate gradients from
***replication factor***: Replication describes the process of running multiple instances of the same model simultaneously on different IPUs to achieve data parallelism.
If the model requires N IPUs and the replication factor is M, N x M IPUs will be necessary.
= # of times the model is copied to speed up computation, each replica of the model is sent a different subset of the dataset
***global batch size***: In a single device iteration, many mini-batches may be processed and the resulting gradients accumulated.
We call this total number of samples processed for one optimiser step the global batch size.
= total number of samples processed for *one optimiser step*
= (minibatch size x Gradient accumulation factor) x Number of replicas
seed: random seed for the IPU
model_name: Name of the model, to be used for ipu profiling
ipu_inference_opts: optional IPU configuration overrides for inference.
If this is provided, options in this file override those in `ipu_file` for inference.
Returns:
training_opts: IPU options for the training set.
inference_opts: IPU options for inference.
It differs from the `training_opts` by enforcing `gradientAccumulation` to 1
"""
poptorch = import_poptorch()
ipu_options = poptorch.Options()
ipu_opts_file = ipu_options_list_to_file(ipu_opts)
ipu_options.loadFromFile(ipu_opts_file.name)
ipu_opts_file.close()
ipu_options.outputMode(poptorch.OutputMode.All)
if seed is not None:
ipu_options.randomSeed(seed)
if model_name is not None:
ipu_options.modelName(f"{model_name}_train")
if gradient_accumulation is not None:
current = ipu_options.Training.gradient_accumulation
assert (current == 1) or (
current == gradient_accumulation
), f"Received inconsistent gradient accumulation `{current}` and `{gradient_accumulation}"
ipu_options.Training.gradientAccumulation(gradient_accumulation)
if precision == "16-true":
# IPUOptions.loadFromFile currently doesn't support setting half partials, doing it here
ipu_options.Precision.setPartialsType(torch.half)
training_opts = ipu_options
# Change the inference options to remove gradient accumulation
inference_opts = deepcopy(ipu_options)
inference_opts.Training.gradientAccumulation(1)
if ipu_inference_opts is not None:
ipu_inference_opts_file = ipu_options_list_to_file(ipu_inference_opts)
inference_opts.loadFromFile(ipu_inference_opts_file.name)
ipu_inference_opts_file.close()
return training_opts, inference_opts
def ipu_options_list_to_file(ipu_opts: Optional[List[str]]) -> tempfile._TemporaryFileWrapper:
"""
Create a temporary file from a list of ipu configs, such that it can be read by `poptorch.Options.loadFromFile`
Parameters:
ipu_opts: The list configurations for the IPU, written as a list of strings to make use of `poptorch.Options.loadFromFile`
Returns:
tmp_file: The temporary file of ipu configs
"""
if ipu_opts is None:
return
tmp_file = tempfile.NamedTemporaryFile("w", delete=True)
for s in ipu_opts:
tmp_file.write(s + "\n")
tmp_file.flush()
return tmp_file
================================================
FILE: graphium/ipu/ipu_wrapper.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Dict, Any, Optional, Callable, Union, Type, Tuple, Iterable
from torch_geometric.data import Batch
from torch import Tensor
from lightning_graphcore import IPUStrategy
from lightning.pytorch.utilities.types import STEP_OUTPUT
from lightning.pytorch.trainer.states import RunningStage
from graphium.trainer.predictor import PredictorModule
from graphium.ipu.ipu_utils import import_poptorch
import torch
from torch_geometric.data import Data, Batch
from torch_geometric.data.data import BaseData
from loguru import logger
import functools
import collections
from graphium.data.utils import get_keys
poptorch = import_poptorch()
class PyGArgsParser(poptorch.ICustomArgParser):
"""
This class is responsible for converting a PyG Batch from and to
a tensor of tuples. This allows PyG Batch to be used as inputs to
IPU programs. Copied from poppyg repo, in the future import from
the repo directly.
"""
@staticmethod
def sortedTensorKeys(struct: BaseData) -> Iterable[str]:
"""
Find all the keys that map to a tensor value in struct. The keys
are returned in sorted order.
"""
all_keys = sorted(get_keys(struct))
def isTensor(k: str) -> bool:
return isinstance(struct[k], torch.Tensor)
return filter(isTensor, all_keys)
def yieldTensors(self, struct: BaseData):
"""
yield every torch.Tensor in struct in sorted order
"""
for k in self.sortedTensorKeys(struct):
yield struct[k]
def reconstruct(self, original_structure: BaseData, tensor_iterator: Iterable[Tensor]):
"""
Create a new instance with the same class type as the
original_structure. This new instance will be initialized with tensors
from the provided iterator and uses the same sorted keys from the
yieldTensors() implementation.
"""
tensor_keys = self.sortedTensorKeys(original_structure)
kwargs = {k: next(tensor_iterator) for k in tensor_keys}
for k in get_keys(original_structure):
if k not in kwargs:
# copy non-tensor properties to the new instance
kwargs[k] = original_structure[k]
cls = original_structure.__class__
if issubclass(cls, Batch):
kwargs["_base_cls"] = Data
return Batch(**kwargs)
return cls(**kwargs)
# PyG uses the BaseData object as the root for data and batch objects
poptorch.registerCustomArgParser(BaseData, PyGArgsParser())
class PredictorModuleIPU(PredictorModule):
"""
This class wraps around the `PredictorModule` to make it work with IPU and the `IPUPluginGraphium`.
"""
def __init__(self, *args, **kwargs):
# Import poptorch in a safe way that will work when working with cpu/gpu
self.poptorch = import_poptorch()
super().__init__(*args, **kwargs)
@staticmethod
def compute_loss(
preds: Dict[str, Tensor],
targets: Dict[str, Tensor],
weights: Optional[Tensor],
loss_fun: Dict[str, Callable],
target_nan_mask: Union[Type, str] = "ignore",
multitask_handling: Optional[str] = None,
) -> Tuple[Tensor, Dict[str, Tensor]]:
return PredictorModule.compute_loss(
preds, targets, weights, loss_fun, target_nan_mask, multitask_handling
)
def on_train_batch_end(self, outputs, batch, batch_idx):
outputs = self.convert_from_fp16(outputs)
outputs["loss"] = outputs["loss"][outputs["loss"] != 0].mean()
super().on_train_batch_end(outputs, batch, batch_idx)
def training_step(self, batch, batch_idx) -> Dict[str, Any]:
features, labels = batch["features"], batch["labels"]
features, labels = self.squeeze_input_dims(features, labels)
dict_input = {"features": features, "labels": labels}
step_dict = super().training_step(dict_input, to_cpu=False)
loss = step_dict.pop("loss")
step_dict["loss"] = self.poptorch.identity_loss(loss, reduction="mean")
return step_dict
def validation_step(self, batch, batch_idx) -> Dict[str, Any]:
features, labels = batch["features"], batch["labels"]
features, labels = self.squeeze_input_dims(features, labels)
dict_input = {"features": features, "labels": labels}
step_dict = super().validation_step(dict_input, to_cpu=False)
return step_dict
def test_step(self, batch, batch_idx) -> Dict[str, Any]:
# Build a dictionary from the tuples
features, labels = batch["features"], batch["labels"]
features, labels = self.squeeze_input_dims(features, labels)
dict_input = {"features": features, "labels": labels}
step_dict = super().test_step(dict_input, to_cpu=False)
return step_dict
def predict_step(self, **inputs) -> Dict[str, Any]:
# Build a dictionary from the tuples
dict_input = inputs
step_dict = super().predict_step(dict_input, to_cpu=False)
return step_dict
def on_validation_batch_end(
self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> None:
# convert data that will be tracked
outputs = self.convert_from_fp16(outputs)
super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)
def evaluation_epoch_end(self, outputs: Any):
outputs = self.convert_from_fp16(outputs)
super().evaluation_epoch_end(outputs)
def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
outputs = self.convert_from_fp16(outputs)
super().on_test_batch_end(outputs, batch, batch_idx, dataloader_idx)
def configure_optimizers(self, impl=None):
if impl is None:
dtype = self.precision_to_dtype(self.trainer.precision)
impl = functools.partial(
self.poptorch.optim.Adam,
accum_type=dtype,
first_order_momentum_accum_type=dtype,
second_order_momentum_accum_type=torch.float,
)
return super().configure_optimizers(impl=impl)
def squeeze_input_dims(self, features, labels):
for key, tensor in features:
if isinstance(tensor, torch.Tensor):
features[key] = features[key].squeeze(0)
for key in labels:
labels[key] = labels[key].squeeze(0)
return features, labels
def convert_from_fp16(self, data: Any) -> Any:
"""
Converts tensors from FP16 to FP32. Useful to convert the IPU program output data
"""
if isinstance(data, collections.Sequence):
for idx in range(len(data)):
data[idx] = self.convert_from_fp16(data[idx])
elif isinstance(data, collections.Mapping):
for key in data:
data[key] = self.convert_from_fp16(data[key])
elif isinstance(data, torch.Tensor) and data.dtype == torch.float16:
data = data.float()
return data
def _convert_features_dtype(self, feats):
"""
Converts features to trainer precision rather than model precision.
Necessary to run IPU on FP16.
"""
dtype = self.precision_to_dtype(self.trainer.precision)
# Convert features to dtype
if isinstance(feats, torch.Tensor):
feats = feats.to(dtype)
elif isinstance(feats, (Data, Batch, dict)):
for key, val in feats.items():
if isinstance(val, torch.Tensor) and (val.is_floating_point()):
feats[key] = val.to(dtype=dtype)
else:
raise ValueError(f"Unsupported feats type `{type(feats)}` : {feats}")
return feats
def precision_to_dtype(self, precision):
return torch.half if precision == "16-true" else torch.float
def get_num_graphs(self, data: Batch):
"""
IPU specific method to compute the number of graphs in a Batch,
that considers gradient accumulation, multiple IPUs and multiple
device iterations. Essential to estimate throughput in graphs/s.
"""
num_graphs = torch.max(data.batch, dim=-1).values
num_graphs = torch.sum(num_graphs)
return num_graphs
================================================
FILE: graphium/ipu/to_dense_batch.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch_scatter import scatter_add
def to_sparse_batch(x: Tensor, mask_idx: Tensor):
"""
Reverse function of `to_dense_batch`
"""
return torch.index_select(x.reshape(-1, x.shape[-1]), 0, mask_idx)
def to_sparse_batch_from_packed(x: Tensor, pack_from_node_idx: Tensor):
"""
Reverse function of `to_packed_dense_batch`
"""
return x[pack_from_node_idx[:, 0], pack_from_node_idx[:, 1]]
def to_dense_batch(
x: Tensor,
batch: Optional[Tensor] = None,
fill_value: float = 0.0,
max_num_nodes_per_graph: Optional[int] = None,
batch_size: Optional[int] = None,
drop_nodes_last_graph=False,
) -> Tuple[Tensor, Tensor]:
r"""Given a sparse batch of node features
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}` (with
:math:`N_i` indicating the number of nodes in graph :math:`i`), creates a
dense node feature tensor
:math:`\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F}` (with
:math:`N_{\max} = \max_i^B N_i`).
In addition, a mask of shape :math:`\mathbf{M} \in \{ 0, 1 \}^{B \times
N_{\max}}` is returned, holding information about the existence of
fake-nodes in the dense representation.
Parameters:
x: Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
batch: Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. Must be ordered. (default: :obj:`None`)
fill_value: The value for invalid entries in the
resulting dense output tensor. (default: :obj:`0`)
max_num_nodes_per_graph: The size of the output node dimension.
(default: :obj:`None`)
batch_size: The batch size. (default: :obj:`None`)
drop_nodes_last_graph: Whether to drop the nodes of the last graphs that exceed
the `max_num_nodes_per_graph`. Useful when the last graph is a padding.
:rtype: (:class:`Tensor`, :class:`BoolTensor`)
"""
if batch is None and max_num_nodes_per_graph is None:
mask = torch.ones(1, x.size(0), dtype=torch.bool, device=x.device)
return x.unsqueeze(0), mask
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)
if batch_size is None:
assert x.device.type != "ipu", (
"When using the IPU the batch size must be "
"provided during compilation instead of determined at runtime"
)
batch_size = int(batch.max()) + 1
if x.device not in ["ipu", "xla"]:
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0, dim_size=batch_size)
else:
# Can't use scatter_add here due to PopTorch bug, will be fixed in SDK 3.3
arange = torch.arange(batch_size).unsqueeze(-1)
num_nodes = batch.eq(arange).sum(dim=-1)
cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])
if max_num_nodes_per_graph is None: # Must be provided on IPU
max_num_nodes_per_graph = int(num_nodes.max())
idx = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
idx = (idx - cum_nodes[batch]) + (batch * max_num_nodes_per_graph)
size = [batch_size * max_num_nodes_per_graph] + list(x.size())[1:]
out = x.new_full(size, fill_value)
##### CHANGES FROM PYG #####
# In case the last graph represents padding. Drop the overflowing nodes.
if drop_nodes_last_graph:
num_nodes = num_nodes[:-1]
idx[idx >= size[0]] = size[0] - 1
# Raise error if num_nodes > max_num_nodes
if x.device.type != "ipu":
assert (
num_nodes <= max_num_nodes_per_graph
).all(), f"Encountered graphs with {num_nodes.max()} nodes, greater than `max_num_nodes = {max_num_nodes_per_graph}`"
out[idx] = x
out = out.view([batch_size, max_num_nodes_per_graph] + list(x.size())[1:])
# Create a zero-mask on the right device
mask_sz = batch_size * max_num_nodes_per_graph
if x.device.type in ("ipu", "xla"):
mask = torch.zeros(mask_sz, dtype=torch.int32, device="cpu")
mask = mask.to(x.device)
# Can't use mask[idx] here due to PopTorch bug, will be fixed in SDK 3.3
# mask[idx] = 1
# mask = mask.bool()
if drop_nodes_last_graph:
num_nodes_with_padding = torch.cat((num_nodes, torch.tensor([0], dtype=torch.int32)), dim=0)
else:
num_nodes_with_padding = num_nodes
arange = torch.arange(max_num_nodes_per_graph)
mask = num_nodes_with_padding.unsqueeze(-1).gt(arange).flatten()
else:
mask = torch.zeros(mask_sz, dtype=torch.bool, device=x.device)
mask[idx] = 1
##### END CHANGES FROM PYG #####
mask = mask.view(batch_size, max_num_nodes_per_graph)
return out, mask, idx # Added `idx` as a return
def to_packed_dense_batch(
x: Tensor,
pack_from_node_idx: Tensor,
pack_attn_mask: Tensor,
fill_value: float = 0.0,
max_num_nodes_per_pack: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
r"""Given a sparse batch of node features
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}` (with
:math:`N_i` indicating the number of nodes in graph :math:`i`), creates a
dense node feature tensor
:math:`\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F}` (with
:math:`N_{\max} = \max_i^B N_i`).
In addition, a mask of shape :math:`\mathbf{M} \in \{ 0, 1 \}^{B \times
N_{\max}}` is returned, holding information about the existence of
fake-nodes in the dense representation.
Parameters: # TODO: Update docstring
x: Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
batch: Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. Must be ordered. (default: :obj:`None`)
fill_value: The value for invalid entries in the
resulting dense output tensor. (default: :obj:`0`)
max_num_nodes_per_graph: The size of the output node dimension.
(default: :obj:`None`)
batch_size: The batch size. (default: :obj:`None`)
drop_nodes_last_graph: Whether to drop the nodes of the last graphs that exceed
the `max_num_nodes_per_graph`. Useful when the last graph is a padding.
:rtype: (:class:`Tensor`, :class:`BoolTensor`)
"""
if max_num_nodes_per_pack is None: # Must be provided on IPU
max_num_nodes_per_pack = pack_attn_mask.shape[-1]
size = [pack_attn_mask[0], max_num_nodes_per_pack] + list(x.size())[1:]
out = x.new_full(size, fill_value)
out[pack_from_node_idx[:, 0], pack_from_node_idx[:, 1]] = x
return out
================================================
FILE: graphium/nn/README.md
================================================
The Graph Of LIfe Library.
## What is in this folder?
code for base graph layer classes,
in subfolders there are different GNN layers, positional encoding layers and the graph architecture.
- ✅ `base_graph_layer.py`: contains `BaseGraphStructure` and `BaseGraphModule` classes, should be inherented by all gnn layers
- ✅ `base_layer.py`: contains base class for different layer, notably `FCLayer` for feedforward layers
- `residual_connections.py`: code for residual connections
- `utils.py`: util for mixed precision
================================================
FILE: graphium/nn/__init__.py
================================================
================================================
FILE: graphium/nn/architectures/README.md
================================================
The Graph Of LIfe Library.
## What is in this folder?
- ✅ `encoder_manager.py`: encoder manager to manage the positional encoders and pool them at the correct level as input to the gnns
- ✅ `global_architectures.py`: `FullGraphNetwork`, architecture to run all the gnn layers
- `pyg_architectures.py`: `FeedForwardPyg`, base class for different pyg layers
================================================
FILE: graphium/nn/architectures/__init__.py
================================================
from .global_architectures import FeedForwardNN
from .global_architectures import FullGraphMultiTaskNetwork
from .global_architectures import TaskHeads
from .global_architectures import GraphOutputNN
from .pyg_architectures import FeedForwardPyg
from .global_architectures import EnsembleFeedForwardNN
================================================
FILE: graphium/nn/architectures/encoder_manager.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Iterable, Dict, Any, Optional
from torch_geometric.data import Batch
# Misc imports
import inspect
from copy import deepcopy
# Torch imports
from torch import Tensor, nn
import torch
from graphium.data.utils import get_keys
from graphium.nn.encoders import (
laplace_pos_encoder,
mlp_encoder,
signnet_pos_encoder,
gaussian_kernel_pos_encoder,
)
PE_ENCODERS_DICT = {
"laplacian_pe": laplace_pos_encoder.LapPENodeEncoder,
"mlp": mlp_encoder.MLPEncoder,
"cat_mlp": mlp_encoder.CatMLPEncoder,
"signnet": signnet_pos_encoder.SignNetNodeEncoder,
"gaussian_kernel": gaussian_kernel_pos_encoder.GaussianKernelPosEncoder,
}
class EncoderManager(nn.Module):
def __init__(
self,
pe_encoders_kwargs: Optional[Dict[str, Any]] = None,
max_num_nodes_per_graph: Optional[int] = None,
name: str = "encoder_manager",
):
r"""
Class that allows to runs multiple encoders in parallel and concatenate / pool their outputs.
Parameters:
pe_encoders_kwargs:
key-word arguments to use for the initialization of all positional encoding encoders
can use the class PE_ENCODERS_DICT: "la_encoder"(tested) , "mlp_encoder" (not tested), "signnet_encoder" (not tested)
name:
Name attributed to the current network, for display and printing
purposes.
"""
super().__init__()
self.name = name
self.max_num_nodes_per_graph = max_num_nodes_per_graph
if pe_encoders_kwargs is not None:
max_nodes = pe_encoders_kwargs.pop("max_num_nodes_per_graph", None)
if max_nodes is not None:
if self.max_num_nodes_per_graph is not None:
assert (
self.max_num_nodes_per_graph == max_nodes
), f"max_num_nodes_per_graph mismatch {self.max_num_nodes_per_graph}!={max_nodes}"
self.max_num_nodes_per_graph = max_nodes
self.pe_encoders_kwargs = deepcopy(pe_encoders_kwargs)
self.pe_encoders = self._initialize_positional_encoders(pe_encoders_kwargs)
def _initialize_positional_encoders(self, pe_encoders_kwargs: Dict[str, Any]) -> Optional[nn.ModuleDict]:
r"""Initialize the positional encoders for each positional/structural encodings.
Parameters:
pe_encoders_kwargs: key-word arguments to use for the initialization of all positional encoding encoders
Returns:
pe_encoders: a nn.ModuleDict containing all positional encoders specified by encoder_name in pe_encoders_kwargs["encoders"]
"""
if (pe_encoders_kwargs is None) or (len(pe_encoders_kwargs) == 0):
return
pe_encoders = nn.ModuleDict()
# Pooling options here for pe encoders
self.pe_pool = pe_encoders_kwargs["pool"]
pe_out_dim = pe_encoders_kwargs.get("out_dim", None)
edge_pe_out_dim = pe_encoders_kwargs.get("edge_out_dim", None)
in_dim_dict = pe_encoders_kwargs["in_dims"]
# Loop every positional encoding to assign it
for encoder_name, encoder_kwargs in pe_encoders_kwargs["encoders"].items():
encoder_kwargs = deepcopy(encoder_kwargs)
encoder_type = encoder_kwargs.pop("encoder_type")
output_keys = encoder_kwargs["output_keys"]
encoder = PE_ENCODERS_DICT[encoder_type]
# Get the keys associated to in_dim. First check if there's a key that starts with `encoder_name`
# Then check for the exact key
this_in_dims = {}
for key in encoder_kwargs.get("input_keys", []):
if key in in_dim_dict:
this_in_dims[key] = in_dim_dict[key]
else:
raise ValueError(
f"Key '{key}' not found in `in_dim_dict`. Encoder '{encoder_name}' is also not found.\n Available keys: {in_dim_dict.keys()}"
)
# Parse the in_dims based on Encoder's signature
accepted_keys = inspect.signature(encoder).parameters.keys()
if all([key in accepted_keys for key in this_in_dims.keys()]):
pass
elif "in_dim" in accepted_keys:
if len(this_in_dims) == 1 or encoder_type == "laplacian_pe":
this_in_dims = {"in_dim": list(this_in_dims.values())[0]}
elif len(this_in_dims) > 1 and encoder_type == "cat_mlp":
this_in_dims = {"in_dim": list(this_in_dims.values())}
else:
raise ValueError(
f"All `in_dims` must be equal for encoder {encoder_name} unless edge_mlp is used. Provided: {this_in_dims}, {encoder_type}"
)
else:
raise ValueError(
f"`in_dim` not understood for encoder {encoder_name}. Provided: {this_in_dims}. Accepted keys are: {accepted_keys}"
)
# Add the max_num_nodes_per_graph if it's in the accepted input keys
if "max_num_nodes_per_graph" in accepted_keys:
encoder_kwargs["max_num_nodes_per_graph"] = self.max_num_nodes_per_graph
# Initialize the pe_encoder layer
if output_keys[0] == "feat":
pe_out_dim2 = encoder_kwargs.pop("out_dim", None)
if pe_out_dim2 is not None:
assert pe_out_dim == pe_out_dim2, f"values mismatch {pe_out_dim}!={pe_out_dim2}"
pe_encoders[encoder_name] = encoder(out_dim=pe_out_dim, **this_in_dims, **encoder_kwargs)
elif output_keys[0] == "edge_feat":
pe_out_dim2 = encoder_kwargs.pop("out_dim", None)
if pe_out_dim2 is not None:
assert edge_pe_out_dim == pe_out_dim2, f"values mismatch {pe_out_dim}!={pe_out_dim2}"
pe_encoders[encoder_name] = encoder(out_dim=edge_pe_out_dim, **this_in_dims, **encoder_kwargs)
else:
pe_encoders[encoder_name] = encoder(**this_in_dims, **encoder_kwargs)
return pe_encoders
def forward(self, g: Batch) -> Batch:
r"""
forward pass of the pe encoders and pooling
Parameters:
g:
ptg Batch on which the convolution is done.
Must contain the following elements:
- Node key `"feat"`: `torch.Tensor[..., N, Din]`.
Input node feature tensor, before the network.
`N` is the number of nodes, `Din` is the input features dimension ``self.pre_nn.in_dim``
- Edge key `"edge_feat"`: `torch.Tensor[..., N, Ein]` **Optional**.
The edge features to use. It will be ignored if the
model doesn't supporte edge features or if
`self.in_dim_edges==0`.
- Other keys related to positional encodings `"pos_enc_feats_sign_flip"`,
`"pos_enc_feats_no_flip"`.
Returns:
g:
pyg Batch with the positional encodings added to the graph
"""
# Apply the positional encoders
pe_pooled = self.forward_positional_encoding(g)
# Add the processed positional encodings to the graphs.
# If the key is already present, concatenate the pe_pooled to the pre-existing feature.
for pe_key, this_pe in pe_pooled.items():
feat = this_pe
if pe_key in get_keys(g):
feat = torch.cat((feat, g[pe_key]), dim=-1)
g[pe_key] = feat
return g
def forward_positional_encoding(self, g: Batch) -> Dict[str, Tensor]:
"""
Forward pass for the positional encodings (PE),
with each PE having it's own encoder defined in `self.pe_encoders`.
All the positional encodings with the same keys are pooled together
using `self.pe_pooling`.
Parameters:
g: pyg Batch containing the node positional encodings
Returns:
pe_node_pooled: The positional / structural encodings go through
encoders, then are pooled together according to their keys.
"""
# Return None if no positional encoders
if (self.pe_encoders is None) or len(self.pe_encoders) == 0:
return {}
encoder_outs = []
# Run every node and edge positional-encoder
for encoder_name, encoder in self.pe_encoders.items():
encoder_outs.append(encoder(g, key_prefix=encoder_name))
# list of dict to dict of list, with concatenation of the tensors
pe_cats = {
key: torch.stack([d[key] for d in encoder_outs if key in d], dim=-1)
for key in set().union(*encoder_outs)
}
# Pool the node and edge positional encodings
pe_pooled = {}
for key, pe_cat in pe_cats.items():
pe_pooled[key] = self.forward_simple_pooling(pe_cat, pooling=self.pe_pool, dim=-1)
return pe_pooled
def forward_simple_pooling(self, h: Tensor, pooling: str, dim: int) -> Tensor:
"""
Apply sum, mean, or max pooling on a Tensor.
Parameters:
h: the Tensor to pool
pooling: string specifiying the pooling method
dim: the dimension to pool over
Returns:
pooled: the pooled Tensor
"""
if pooling == "sum":
pooled = torch.sum(h, dim=dim)
elif pooling == "mean":
pooled = torch.mean(h, dim=dim)
elif pooling == "max":
pooled = torch.max(h, dim=dim).values
else:
raise Exception(f"Pooling method `{self.pe_pool}` is not defined")
return pooled
def make_mup_base_kwargs(self, divide_factor: float = 2.0) -> Dict[str, Any]:
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter:
divide_factor: Factor by which to divide the width.
Returns:
pe_kw: the model kwargs where the dimensions are divided by the factor
"""
# For the pe-encoders, don't factor the in_dim and in_dim_edges
if self.pe_encoders is not None:
pe_kw = deepcopy(self.pe_encoders_kwargs)
new_pe_kw = {
key: encoder.make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=False)
for key, encoder in self.pe_encoders.items()
}
pe_kw["out_dim"] = round(pe_kw.get("out_dim", 0) / divide_factor)
pe_kw["edge_out_dim"] = round(pe_kw.get("edge_out_dim", 0) / divide_factor)
for key, enc in pe_kw["encoders"].items():
new_pe_kw[key].pop("in_dim", None)
new_pe_kw[key].pop("in_dim_edges", None)
enc.update(new_pe_kw[key])
return pe_kw
@property
def input_keys(self) -> Iterable[str]:
r"""
Returns the input keys for all pe-encoders
Returns:
input_keys: the input keys for all pe-encoders
"""
if self.pe_encoders is not None:
return self.pe_encoders_kwargs["input_keys"]
else:
raise ValueError("pe_encoders is not initialized, so there are no input keys.")
@property
def in_dims(self) -> Iterable[int]:
r"""
Returns the input dimensions for all pe-encoders
Returns:
in_dims: the input dimensions for all pe-encoders
"""
if self.pe_encoders is not None:
return self.pe_encoders_kwargs["in_dims"]
else:
raise ValueError("pe_encoders is not initialized, so there are no input dimensions.")
@property
def out_dim(self) -> int:
r"""
Returns the output dimension of the pooled embedding from all the pe encoders
Returns:
out_dim: the output dimension of the pooled embedding from all the pe encoders
"""
if self.pe_encoders is not None:
return self.pe_encoders_kwargs["out_dim"]
else:
raise ValueError("pe_encoders is not initialized, so there is no output dimension.")
================================================
FILE: graphium/nn/architectures/global_architectures.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Iterable, List, Dict, Literal, Tuple, Union, Callable, Any, Optional, Type
from torch_geometric.data import Batch
from graphium.ipu.to_dense_batch import to_dense_batch
from loguru import logger
# Misc imports
import inspect
from copy import deepcopy
from collections import OrderedDict
# Torch imports
from torch import Tensor, nn
import torch
from torch_geometric.data import Data
from omegaconf import DictConfig, OmegaConf
# graphium imports
from graphium.data.utils import get_keys
from graphium.nn.base_layers import FCLayer, get_activation, get_norm
from graphium.nn.architectures.encoder_manager import EncoderManager
from graphium.nn.pyg_layers import VirtualNodePyg, parse_pooling_layer_pyg
from graphium.nn.base_graph_layer import BaseGraphModule, BaseGraphStructure
from graphium.nn.residual_connections import (
ResidualConnectionBase,
ResidualConnectionWeighted,
ResidualConnectionRandom,
)
from graphium.nn.utils import MupMixin
from graphium.ipu.ipu_utils import import_poptorch, is_running_on_ipu
poptorch = import_poptorch(raise_error=False)
import collections
class FeedForwardNN(nn.Module, MupMixin):
def __init__(
self,
in_dim: int,
out_dim: int,
hidden_dims: Union[List[int], int],
depth: Optional[int] = None,
activation: Union[str, Callable] = "relu",
last_activation: Union[str, Callable] = "none",
dropout: float = 0.0,
last_dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
first_normalization: Union[str, Callable] = "none",
last_normalization: Union[str, Callable] = "none",
residual_type: str = "none",
residual_skip_steps: int = 1,
name: str = "LNN",
layer_type: Union[str, nn.Module] = "fc",
layer_kwargs: Optional[Dict] = None,
last_layer_is_readout: bool = False,
):
r"""
A flexible neural network architecture, with variable hidden dimensions,
support for multiple layer types, and support for different residual
connections.
Parameters:
in_dim:
Input feature dimensions of the layer
out_dim:
Output feature dimensions of the layer
hidden_dims:
Either an integer specifying all the hidden dimensions,
or a list of dimensions in the hidden layers.
Be careful, the "simple" residual type only supports
hidden dimensions of the same value.
depth:
If `hidden_dims` is an integer, `depth` is 1 + the number of
hidden layers to use.
If `hidden_dims` is a list, then
`depth` must be `None` or equal to `len(hidden_dims) + 1`
activation:
activation function to use in the hidden layers.
last_activation:
activation function to use in the last layer.
dropout:
The ratio of units to dropout. Must be between 0 and 1
last_dropout:
The ratio of units to dropout for the last_layer. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
first_normalization:
Whether to use batch normalization **before** the first layer
last_normalization:
Whether to use batch normalization in the last layer
residual_type:
- "none": No residual connection
- "simple": Residual connection similar to the ResNet architecture.
See class `ResidualConnectionSimple`
- "weighted": Residual connection similar to the Resnet architecture,
but with weights applied before the summation. See class `ResidualConnectionWeighted`
- "concat": Residual connection where the residual is concatenated instead
of being added.
- "densenet": Residual connection where the residual of all previous layers
are concatenated. This leads to a strong increase in the number of parameters
if there are multiple hidden layers.
residual_skip_steps:
The number of steps to skip between each residual connection.
If `1`, all the layers are connected. If `2`, half of the
layers are connected.
name:
Name attributed to the current network, for display and printing
purposes.
layer_type:
The type of layers to use in the network.
Either "fc" as the `FCLayer`, or a class representing the `nn.Module`
to use.
layer_kwargs:
The arguments to be used in the initialization of the layer provided by `layer_type`
last_layer_is_readout: Whether the last layer should be treated as a readout layer.
Allows to use the `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup
"""
super().__init__()
# Set the class attributes
self.in_dim = in_dim
self.out_dim = out_dim
if isinstance(hidden_dims, int):
self.hidden_dims = [hidden_dims] * (depth - 1)
else:
self.hidden_dims = list(hidden_dims)
assert (depth is None) or (
depth == len(self.hidden_dims) + 1
), "Mismatch between the provided network depth from `hidden_dims` and `depth`"
self.depth = len(self.hidden_dims) + 1
self.activation = get_activation(activation)
self.last_activation = get_activation(last_activation)
self.dropout = dropout
self.last_dropout = last_dropout
self.normalization = normalization
self.first_normalization = get_norm(first_normalization, dim=in_dim)
self.last_normalization = last_normalization
self.residual_type = None if residual_type is None else residual_type.lower()
self.residual_skip_steps = residual_skip_steps
self.layer_kwargs = layer_kwargs if layer_kwargs is not None else {}
self.name = name
self.last_layer_is_readout = last_layer_is_readout
self._readout_cache = None
self.full_dims = [self.in_dim] + self.hidden_dims + [self.out_dim]
self._parse_layers(layer_type=layer_type, residual_type=residual_type)
self._create_layers()
self._check_bad_arguments()
def _parse_layers(self, layer_type, residual_type):
# Parse the layer and residuals
from graphium.utils.spaces import LAYERS_DICT, RESIDUALS_DICT
self.layer_class, self.layer_name = self._parse_class_from_dict(layer_type, LAYERS_DICT)
self.residual_class, self.residual_name = self._parse_class_from_dict(residual_type, RESIDUALS_DICT)
def _check_bad_arguments(self):
r"""
Raise comprehensive errors if the arguments seem wrong
"""
if (self.residual_type == "simple") and not (self.hidden_dims[:-1] == self.hidden_dims[1:]):
raise ValueError(
f"When using the residual_type={self.residual_type}"
+ f", all elements in the hidden_dims must be equal. Provided:{self.hidden_dims}"
)
def _parse_class_from_dict(
self, name_or_class: Union[type, str], class_dict: Dict[str, type]
) -> Tuple[type, str]:
r"""
Register the hyperparameters for tracking by Pytorch-lightning
"""
if isinstance(name_or_class, str):
obj_name = name_or_class.lower()
obj_class = class_dict[obj_name]
elif callable(name_or_class):
obj_name = str(name_or_class)
obj_class = name_or_class
else:
raise TypeError(f"`name_or_class` must be str or callable, provided: {type(name_or_class)}")
return obj_class, obj_name
def _create_residual_connection(self, out_dims: List[int]) -> Tuple[ResidualConnectionBase, List[int]]:
r"""
Create the residual connection classes.
The out_dims is only used if the residual classes requires weights
"""
if self.residual_class == ResidualConnectionWeighted:
# if self.residual_class.has_weights:
residual_layer = self.residual_class(
skip_steps=self.residual_skip_steps,
out_dims=out_dims,
dropout=self.dropout,
activation=self.activation,
normalization=self.normalization,
bias=False,
)
elif self.residual_class == ResidualConnectionRandom:
residual_layer = self.residual_class(
out_dims=out_dims,
skip_steps=self.residual_skip_steps,
)
else:
residual_layer = self.residual_class(skip_steps=self.residual_skip_steps)
residual_out_dims = residual_layer.get_true_out_dims(self.full_dims[1:])
return residual_layer, residual_out_dims
def _create_layers(self):
r"""
Create all the necessary layers for the network.
It's a bit complicated to explain what's going on in this function,
but it must manage the varying features sizes caused by:
- The presence of different types of residual connections
"""
self.residual_layer, residual_out_dims = self._create_residual_connection(out_dims=self.full_dims[1:])
# Create a ModuleList of the GNN layers
self.layers = nn.ModuleList()
this_in_dim = self.full_dims[0]
this_activation = self.activation
this_norm = self.normalization
this_dropout = self.dropout
for ii in range(self.depth):
this_out_dim = self.full_dims[ii + 1]
other_kwargs = {}
sig = inspect.signature(self.layer_class)
key_args = [p.name for p in sig.parameters.values()]
if ii == self.depth - 1:
this_activation = self.last_activation
this_norm = self.last_normalization
this_dropout = self.last_dropout
if self.last_layer_is_readout and ("is_readout_layer" in key_args):
other_kwargs["is_readout_layer"] = self.last_layer_is_readout
# Create the layer
self.layers.append(
self.layer_class(
in_dim=this_in_dim,
out_dim=this_out_dim,
activation=this_activation,
dropout=this_dropout,
normalization=this_norm,
**self.layer_kwargs,
**other_kwargs,
)
)
if ii < len(residual_out_dims):
this_in_dim = residual_out_dims[ii]
@property
def cache_readouts(self) -> bool:
"""Whether the readout cache is enabled"""
return isinstance(self._readout_cache, dict)
def _enable_readout_cache(self):
"""
Enable the readout cache.
Due to the usage of a dict, it only saves readouts for a single batch at a time
"""
if not self.cache_readouts:
self._readout_cache = {}
def _disable_readout_cache(self):
"""Disable the readout cache"""
self._readout_cache = None
def drop_layers(self, depth: int) -> None:
r"""
Remove the last layers of the model part.
"""
assert depth >= 0
assert depth <= len(self.layers)
if depth > 0:
self.layers = self.layers[:-depth]
def add_layers(self, layers: int) -> None:
r"""
Add layers to the end of the model.
"""
assert isinstance(layers, nn.ModuleList)
assert len(layers) > 0
if len(self.layers) > 0:
assert layers[0].in_dim == self.layers[-1].out_dim
self.layers.extend(layers)
def forward(self, h: torch.Tensor) -> torch.Tensor:
r"""
Apply the neural network on the input features.
Parameters:
h: `torch.Tensor[..., Din]`:
Input feature tensor, before the network.
`Din` is the number of input features
Returns:
`torch.Tensor[..., Dout]`:
Output feature tensor, after the network.
`Dout` is the number of output features
"""
feat_prev = None
# Apply a normalization before the first layer
if self.first_normalization is not None:
h = self.first_normalization(h)
# Apply all neural network layers
for ii, layer in enumerate(self.layers):
h = layer.forward(h)
if ii < len(self.layers) - 1:
h, feat_prev = self.residual_layer.forward(h, feat_prev, step_idx=ii)
if self.cache_readouts:
self._readout_cache[ii] = h
return h
def get_init_kwargs(self) -> Dict[str, Any]:
"""
Get a dictionary that can be used to instanciate a new object with identical parameters.
"""
return deepcopy(
dict(
in_dim=self.in_dim,
out_dim=self.out_dim,
hidden_dims=self.hidden_dims,
depth=None,
activation=self.activation,
last_activation=self.last_activation,
dropout=self.dropout,
last_dropout=self.last_dropout,
normalization=self.normalization,
first_normalization=self.first_normalization,
last_normalization=self.last_normalization,
residual_type=self.residual_type,
residual_skip_steps=self.residual_skip_steps,
name=self.name,
layer_type=self.layer_class,
layer_kwargs=self.layer_kwargs,
last_layer_is_readout=self.last_layer_is_readout,
)
)
def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: bool = False) -> Dict[str, Any]:
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter:
divide_factor: Factor by which to divide the width.
factor_in_dim: Whether to factor the input dimension
"""
kwargs = self.get_init_kwargs()
kwargs["hidden_dims"] = [round(dim / divide_factor) for dim in kwargs["hidden_dims"]]
if factor_in_dim:
kwargs["in_dim"] = round(kwargs["in_dim"] / divide_factor)
if not self.last_layer_is_readout:
kwargs["out_dim"] = round(kwargs["out_dim"] / divide_factor)
return kwargs
def __repr__(self):
r"""
Controls how the class is printed
"""
class_str = f"{self.name}(depth={self.depth}, {self.residual_layer})\n "
layer_str = f"[{self.layer_class.__name__}[{' -> '.join(map(str, self.full_dims))}]"
return class_str + layer_str
class EnsembleFeedForwardNN(FeedForwardNN):
def __init__(
self,
in_dim: int,
out_dim: int,
hidden_dims: Union[List[int], int],
num_ensemble: int,
reduction: Union[str, Callable],
subset_in_dim: Union[float, int] = 1.0,
depth: Optional[int] = None,
activation: Union[str, Callable] = "relu",
last_activation: Union[str, Callable] = "none",
dropout: float = 0.0,
last_dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
first_normalization: Union[str, Callable] = "none",
last_normalization: Union[str, Callable] = "none",
residual_type: str = "none",
residual_skip_steps: int = 1,
name: str = "LNN",
layer_type: Union[str, nn.Module] = "ens-fc",
layer_kwargs: Optional[Dict] = None,
last_layer_is_readout: bool = False,
):
r"""
An ensemble of flexible neural network architecture, with variable hidden dimensions,
support for multiple layer types, and support for different residual
connections.
Parameters:
in_dim:
Input feature dimensions of the layer
out_dim:
Output feature dimensions of the layer
hidden_dims:
Either an integer specifying all the hidden dimensions,
or a list of dimensions in the hidden layers.
Be careful, the "simple" residual type only supports
hidden dimensions of the same value.
num_ensemble:
Number of MLPs that run in parallel.
reduction:
Reduction to use at the end of the MLP. Choices:
- "none" or `None`: No reduction
- "mean": Mean reduction
- "sum": Sum reduction
- "max": Max reduction
- "min": Min reduction
- "median": Median reduction
- `Callable`: Any callable function. Must take `dim` as a keyword argument.
subset_in_dim:
If float, ratio of the subset of the ensemble to use. Must be between 0 and 1.
If int, number of elements to subset from in_dim.
If `None`, the subset_in_dim is set to `1.0`.
A different subset is used for each ensemble.
Only valid if the input shape is `[B, Din]`.
depth:
If `hidden_dims` is an integer, `depth` is 1 + the number of
hidden layers to use.
If `hidden_dims` is a list, then
`depth` must be `None` or equal to `len(hidden_dims) + 1`
activation:
activation function to use in the hidden layers.
last_activation:
activation function to use in the last layer.
dropout:
The ratio of units to dropout. Must be between 0 and 1
last_dropout:
The ratio of units to dropout for the last_layer. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
first_normalization:
Whether to use batch normalization **before** the first layer
last_normalization:
Whether to use batch normalization in the last layer
residual_type:
- "none": No residual connection
- "simple": Residual connection similar to the ResNet architecture.
See class `ResidualConnectionSimple`
- "weighted": Residual connection similar to the Resnet architecture,
but with weights applied before the summation. See class `ResidualConnectionWeighted`
- "concat": Residual connection where the residual is concatenated instead
of being added.
- "densenet": Residual connection where the residual of all previous layers
are concatenated. This leads to a strong increase in the number of parameters
if there are multiple hidden layers.
residual_skip_steps:
The number of steps to skip between each residual connection.
If `1`, all the layers are connected. If `2`, half of the
layers are connected.
name:
Name attributed to the current network, for display and printing
purposes.
layer_type:
The type of layers to use in the network.
Either "ens-fc" as the `EnsembleFCLayer`, or a class representing the `nn.Module`
to use.
layer_kwargs:
The arguments to be used in the initialization of the layer provided by `layer_type`
last_layer_is_readout: Whether the last layer should be treated as a readout layer.
Allows to use the `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup
"""
# Parse the ensemble arguments
if layer_kwargs is None:
layer_kwargs = {}
layer_kwargs["num_ensemble"] = self._parse_num_ensemble(num_ensemble, layer_kwargs)
# Parse the sample input dimension
self.subset_in_dim, self.subset_idx = self._parse_subset_in_dim(in_dim, subset_in_dim, num_ensemble)
super().__init__(
in_dim=in_dim,
out_dim=out_dim,
hidden_dims=hidden_dims,
depth=depth,
activation=activation,
last_activation=last_activation,
dropout=dropout,
last_dropout=last_dropout,
normalization=normalization,
first_normalization=first_normalization,
last_normalization=last_normalization,
residual_type=residual_type,
residual_skip_steps=residual_skip_steps,
name=name,
layer_type=layer_type,
layer_kwargs=layer_kwargs,
last_layer_is_readout=last_layer_is_readout,
)
# Parse the reduction
self.reduction = reduction
self.reduction_fn = self._parse_reduction(reduction)
def _create_layers(self):
self.full_dims[0] = self.subset_in_dim
super()._create_layers()
def _parse_num_ensemble(self, num_ensemble: int, layer_kwargs) -> int:
r"""
Parse the num_ensemble argument.
"""
num_ensemble_out = num_ensemble
# Get the num_ensemble from the layer_kwargs if it exists
num_ensemble_2 = None
if layer_kwargs is None:
layer_kwargs = {}
else:
num_ensemble_2 = layer_kwargs.get("num_ensemble", None)
if num_ensemble is None:
num_ensemble_out = num_ensemble_2
# Check that the num_ensemble is consistent
if num_ensemble_2 is not None:
assert (
num_ensemble_2 == num_ensemble
), f"num_ensemble={num_ensemble} != num_ensemble_2={num_ensemble_2}"
# Check that `num_ensemble_out` is not None
assert (
num_ensemble_out is not None
), f"num_ensemble={num_ensemble} and num_ensemble_2={num_ensemble_2}"
return num_ensemble_out
def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optional[Callable]:
r"""
Parse the reduction argument.
"""
if isinstance(reduction, str):
reduction = reduction.lower()
if reduction is None or reduction == "none":
return None
elif reduction == "mean":
return torch.mean
elif reduction == "sum":
return torch.sum
elif reduction == "max":
def max_vals(x, dim):
return torch.max(x, dim=dim).values
return max_vals
elif reduction == "min":
def min_vals(x, dim):
return torch.min(x, dim=dim).values
return min_vals
elif reduction == "median":
def median_vals(x, dim):
return torch.median(x, dim=dim).values
return median_vals
elif callable(reduction):
return reduction
else:
raise ValueError(f"Unknown reduction {reduction}")
def _parse_subset_in_dim(
self, in_dim: int, subset_in_dim: Union[float, int], num_ensemble: int
) -> Tuple[float, int]:
r"""
Parse the subset_in_dim argument and the subset_in_dim.
The subset_in_dim is the ratio of the hidden features to use by each MLP of the ensemble.
The subset_in_dim is the number of input features to use by each MLP of the ensemble.
Parameters:
in_dim: The number of input features, before subsampling
subset_in_dim:
Ratio of the subset of features to use by each MLP of the ensemble.
Must be between 0 and 1. A different subset is used for each ensemble.
Only valid if the input shape is `[B, Din]`.
If None, the subset_in_dim is set to 1.0.
num_ensemble:
Number of MLPs that run in parallel.
Returns:
subset_in_dim: The ratio of the subset of features to use by each MLP of the ensemble.
subset_idx: The indices of the features to use by each MLP of the ensemble.
"""
# Parse the subset_in_dim, make sure value is between 0 and 1
subset_idx = None
if subset_in_dim is None:
return 1.0, None
if isinstance(subset_in_dim, int):
assert (
subset_in_dim > 0 and subset_in_dim <= in_dim
), f"subset_in_dim={subset_in_dim}, in_dim={in_dim}"
elif isinstance(subset_in_dim, float):
assert subset_in_dim > 0.0 and subset_in_dim <= 1.0, f"subset_in_dim={subset_in_dim}"
# Convert to integer value
subset_in_dim = int(in_dim * subset_in_dim)
if subset_in_dim == 0:
subset_in_dim = 1
# Create the subset_idx, which is a list of indices to use for each ensemble
if subset_in_dim != in_dim:
subset_idx = torch.stack([torch.randperm(in_dim)[:subset_in_dim] for _ in range(num_ensemble)])
return subset_in_dim, subset_idx
def _parse_layers(self, layer_type, residual_type):
# Parse the layer and residuals
from graphium.utils.spaces import ENSEMBLE_LAYERS_DICT, RESIDUALS_DICT
self.layer_class, self.layer_name = self._parse_class_from_dict(layer_type, ENSEMBLE_LAYERS_DICT)
self.residual_class, self.residual_name = self._parse_class_from_dict(residual_type, RESIDUALS_DICT)
def forward(self, h: torch.Tensor) -> torch.Tensor:
r"""
Subset the hidden dimension for each MLP,
forward the ensemble MLP on the input features,
then reduce the output if specified.
Parameters:
h: `torch.Tensor[B, Din]` or `torch.Tensor[..., 1, B, Din]` or `torch.Tensor[..., L, B, Din]`:
Input feature tensor, before the MLP.
`Din` is the number of input features, `B` is the batch size, and `L` is the number of ensembles.
Returns:
`torch.Tensor[..., L, B, Dout]` or `torch.Tensor[..., B, Dout]`:
Output feature tensor, after the MLP.
`Dout` is the number of output features, `B` is the batch size, and `L` is the number of ensembles.
`L` is removed if a reduction is specified.
"""
# Subset the input features for each MLP in the ensemble
if self.subset_idx is not None:
if len(h.shape) != 2:
assert (
h.shape[-3] == 1
), f"Expected shape to be [B, Din] or [..., 1, B, Din] when using `subset_in_dim`, got {h.shape}."
h = h[..., self.subset_idx].transpose(-2, -3)
# Run the standard forward pass
h = super().forward(h)
# Reduce the output if specified
if self.reduction_fn is not None:
h = self.reduction_fn(h, dim=-3)
return h
def get_init_kwargs(self) -> Dict[str, Any]:
"""
Get a dictionary that can be used to instanciate a new object with identical parameters.
"""
kw = super().get_init_kwargs()
kw["num_ensemble"] = self.num_ensemble
kw["reduction"] = self.reduction
return kw
def __repr__(self):
r"""
Controls how the class is printed
"""
class_str = f"{self.name}(depth={self.depth}, {self.residual_layer})\n , num_ensemble={self.num_ensemble}, reduction={self.reduction}\n "
layer_str = f"[{self.layer_class.__name__}[{' -> '.join(map(str, self.full_dims))}]"
return class_str + layer_str
class FeedForwardGraph(FeedForwardNN):
def __init__(
self,
in_dim: int,
out_dim: int,
hidden_dims: Union[List[int], int],
layer_type: Union[str, nn.Module],
depth: Optional[int] = None,
activation: Union[str, Callable] = "relu",
last_activation: Union[str, Callable] = "none",
dropout: float = 0.0,
last_dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
first_normalization: Union[str, Callable] = "none",
last_normalization: Union[str, Callable] = "none",
residual_type: str = "none",
residual_skip_steps: int = 1,
in_dim_edges: int = 0,
hidden_dims_edges: List[int] = [],
out_dim_edges: Optional[int] = None,
name: str = "GNN",
layer_kwargs: Optional[Dict] = None,
virtual_node: str = "none",
use_virtual_edges: bool = False,
last_layer_is_readout: bool = False,
):
r"""
A flexible neural network architecture, with variable hidden dimensions,
support for multiple layer types, and support for different residual
connections.
This class is meant to work with different graph neural networks
layers. Any layer must inherit from `graphium.nn.base_graph_layer.BaseGraphStructure`
or `graphium.nn.base_graph_layer.BaseGraphLayer`.
Parameters:
in_dim:
Input feature dimensions of the layer
out_dim:
Output feature dimensions of the layer
hidden_dims:
List of dimensions in the hidden layers.
Be careful, the "simple" residual type only supports
hidden dimensions of the same value.
layer_type:
Type of layer to use. Can be a string or nn.Module.
depth:
If `hidden_dims` is an integer, `depth` is 1 + the number of
hidden layers to use. If `hidden_dims` is a `list`, `depth` must
be `None`.
activation:
activation function to use in the hidden layers.
last_activation:
activation function to use in the last layer.
dropout:
The ratio of units to dropout. Must be between 0 and 1
last_dropout:
The ratio of units to dropout for the last layer. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
first_normalization:
Whether to use batch normalization **before** the first layer
last_normalization:
Whether to use batch normalization in the last layer
residual_type:
- "none": No residual connection
- "simple": Residual connection similar to the ResNet architecture.
See class `ResidualConnectionSimple`
- "weighted": Residual connection similar to the Resnet architecture,
but with weights applied before the summation. See class `ResidualConnectionWeighted`
- "concat": Residual connection where the residual is concatenated instead
of being added.
- "densenet": Residual connection where the residual of all previous layers
are concatenated. This leads to a strong increase in the number of parameters
if there are multiple hidden layers.
residual_skip_steps:
The number of steps to skip between each residual connection.
If `1`, all the layers are connected. If `2`, half of the
layers are connected.
in_dim_edges:
Input edge-feature dimensions of the network. Keep at 0 if not using
edge features, or if the layer doesn't support edges.
hidden_dims_edges:
Hidden dimensions for the edges. Most models don't support it, so it
should only be used for those that do, i.e. `GatedGCNLayer`
out_dim_edges:
Output edge-feature dimensions of the network. Keep at 0 if not using
edge features, or if the layer doesn't support edges. Defaults to the
last value of hidden_dims_edges.
name:
Name attributed to the current network, for display and printing
purposes.
layer_type:
The type of layers to use in the network.
A class that inherits from `graphium.nn.base_graph_layer.BaseGraphStructure`,
or one of the following strings
- "pyg:gin": GINConvPyg
- "pyg:gine": GINEConvPyg
- "pyg:gated-gcn": GatedGCNPyg
- "pyg:pna-msgpass": PNAMessagePassingPyg
layer_kwargs:
The arguments to be used in the initialization of the layer provided by `layer_type`
virtual_node:
A string associated to the type of virtual node to use,
either `None`, "none", "mean", "sum", "max", "logsum".
See `graphium.nn.pooling_pyg.VirtualNode`.
The virtual node will not use any residual connection if `residual_type`
is "none". Otherwise, it will use a simple ResNet like residual
connection.
use_virtual_edges:
A bool flag used to select if the virtual node should use the edges or not
last_layer_is_readout: Whether the last layer should be treated as a readout layer.
Allows to use the `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup
"""
# Initialize the additional attributes
self.in_dim_edges = in_dim_edges
if isinstance(hidden_dims_edges, int):
self.hidden_dims_edges = [hidden_dims_edges] * (depth - 1)
elif len(hidden_dims_edges) == 0:
self.hidden_dims_edges = []
else:
self.hidden_dims_edges = list(hidden_dims_edges)
assert depth is None
self.out_dim_edges = (
out_dim_edges
if out_dim_edges is not None
else self.hidden_dims_edges[-1] if self.hidden_dims_edges else 0
)
self.full_dims_edges = None
if len(self.hidden_dims_edges) or self.out_dim_edges > 0:
assert self.out_dim_edges > 0, self.out_dim_edges
self.full_dims_edges = [self.in_dim_edges] + self.hidden_dims_edges + [self.out_dim_edges]
self.virtual_node = virtual_node.lower() if virtual_node is not None else "none"
self.use_virtual_edges = use_virtual_edges
self.virtual_node_class = self._parse_virtual_node_class()
# Initialize the parent `FeedForwardNN`
super().__init__(
in_dim=in_dim,
out_dim=out_dim,
hidden_dims=hidden_dims,
depth=depth,
activation=activation,
last_activation=last_activation,
normalization=normalization,
first_normalization=first_normalization,
last_normalization=last_normalization,
residual_type=residual_type,
residual_skip_steps=residual_skip_steps,
name=name,
layer_type=layer_type,
dropout=dropout,
last_dropout=last_dropout,
layer_kwargs=layer_kwargs,
last_layer_is_readout=last_layer_is_readout,
)
self.first_normalization_edges = get_norm(first_normalization, dim=in_dim_edges)
def _check_bad_arguments(self):
r"""
Raise comprehensive errors if the arguments seem wrong
"""
super()._check_bad_arguments()
if (
(self.in_dim_edges > 0) or (self.full_dims_edges is not None)
) and not self.layer_class.layer_supports_edges:
raise ValueError(f"Cannot use edge features with class `{self.layer_class}`")
def get_nested_key(self, d, target_key):
"""
Get the value associated with a key in a nested dictionary.
Parameters:
- d: The dictionary to search in
- target_key: The key to search for
Returns:
- The value associated with the key if found, None otherwise
"""
if target_key in d:
return d[target_key]
for key, value in d.items():
if isinstance(value, (dict, DictConfig)):
nested_result = self.get_nested_key(value, target_key)
if nested_result is not None:
return nested_result
return None
def _create_layers(self):
r"""
Create all the necessary layers for the network.
It's a bit complicated to explain what's going on in this function,
but it must manage the varying features sizes caused by:
- The presence of different types of residual connections
- The presence or absence of edges
- The output dimensions varying for different networks i.e. `GatLayer` outputs different feature sizes according to the number of heads
- The presence or absence of virtual nodes
- The different possible pooling, and the concatenation of multiple pooling together.
"""
residual_layer_temp, residual_out_dims = self._create_residual_connection(out_dims=self.full_dims[1:])
# Create a ModuleList of the GNN layers
self.layers = nn.ModuleList()
self.virtual_node_layers = nn.ModuleList()
this_in_dim = self.full_dims[0]
this_activation = self.activation
this_norm = self.normalization
this_dropout = self.dropout
# Find the appropriate edge dimensions, depending if edges are used,
# And if the residual is required for the edges
this_in_dim_edges, this_out_dim_edges = None, None
if self.full_dims_edges is not None:
this_in_dim_edges, this_out_dim_edges = self.full_dims_edges[0:2]
residual_out_dims_edges = residual_layer_temp.get_true_out_dims(self.full_dims_edges[1:])
elif self.in_dim_edges > 0:
this_in_dim_edges = self.in_dim_edges
layer_out_dims_edges = []
# Create all the layers in a loop
for ii in range(self.depth):
this_out_dim = self.full_dims[ii + 1]
# Find the edge key-word arguments depending on the layer type and residual connection
this_edge_kwargs = {}
if self.layer_class.layer_supports_edges and self.in_dim_edges > 0:
this_edge_kwargs["in_dim_edges"] = this_in_dim_edges
if "out_dim_edges" in inspect.signature(self.layer_class.__init__).parameters.keys():
if self.full_dims_edges is not None:
this_out_dim_edges = self.full_dims_edges[ii + 1]
this_edge_kwargs["out_dim_edges"] = this_out_dim_edges
else:
this_out_dim_edges = self.get_nested_key(self.layer_kwargs, "out_dim_edges")
this_edge_kwargs["out_dim_edges"] = this_out_dim_edges
layer_out_dims_edges.append(this_out_dim_edges)
# Create the GNN layer
self.layers.append(
self.layer_class(
in_dim=this_in_dim,
out_dim=this_out_dim,
activation=this_activation,
dropout=this_dropout,
normalization=this_norm,
layer_idx=ii,
layer_depth=self.depth,
**self.layer_kwargs,
**this_edge_kwargs,
)
)
# Create the Virtual Node layer, except at the last layer
if ii < len(residual_out_dims):
self.virtual_node_layers.append(
self.virtual_node_class(
in_dim=this_out_dim * self.layers[-1].out_dim_factor,
out_dim=this_out_dim * self.layers[-1].out_dim_factor,
in_dim_edges=this_out_dim_edges,
out_dim_edges=this_out_dim_edges,
activation=this_activation,
dropout=this_dropout,
normalization=this_norm,
bias=True,
vn_type=self.virtual_node,
residual=self.residual_type is not None,
use_edges=self.use_virtual_edges,
)
)
# Get the true input dimension of the next layer,
# by factoring both the residual connection and GNN layer type
if ii < len(residual_out_dims):
this_in_dim = residual_out_dims[ii] * self.layers[ii - 1].out_dim_factor
if self.full_dims_edges is not None:
this_in_dim_edges = residual_out_dims_edges[ii] * self.layers[ii - 1].out_dim_factor
layer_out_dims = [layer.out_dim_factor * layer.out_dim for layer in self.layers]
# Initialize residual and pooling layers
self.residual_layer, _ = self._create_residual_connection(out_dims=layer_out_dims)
if len(layer_out_dims_edges) > 0:
self.residual_edges_layer, _ = self._create_residual_connection(out_dims=layer_out_dims_edges)
else:
self.residual_edges_layer = None
def _graph_layer_forward(
self,
layer: BaseGraphModule,
g: Batch,
feat: Tensor,
edge_feat: Optional[Tensor],
feat_prev: Optional[Tensor],
edge_feat_prev: Optional[Tensor],
step_idx: int,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
r"""
A flexible neural network architecture, with variable hidden dimensions,
support for multiple layer types, and support for different residual
connections.
This class is meant to work with different PyG-based graph neural networks
layers. Any layer must inherit from `graphium.nn.base_graph_layer.BaseGraphStructure`
or `graphium.nn.base_graph_layer.BaseGraphLayer`.
Apply the *i-th* PyG graph layer, where *i* is the index given by `step_idx`.
The layer is applied differently depending if there are edge features or not.
Then, the residual is also applied on both the features and the edges (if applicable)
Parameters:
layer:
The PyG layer used for the convolution
g:
graph on which the convolution is done
feat (torch.Tensor[..., N, Din]):
Node feature tensor, before convolution.
`N` is the number of nodes, `Din` is the input features
edge_feat (torch.Tensor[..., N, Ein]):
Edge feature tensor, before convolution.
`N` is the number of nodes, `Ein` is the input edge features
feat_prev:
Node feature of the previous residual connection, or `None`
edge_feat_prev:
Edge feature of the previous residual connection, or `None`
step_idx:
The current step idx in the forward loop
Returns:
feat (torch.Tensor[..., N, Dout]):
Node feature tensor, after convolution and residual.
`N` is the number of nodes, `Dout` is the output features of the layer and residual
edge_feat:
Edge feature tensor, after convolution and residual.
`N` is the number of nodes, `Ein` is the input edge features
feat_prev:
Node feature tensor to be used at the next residual connection, or `None`
edge_feat_prev:
Edge feature tensor to be used at the next residual connection, or `None`
"""
# Set node / edge features into the graph
g["feat"] = feat
g["edge_feat"] = edge_feat
# Apply the GNN layer
g = layer(g)
# Get the node / edge features from the graph
feat = g["feat"]
edge_feat = g["edge_feat"]
# Apply the residual layers on the features and edges (if applicable)
if step_idx < len(self.layers) - 1:
feat, feat_prev = self.residual_layer.forward(feat, feat_prev, step_idx=step_idx)
if (self.residual_edges_layer is not None) and (layer.layer_outputs_edges):
edge_feat, edge_feat_prev = self.residual_edges_layer.forward(
edge_feat, edge_feat_prev, step_idx=step_idx
)
return feat, edge_feat, feat_prev, edge_feat_prev
def _parse_virtual_node_class(self) -> type:
return VirtualNodePyg
def _virtual_node_forward(
self,
g: Data,
feat: torch.Tensor,
vn_feat: torch.Tensor,
step_idx: int,
edge_feat: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Apply the *i-th* virtual node layer, where *i* is the index given by `step_idx`.
Parameters:
g:
graph on which the convolution is done
feat (torch.Tensor[..., N, Din]):
Node feature tensor, before convolution.
`N` is the number of nodes, `Din` is the input features
vn_feat (torch.Tensor[..., M, Din]):
Graph feature of the previous virtual node, or `None`
`M` is the number of graphs, `Din` is the input features
It is added to the result after the MLP, as a residual connection
step_idx:
The current step idx in the forward loop
edge_feat: torch.Tensor
Returns:
`feat = torch.Tensor[..., N, Dout]`:
Node feature tensor, after convolution and residual.
`N` is the number of nodes, `Dout` is the output features of the layer and residual
`vn_feat = torch.Tensor[..., M, Dout]`:
Graph feature tensor to be used at the next virtual node, or `None`
`M` is the number of graphs, `Dout` is the output features
"""
if step_idx < len(self.virtual_node_layers):
feat, vn_feat, edge_feat = self.virtual_node_layers[step_idx].forward(
g=g, feat=feat, vn_feat=vn_feat, edge_feat=edge_feat
)
return feat, vn_feat, edge_feat
def forward(self, g: Batch) -> torch.Tensor:
r"""
Apply the full graph neural network on the input graph and node features.
Parameters:
g:
pyg Batch graph on which the convolution is done with the keys:
- `"feat"`: torch.Tensor[..., N, Din]
Node feature tensor, before convolution.
`N` is the number of nodes, `Din` is the input features
- `"edge_feat"` (torch.Tensor[..., N, Ein]):
Edge feature tensor, before convolution.
`N` is the number of nodes, `Ein` is the input edge features
Returns:
`torch.Tensor[..., M, Dout]` or `torch.Tensor[..., N, Dout]`:
Node or graph feature tensor, after the network.
`N` is the number of nodes, `M` is the number of graphs,
`Dout` is the output dimension ``self.out_dim``
If the `self.pooling` is [`None`], then it returns node features and the output dimension is `N`,
otherwise it returns graph features and the output dimension is `M`
"""
# Initialize values of the residuals and virtual node
feat_prev = None
edge_feat_prev = None
vn_feat = 0.0
feat = g["feat"]
edge_feat = g["edge_feat"]
# Apply the normalization before the first network layers
if self.first_normalization is not None:
feat = self.first_normalization(feat)
if (self.first_normalization_edges is not None) and (self.in_dim_edges > 0):
edge_feat = self.first_normalization_edges(edge_feat)
# Apply the forward loop of the layers, residuals and virtual nodes
for ii, layer in enumerate(self.layers):
feat, edge_feat, feat_prev, edge_feat_prev = self._graph_layer_forward(
layer=layer,
g=g,
feat=feat,
edge_feat=edge_feat,
feat_prev=feat_prev,
edge_feat_prev=edge_feat_prev,
step_idx=ii,
)
feat, vn_feat, edge_feat = self._virtual_node_forward(
g=g, feat=feat, edge_feat=edge_feat, vn_feat=vn_feat, step_idx=ii
)
if self.cache_readouts:
self._readout_cache[ii] = feat
g["feat"], g["edge_feat"] = feat, edge_feat
return g
def get_init_kwargs(self) -> Dict[str, Any]:
"""
Get a dictionary that can be used to instanciate a new object with identical parameters.
"""
kwargs = super().get_init_kwargs()
new_kwargs = dict(
in_dim_edges=self.in_dim_edges,
hidden_dims_edges=self.hidden_dims_edges,
out_dim_edges=self.out_dim_edges,
virtual_node=self.virtual_node,
use_virtual_edges=self.use_virtual_edges,
)
kwargs.update(new_kwargs)
return deepcopy(kwargs)
def make_mup_base_kwargs(
self,
divide_factor: float = 2.0,
factor_in_dim: bool = False,
) -> Dict[str, Any]:
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter:
divide_factor: Factor by which to divide the width.
factor_in_dim: Whether to factor the input dimension for the nodes
Returns:
kwargs: Dictionary of parameters to be used to instanciate the base model divided by the factor
"""
kwargs = self.get_init_kwargs()
kwargs["hidden_dims"] = [round(dim / divide_factor) for dim in kwargs["hidden_dims"]]
kwargs["hidden_dims_edges"] = [round(dim / divide_factor) for dim in kwargs["hidden_dims_edges"]]
if factor_in_dim:
kwargs["in_dim"] = round(kwargs["in_dim"] / divide_factor)
kwargs["in_dim_edges"] = round(kwargs["in_dim_edges"] / divide_factor)
if not self.last_layer_is_readout:
kwargs["out_dim"] = round(kwargs["out_dim"] / divide_factor)
kwargs["out_dim_edges"] = round(kwargs["out_dim_edges"] / divide_factor)
def _recursive_divide_dim(x: collections.abc.Mapping):
for k, v in x.items():
if isinstance(v, collections.abc.Mapping):
_recursive_divide_dim(v)
elif k in ["in_dim", "out_dim", "in_dim_edges", "out_dim_edges"]:
x[k] = round(v / divide_factor)
elif k in ["embed_dim"]:
x[k] = round(v / divide_factor)
_recursive_divide_dim(kwargs["layer_kwargs"])
return kwargs
def __repr__(self):
r"""
Controls how the class is printed
"""
class_str = f"{self.name}(depth={self.depth}, {self.residual_layer})\n "
layer_str = f"{self.layer_class.__name__}[{' -> '.join(map(str, self.full_dims))}]\n "
return class_str + layer_str
class FullGraphMultiTaskNetwork(nn.Module, MupMixin):
def __init__(
self,
gnn_kwargs: Dict[str, Any],
pre_nn_kwargs: Optional[Dict[str, Any]] = None,
pre_nn_edges_kwargs: Optional[Dict[str, Any]] = None,
pe_encoders_kwargs: Optional[Dict[str, Any]] = None,
task_heads_kwargs: Optional[Dict[str, Any]] = None,
graph_output_nn_kwargs: Optional[Dict[str, Any]] = None,
accelerator_kwargs: Optional[Dict[str, Any]] = None,
num_inference_to_average: int = 1,
last_layer_is_readout: bool = False,
name: str = "FullGNN",
):
r"""
Class that allows to implement a full graph neural network architecture,
including the pre-processing MLP and the post processing MLP.
Parameters:
gnn_kwargs:
key-word arguments to use for the initialization of the pre-processing
GNN network using the class `FeedForwardGraph`.
It must respect the following criteria:
- gnn_kwargs["in_dim"] must be equal to pre_nn_kwargs["out_dim"]
- gnn_kwargs["out_dim"] must be equal to graph_output_nn_kwargs["in_dim"]
pe_encoders_kwargs:
key-word arguments to use for the initialization of all positional encoding encoders.
See the class `EncoderManager` for more details.
pre_nn_kwargs:
key-word arguments to use for the initialization of the pre-processing
MLP network of the node features before the GNN, using the class `FeedForwardNN`.
If `None`, there won't be a pre-processing MLP.
pre_nn_edges_kwargs:
key-word arguments to use for the initialization of the pre-processing
MLP network of the edge features before the GNN, using the class `FeedForwardNN`.
If `None`, there won't be a pre-processing MLP.
task_heads_kwargs:
This argument is a list of dictionaries containing the arguments for task heads. Each argument is used to
initialize a task-specific MLP.
graph_output_nn_kwargs:
This argument is a list of dictionaries corresponding to the arguments for a FeedForwardNN.
Each dict of arguments is used to initialize a shared MLP.
accelerator_kwargs:
key-word arguments specific to the accelerator being used,
e.g. pipeline split points
num_inference_to_average:
Number of inferences to average at val/test time. This is used to avoid the noise introduced
by positional encodings with sign-flips. In case no such encoding is given,
this parameter is ignored.
NOTE: The inference time will be slowed-down proportionaly to this parameter.
last_layer_is_readout: Whether the last layer should be treated as a readout layer.
Allows to use the `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup
name:
Name attributed to the current network, for display and printing
purposes.
"""
super().__init__()
self.name = name
self.num_inference_to_average = num_inference_to_average
self.last_layer_is_readout = last_layer_is_readout
self._concat_last_layers = None
self.pre_nn, self.pre_nn_edges, self.task_heads = None, None, None
self.pe_encoders_kwargs = deepcopy(pe_encoders_kwargs)
self.graph_output_nn_kwargs = graph_output_nn_kwargs
self.encoder_manager = EncoderManager(pe_encoders_kwargs)
self.max_num_nodes_per_graph = None
self.max_num_edges_per_graph = None
self._cache_readouts = False
# Initialize the pre-processing neural net for nodes (applied directly on node features)
if pre_nn_kwargs is not None:
name = pre_nn_kwargs.pop("name", "pre-NN")
self.pre_nn = FeedForwardNN(**pre_nn_kwargs, name=name)
next_in_dim = self.pre_nn.out_dim
gnn_kwargs.setdefault("in_dim", next_in_dim)
assert (
next_in_dim == gnn_kwargs["in_dim"]
), f"Inconsistent dimensions between pre-NN output ({next_in_dim}) and GNN input ({gnn_kwargs['in_dim']})"
# Initialize the pre-processing neural net for edges (applied directly on edge features)
if pre_nn_edges_kwargs is not None:
name = pre_nn_edges_kwargs.pop("name", "pre-NN-edges")
self.pre_nn_edges = FeedForwardNN(**pre_nn_edges_kwargs, name=name)
next_in_dim = self.pre_nn_edges.out_dim
gnn_kwargs.setdefault("in_dim_edges", next_in_dim)
assert (
next_in_dim == gnn_kwargs["in_dim_edges"]
), f"Inconsistent dimensions between pre-NN-edges output ({next_in_dim}) and GNN input ({gnn_kwargs['in_dim_edges']})"
# Initialize the graph neural net (applied after the pre_nn)
name = gnn_kwargs.pop("name", "GNN")
gnn_class = self._parse_feed_forward_gnn(gnn_kwargs)
gnn_kwargs.setdefault(
"last_layer_is_readout", self.last_layer_is_readout and (task_heads_kwargs is None)
)
self.gnn = gnn_class(**gnn_kwargs, name=name)
next_in_dim = self.gnn.out_dim
if task_heads_kwargs is not None:
self.task_heads = TaskHeads(
in_dim=self.out_dim,
in_dim_edges=self.out_dim_edges,
task_heads_kwargs=task_heads_kwargs,
graph_output_nn_kwargs=graph_output_nn_kwargs,
)
self._task_heads_kwargs = task_heads_kwargs
if accelerator_kwargs is not None:
accelerator = accelerator_kwargs["_accelerator"]
if accelerator == "ipu":
self._apply_ipu_options(accelerator_kwargs)
self._check_bad_arguments()
@staticmethod
def _parse_feed_forward_gnn(
gnn_kwargs: Dict[str, Any],
):
"""
Parse the key-word arguments to determine which `FeedForward` class to use.
Parameters:
gnn_kwargs: key-word arguments to use for the initialization of the GNN network.
Returns:
`FeedForwardPyg` class
"""
return FeedForwardGraph
def _check_bad_arguments(self):
r"""
Raise comprehensive errors if the arguments seem wrong
"""
if self.pre_nn is not None:
if self.pre_nn.out_dim != self.gnn.in_dim:
raise ValueError(
f"`self.pre_nn.out_dim` must be equal to `self.gnn.in_dim`."
+ 'Provided" {self.pre_nn.out_dim} and {self.gnn.in_dim}'
)
if self.task_heads is not None:
edge_level_tasks = [
task_name
for task_name, head_kwargs in self._task_heads_kwargs.items()
if head_kwargs["task_level"] == "edge"
]
if len(edge_level_tasks) > 0 and (
not self.gnn.layer_class.layer_supports_edges or self.out_dim_edges < 1
):
raise ValueError(
f"Task heads have edge level tasks {', '.join(edge_level_tasks)}, but edge level tasks cannot be used with layer class `{self.gnn.layer_class}`"
)
graph_level_tasks = [
task_name
for task_name, head_kwargs in self._task_heads_kwargs.items()
if head_kwargs["task_level"] == "graph"
]
if len(graph_level_tasks) > 0 and self.graph_output_nn_kwargs["graph"]["pooling"] == ["none"]:
raise ValueError(
f"Task heads have graph level tasks {', '.join(graph_level_tasks)}, but pooling is none."
)
def _apply_ipu_options(self, ipu_kwargs):
gnn_layers_per_ipu = ipu_kwargs.get("gnn_layers_per_ipu")
self._apply_ipu_pipeline_split(gnn_layers_per_ipu)
def _apply_ipu_pipeline_split(self, gnn_layers_per_ipu):
r"""
Apply pipeline split from accelerator options if applicable
"""
if gnn_layers_per_ipu is None:
return
if not isinstance(gnn_layers_per_ipu, collections.abc.Sequence):
raise ValueError("gnn_layers_per_ipu must be a Sequence (e.g. a list)")
valid_ipu_pipeline_lengths = [1, 2, 4, 8, 16]
pipeline_length = len(gnn_layers_per_ipu)
if pipeline_length not in valid_ipu_pipeline_lengths:
raise ValueError(
f"Length of gnn_layers_per_ipu must be one of {valid_ipu_pipeline_lengths}, "
f"got {gnn_layers_per_ipu} of length {pipeline_length} instead"
)
model_depth = len(self.gnn.layers)
if sum(gnn_layers_per_ipu) != model_depth:
raise ValueError(
f"The values in gnn_layers_per_ipu must add up to the depth of the model, "
f"got {gnn_layers_per_ipu} with total {sum(gnn_layers_per_ipu)} vs model depth "
f"of {model_depth}"
)
begin_block_layer_indices = [sum(gnn_layers_per_ipu[:i]) for i in range(1, pipeline_length)]
for begin_block_layer_index, ipu_id in zip(begin_block_layer_indices, range(1, pipeline_length)):
self.gnn.layers[begin_block_layer_index] = poptorch.BeginBlock(
self.gnn.layers[begin_block_layer_index], ipu_id=ipu_id
)
def _enable_readout_cache(self, module_filter: Optional[Union[str, List[str]]]):
"""
Enable a single-batch readout cache for (a subset of) the modules.
This is used to extract hidden representations for fingerprinting.
"""
self.create_module_map(level="module")
for k in module_filter:
if k not in self._module_map:
raise ValueError(f"Module {k} not found in network, choose from {self._module_map.keys()}")
if module_filter is None:
module_filter = list(self._module_map.keys())
if isinstance(module_filter, str):
module_filter = [module_filter]
for module_name, module in self._module_map.items():
if module_name in module_filter:
if not isinstance(module, FeedForwardNN):
raise RuntimeError(
f"Readout cache can only be enabled for FeedForwardNN subclasses, not {type(module)}"
)
module._enable_readout_cache()
self._cache_readouts = True
def _disable_readout_cache(self):
"""Disable the readout cache"""
self.create_module_map(level="module")
for _, module in self._module_map.items():
if isinstance(module, FeedForwardNN):
module._disable_readout_cache()
self._cache_readouts = False
def create_module_map(self, level: Union[Literal["layers"], Literal["module"]] = "layers"):
"""
Function to create mapping between each (sub)module name and corresponding nn.ModuleList() (if possible);
Used for finetuning when (partially) loading or freezing specific modules of the pretrained model
Args:
level: Whether to map to the module object or the layers of the module object
"""
self._module_map = OrderedDict()
if self.encoder_manager is not None:
self._module_map.update(
{"pe_encoders": self.encoder_manager}
) # could be extended to submodules, e.g. pe_encoders/la_pos/linear_in/..., etc.; not necessary for current finetuning
if self.pre_nn is not None:
self._module_map.update({"pre_nn": self.pre_nn})
if self.pre_nn_edges is not None:
self._module_map.update({"pre_nn_edges": self.pre_nn_edges})
# No need to check for NoneType as GNN module is not optional in FullGraphMultitaskNetwork
self._module_map.update({"gnn": self.gnn})
if self.task_heads is not None:
self._module_map.update(
{
"graph_output_nn-"
+ output_level: self.task_heads.graph_output_nn[output_level].graph_output_nn
for output_level in self.task_heads.graph_output_nn.keys()
}
)
self._module_map.update(
{
"task_heads-" + task_head_name: self.task_heads.task_heads[task_head_name]
for task_head_name in self.task_heads.task_heads.keys()
}
)
if level == "layers":
for module_name, module in self._module_map.items():
if module_name != "pe_encoders":
self._module_map[module_name] = module.layers
return self._module_map
def forward(self, g: Batch) -> Tensor:
r"""
Apply the pre-processing neural network, the graph neural network,
and the post-processing neural network on the graph features.
Parameters:
g:
pyg Batch graph on which the convolution is done.
Must contain the following elements:
- Node key `"feat"`: `torch.Tensor[..., N, Din]`.
Input node feature tensor, before the network.
`N` is the number of nodes, `Din` is the input features dimension ``self.pre_nn.in_dim``
- Edge key `"edge_feat"`: `torch.Tensor[..., N, Ein]` **Optional**.
The edge features to use. It will be ignored if the
model doesn't supporte edge features or if
`self.in_dim_edges==0`.
- Other keys related to positional encodings `"pos_enc_feats_sign_flip"`,
`"pos_enc_feats_no_flip"`.
Returns:
`torch.Tensor[..., M, Dout]` or `torch.Tensor[..., N, Dout]`:
Node or graph feature tensor, after the network.
`N` is the number of nodes, `M` is the number of graphs,
`Dout` is the output dimension ``self.graph_output_nn.out_dim``
If the `self.gnn.pooling` is [`None`], then it returns node features and the output dimension is `N`,
otherwise it returns graph features and the output dimension is `M`
"""
# Apply the positional encoders
g = self.encoder_manager(g)
e = None
# Run the pre-processing network on node features
if self.pre_nn is not None:
g["feat"] = self.pre_nn.forward(g["feat"])
# Run the pre-processing network on edge features
# If there are no edges, skip the forward and change the dimension of e
if self.pre_nn_edges is not None:
e = g["edge_feat"]
if torch.prod(torch.as_tensor(e.shape[:-1])) == 0:
e = torch.zeros(
list(e.shape[:-1]) + [self.pre_nn_edges.out_dim], device=e.device, dtype=e.dtype
)
else:
e = self.pre_nn_edges.forward(e)
g["edge_feat"] = e
# Run the graph neural network
g = self.gnn.forward(g)
if self.task_heads is not None:
return self.task_heads.forward(g)
return g
def make_mup_base_kwargs(self, divide_factor: float = 2.0) -> Dict[str, Any]:
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter:
divide_factor: Factor by which to divide the width.
Returns:
Dictionary with the kwargs to create the base model.
"""
kwargs = dict(
gnn_kwargs=None,
pre_nn_kwargs=None,
pre_nn_edges_kwargs=None,
pe_encoders_kwargs=None,
num_inference_to_average=self.num_inference_to_average,
last_layer_is_readout=self.last_layer_is_readout,
name=self.name,
)
# For the pre-nn network, get the smaller dimensions.
# For the input dim, only divide the features coming from the pe-encoders
if self.pre_nn is not None:
kwargs["pre_nn_kwargs"] = self.pre_nn.make_mup_base_kwargs(
divide_factor=divide_factor, factor_in_dim=False
)
pe_enc_outdim = 0 if self.encoder_manager is None else self.pe_encoders_kwargs.get("out_dim", 0)
pre_nn_indim = kwargs["pre_nn_kwargs"]["in_dim"] - pe_enc_outdim
kwargs["pre_nn_kwargs"]["in_dim"] = round(pre_nn_indim + (pe_enc_outdim / divide_factor))
# For the pre-nn on the edges, factor all dimensions, except the in_dim
if self.pre_nn_edges is not None:
kwargs["pre_nn_edges_kwargs"] = self.pre_nn_edges.make_mup_base_kwargs(
divide_factor=divide_factor, factor_in_dim=False
)
pe_enc_edge_outdim = (
0 if self.encoder_manager is None else self.pe_encoders_kwargs.get("edge_out_dim", 0)
)
pre_nn_edge_indim = kwargs["pre_nn_edges_kwargs"]["in_dim"] - pe_enc_edge_outdim
kwargs["pre_nn_edges_kwargs"]["in_dim"] = round(
pre_nn_edge_indim + (pe_enc_edge_outdim / divide_factor)
)
# For the pe-encoders, don't factor the in_dim and in_dim_edges
if self.encoder_manager is not None:
kwargs["pe_encoders_kwargs"] = self.encoder_manager.make_mup_base_kwargs(
divide_factor=divide_factor
)
if self.task_heads is not None:
task_heads_kwargs = self.task_heads.make_mup_base_kwargs(
divide_factor=divide_factor, factor_in_dim=True
)
kwargs["task_heads_kwargs"] = task_heads_kwargs["task_heads_kwargs"]
kwargs["graph_output_nn_kwargs"] = task_heads_kwargs["graph_output_nn_kwargs"]
# For the gnn network, all the dimension are divided, except the input dims if pre-nn are missing
if self.gnn is not None:
factor_in_dim = self.pre_nn is not None
kwargs["gnn_kwargs"] = self.gnn.make_mup_base_kwargs(
divide_factor=divide_factor,
factor_in_dim=factor_in_dim,
)
return kwargs
def set_max_num_nodes_edges_per_graph(self, max_nodes: Optional[int], max_edges: Optional[int]) -> None:
"""
Set the maximum number of nodes and edges for all gnn layers and encoder layers
Parameters:
max_nodes: Maximum number of nodes in the dataset.
This will be useful for certain architecture, but ignored by others.
max_edges: Maximum number of edges in the dataset.
This will be useful for certain architecture, but ignored by others.
"""
self.max_num_nodes_per_graph = max_nodes
self.max_num_edges_per_graph = max_edges
if (self.encoder_manager is not None) and (self.encoder_manager.pe_encoders is not None):
for encoder in self.encoder_manager.pe_encoders.values():
encoder.max_num_nodes_per_graph = max_nodes
encoder.max_num_edges_per_graph = max_edges
if self.gnn is not None:
for layer in self.gnn.layers:
if isinstance(layer, BaseGraphStructure):
layer.max_num_nodes_per_graph = max_nodes
layer.max_num_edges_per_graph = max_edges
self.task_heads.set_max_num_nodes_edges_per_graph(max_nodes, max_edges)
def __repr__(self) -> str:
r"""
Controls how the class is printed
"""
pre_nn_str, pre_nn_edges_str = "", ""
if self.pre_nn is not None:
pre_nn_str = self.pre_nn.__repr__() + "\n\n"
if self.pre_nn_edges is not None:
pre_nn_edges_str = self.pre_nn_edges.__repr__() + "\n\n"
gnn_str = self.gnn.__repr__() + "\n\n"
if self.task_heads is not None:
task_str = self.task_heads.__repr__()
task_str = " Task heads:\n " + " ".join(task_str.splitlines(True))
child_str = " " + pre_nn_str + pre_nn_edges_str + gnn_str + task_str
child_str = " ".join(child_str.splitlines(True))
full_str = self.name + "\n" + "-" * (len(self.name) + 2) + "\n" + child_str
return full_str
@property
def in_dim(self) -> int:
r"""
Returns the input dimension of the network
"""
if self.pre_nn is not None:
return self.pre_nn.in_dim
else:
return self.gnn.in_dim
@property
def out_dim(self) -> int:
r"""
Returns the output dimension of the network
"""
return self.gnn.out_dim
@property
def out_dim_edges(self) -> int:
r"""
Returns the output dimension of the edges
of the network.
"""
if self.gnn.full_dims_edges is not None:
return self.gnn.full_dims_edges[-1]
return self.gnn.in_dim_edges
@property
def in_dim_edges(self) -> int:
r"""
Returns the input edge dimension of the network
"""
return self.gnn.in_dim_edges
class GraphOutputNN(nn.Module, MupMixin):
def __init__(
self,
in_dim: int,
in_dim_edges: int,
task_level: str,
graph_output_nn_kwargs: Dict[str, Any],
):
r"""
Parameters:
in_dim:
Input feature dimensions of the layer
in_dim_edges:
Input edge feature dimensions of the layer
task_level:
graph/node/edge/nodepair depending on wether it is graph/node/edge/nodepair level task
graph_output_nn_kwargs:
key-word arguments to use for the initialization of the post-processing
MLP network after the GNN, using the class `FeedForwardNN`.
"""
super().__init__()
self.task_level = task_level
self._concat_last_layers = None
self.in_dim = in_dim
self.in_dim_edges = in_dim_edges
self.graph_output_nn_kwargs = graph_output_nn_kwargs
self.max_num_nodes_per_graph = None
self.max_num_edges_per_graph = None
self.map_task_level = {
"node": "feat",
"nodepair": "nodepair_feat",
"graph": "graph_feat",
"edge": "edge_feat",
}
if self.task_level == "nodepair":
level_in_dim = 2 * self.in_dim
elif self.task_level == "edge":
level_in_dim = self.in_dim_edges
elif self.task_level == "graph":
self.global_pool_layer, self.out_pool_dim = self._parse_pooling_layer(
self.in_dim, graph_output_nn_kwargs[self.task_level]["pooling"]
)
level_in_dim = self.out_pool_dim
else:
level_in_dim = self.in_dim
if level_in_dim == 0:
raise ValueError(f"Task head has an input dimension of 0.")
# Initialize the post-processing neural net (applied after the gnn)
name = graph_output_nn_kwargs[self.task_level].pop("name", "post-NN")
filtered_graph_output_nn_kwargs = {
k: v for k, v in graph_output_nn_kwargs[self.task_level].items() if k not in ["pooling", "in_dim"]
}
self.graph_output_nn = FeedForwardNN(
in_dim=level_in_dim, name=name, **filtered_graph_output_nn_kwargs
)
def forward(self, g: Batch):
"""
Parameters:
g: pyg Batch graph
Returns:
h: Output features after applying graph_output_nn
"""
# Check if at least one nodepair task is present
if self.task_level == "nodepair":
g["nodepair_feat"] = self.compute_nodepairs(
node_feats=g["feat"],
batch=g.batch,
max_num_nodes=self.max_num_nodes_per_graph,
drop_nodes_last_graph=is_running_on_ipu(),
)
# Check if at least one graph-level task is present
if self.task_level == "graph":
# pool features if the level is graph
g["graph_feat"] = self._pool_layer_forward(g, g["feat"])
h = g[self.map_task_level[self.task_level]]
# Run the output network
if self.concat_last_layers is None:
h = self.graph_output_nn.forward(h)
else:
# Concatenate the output of the last layers according to `self._concat_last_layers``.
# Useful for generating fingerprints
h = [h]
for ii in range(len(self.graph_output_nn.layers)):
h.insert(0, self.graph_output_nn.layers[ii].forward(h[0])) # Append in reverse
return h
def _parse_pooling_layer(
self, in_dim: int, pooling: Union[str, List[str]], **kwargs
) -> Tuple[nn.Module, int]:
r"""
Return the pooling layer
**This function is virtual, so it needs to be implemented by the child class.**
Parameters:
in_dim:
The dimension at the input layer of the pooling
pooling:
The list of pooling layers to use. The accepted strings are:
- "sum": `SumPooling`
- "mean": `MeanPooling`
- "max": `MaxPooling`
- "min": `MinPooling`
- "std": `StdPooling`
- "s2s": `Set2Set`
- "dir{int}": `DirPooling`
kwargs:
Kew-word arguments for the pooling layer initialization
Return:
pool_layer: Pooling layer module
out_pool_dim: Output dimension of the pooling layer
"""
return parse_pooling_layer_pyg(in_dim, pooling, **kwargs)
def _pool_layer_forward(
self,
g: Batch,
feat: torch.Tensor,
):
r"""
Apply the graph pooling layer, followed by the linear output layer.
Parameters:
g: pyg Batch graph on which the convolution is done
feat (torch.Tensor[..., N, Din]):
Node feature tensor, before convolution.
`N` is the number of nodes, `Din` is the output size of the last Graph layer
Returns:
torch.Tensor[..., M, Din] or torch.Tensor[..., N, Din]:
Node feature tensor, after convolution.
`N` is the number of nodes, `M` is the number of graphs, `Dout` is the output dimension ``self.out_dim``
If the pooling is `None`, the dimension is `N`, otherwise it is `M`
"""
if len(self.global_pool_layer) > 0:
pooled_feat = self.global_pool_layer(g, feat)
else:
pooled_feat = feat
return pooled_feat
def compute_nodepairs(
self,
node_feats: torch.Tensor,
batch: torch.Tensor,
max_num_nodes: int = None,
fill_value: float = float("nan"),
batch_size: int = None,
drop_nodes_last_graph: bool = False,
) -> torch.Tensor:
r"""
Vectorized implementation of nodepair-level task:
Parameters:
node_feats: Node features
batch: Batch vector
max_num_nodes: The maximum number of nodes per graph
fill_value: The value for invalid entries in the
resulting dense output tensor. (default: :obj:`NaN`)
batch_size: The batch size. (default: :obj:`None`)
drop_nodes_last_graph: Whether to drop the nodes of the last graphs that exceed
the `max_num_nodes_per_graph`. Useful when the last graph is a padding.
Returns:
result: concatenated node features of shape B * max_num_nodes * 2*h,
where B is number of graphs, max_num_nodes is the chosen maximum number nodes, and h is the feature dim
"""
dense_feat, mask, _ = to_dense_batch(
node_feats,
batch=batch,
fill_value=fill_value,
batch_size=batch_size,
max_num_nodes_per_graph=max_num_nodes,
drop_nodes_last_graph=drop_nodes_last_graph,
)
n = dense_feat.size(1)
h_X = dense_feat[:, :, None].repeat(1, 1, n, 1)
h_Y = dense_feat[:, None, :, :].repeat(1, n, 1, 1)
nodepair_h = torch.cat((h_X + h_Y, torch.abs(h_X - h_Y)), dim=-1)
upper_tri_mask = torch.triu(torch.ones(n, n), diagonal=1).bool()
# Mask nodepair_h using upper_tri_mask
batched_result = nodepair_h[:, upper_tri_mask, :]
return batched_result
def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: bool = False) -> Dict[str, Any]:
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter:
divide_factor: Factor by which to divide the width.
factor_in_dim: Whether to factor the input dimension
Returns:
Dictionary with the kwargs to create the base model.
"""
# For the post-nn network, all the dimension are divided
graph_output_nn_kwargs = self.graph_output_nn.make_mup_base_kwargs(
divide_factor=divide_factor, factor_in_dim=factor_in_dim
)
kwargs = {
"pooling": self.graph_output_nn_kwargs[self.task_level]["pooling"],
**graph_output_nn_kwargs,
}
return kwargs
def drop_graph_output_nn_layers(self, num_layers_to_drop: int) -> None:
r"""
Remove the last layers of the model. Useful for Transfer Learning.
Parameters:
num_layers_to_drop: The number of layers to drop from the `self.graph_output_nn` network.
"""
assert num_layers_to_drop >= 0
assert num_layers_to_drop <= len(self.graph_output_nn.layers)
if num_layers_to_drop > 0:
self.graph_output_nn.layers = self.graph_output_nn.layers[:-num_layers_to_drop]
def extend_graph_output_nn_layers(self, layers: nn.ModuleList):
r"""
Add layers at the end of the model. Useful for Transfer Learning.
Parameters:
layers: A ModuleList of all the layers to extend
"""
assert isinstance(layers, nn.ModuleList)
if len(self.graph_output_nn.layers) > 0:
assert layers[0].in_dim == self.graph_output_nn.layers.out_dim[-1]
self.graph_output_nn.extend(layers)
def set_max_num_nodes_edges_per_graph(self, max_nodes: Optional[int], max_edges: Optional[int]) -> None:
"""
Set the maximum number of nodes and edges for all gnn layers and encoder layers
Parameters:
max_nodes: Maximum number of nodes in the dataset.
This will be useful for certain architecture, but ignored by others.
max_edges: Maximum number of edges in the dataset.
This will be useful for certain architecture, but ignored by others.
"""
self.max_num_nodes_per_graph = max_nodes
self.max_num_edges_per_graph = max_edges
@property
def concat_last_layers(self) -> Optional[Iterable[int]]:
"""
Property to control the output of the `self.forward`.
If set to a list of integer, the `forward` function will
concatenate the output of different layers.
If set to `None`, the output of the last layer is returned.
NOTE: The indexes are inverted. 0 is the last layer, 1 is the second last, etc.
"""
return self._concat_last_layers
@concat_last_layers.setter
def concat_last_layers(self, value: Union[Type[None], int, Iterable[int]]) -> None:
"""
Set the property to control the output of the `self.forward`.
If set to a list of integer, the `forward` function will
concatenate the output of different layers.
If a single integer is provided, it will output that specific layer.
If set to `None`, the output of the last layer is returned.
NOTE: The indexes are inverted. 0 is the last layer, 1 is the second last, etc.
Parameters:
value: Output layers to concatenate, in reverse order (`0` is the last layer)
"""
if (value is not None) and not isinstance(value, Iterable):
value = [value]
self._concat_last_layers = value
@property
def out_dim(self) -> int:
r"""
Returns the output dimension of the network
"""
return self.graph_output_nn.out_dim
class TaskHeads(nn.Module, MupMixin):
def __init__(
self,
in_dim: int,
in_dim_edges: int,
task_heads_kwargs: Dict[str, Any],
graph_output_nn_kwargs: Dict[str, Any],
last_layer_is_readout: bool = True,
):
r"""
Class that groups all multi-task output heads together to provide the task-specific outputs.
Parameters:
in_dim:
Input feature dimensions of the layer
in_dim_edges:
Input edge feature dimensions of the layer
last_layer_is_readout: Whether the last layer should be treated as a readout layer.
Allows to use the `mup.MuReadout` from the muTransfer method
task_heads_kwargs:
This argument is a list of dictionaries corresponding to the arguments for a FeedForwardNN.
Each dict of arguments is used to initialize a task-specific MLP.
graph_output_nn_kwargs:
key-word arguments to use for the initialization of the post-processing
MLP network after the GNN, using the class `FeedForwardNN`.
"""
super().__init__()
self.last_layer_is_readout = last_layer_is_readout
self.task_heads_kwargs = deepcopy(task_heads_kwargs)
self.graph_output_nn_kwargs = deepcopy(graph_output_nn_kwargs)
self.task_levels = {head_kwargs["task_level"] for _, head_kwargs in self.task_heads_kwargs.items()}
self.in_dim = in_dim
self.in_dim_edges = in_dim_edges
self.task_heads = nn.ModuleDict()
self.graph_output_nn = nn.ModuleDict()
self._check_bad_arguments()
for task_name, head_kwargs in self.task_heads_kwargs.items():
task_level = self.task_heads_kwargs[task_name].get("task_level")
self.graph_output_nn[task_level] = GraphOutputNN(
in_dim=self.in_dim,
in_dim_edges=self.in_dim_edges,
task_level=task_level,
graph_output_nn_kwargs=self.graph_output_nn_kwargs,
)
head_kwargs.setdefault("name", f"NN-{task_name}")
head_kwargs.setdefault("last_layer_is_readout", last_layer_is_readout)
# Create a new dictionary without the task_level key-value pair,
# and pass it while initializing the FeedForwardNN instance for tasks
filtered_kwargs = {k: v for k, v in head_kwargs.items() if k != "task_level"}
filtered_kwargs["in_dim"] = self.graph_output_nn_kwargs[task_level]["out_dim"]
self.task_heads[task_name] = FeedForwardNN(**filtered_kwargs)
def forward(self, g: Batch) -> Dict[str, torch.Tensor]:
r"""
forward function of the task head
Parameters:
g: pyg Batch graph
Returns:
task_head_outputs: Return a dictionary: Dict[task_name, Tensor]
"""
features = {task_level: self.graph_output_nn[task_level](g) for task_level in self.task_levels}
task_head_outputs = {}
for task_name, head in self.task_heads.items():
task_level = self.task_heads_kwargs[task_name].get(
"task_level", None
) # Get task_level without modifying head_kwargs
task_head_outputs[task_name] = head.forward(features[task_level])
return task_head_outputs
def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: bool = False) -> Dict[str, Any]:
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter:
divide_factor: Factor by which to divide the width.
factor_in_dim: Whether to factor the input dimension
Returns:
kwargs: Dictionary of arguments to be used to initialize the base model
"""
graph_output_nn_kwargs = {}
for task_level, graph_output_nn in self.graph_output_nn.items():
graph_output_nn: GraphOutputNN
graph_output_nn_kwargs[task_level] = graph_output_nn.make_mup_base_kwargs(
divide_factor=divide_factor, factor_in_dim=factor_in_dim
)
task_heads_kwargs = {}
for task_name, task_nn in self.task_heads.items():
task_nn: FeedForwardNN
task_heads_kwargs[task_name] = task_nn.make_mup_base_kwargs(
divide_factor=divide_factor, factor_in_dim=factor_in_dim
)
task_heads_kwargs[task_name]["task_level"] = self.task_heads_kwargs[task_name]["task_level"]
kwargs = dict(
in_dim=self.in_dim,
last_layer_is_readout=self.last_layer_is_readout,
task_heads_kwargs=task_heads_kwargs,
graph_output_nn_kwargs=graph_output_nn_kwargs,
)
return kwargs
def set_max_num_nodes_edges_per_graph(self, max_nodes: Optional[int], max_edges: Optional[int]) -> None:
"""
Set the maximum number of nodes and edges for all gnn layers and encoder layers
Parameters:
max_nodes: Maximum number of nodes in the dataset.
This will be useful for certain architecture, but ignored by others.
max_edges: Maximum number of edges in the dataset.
This will be useful for certain architecture, but ignored by others.
"""
for graph_output_nn in self.graph_output_nn.values():
graph_output_nn: GraphOutputNN
graph_output_nn.set_max_num_nodes_edges_per_graph(max_nodes, max_edges)
for task_head in self.task_heads.values():
task_head: FeedForwardNN
for layer in task_head.layers:
if isinstance(layer, BaseGraphStructure):
layer.max_num_nodes_per_graph = max_nodes
layer.max_num_edges_per_graph = max_edges
@property
def out_dim(self) -> Dict[str, int]:
r"""
Returns the output dimension of each task head
"""
return {task_name: head.out_dim for task_name, head in self.task_heads.items()}
def __repr__(self):
r"""
Returns a string representation of the task heads
"""
task_repr = []
for head, net in self.task_heads.items():
task_repr.append(head + ": " + net.__repr__())
return "\n".join(task_repr)
def _check_bad_arguments(self):
r"""
Raise comprehensive errors if the arguments seem wrong
"""
for task_name, head_kwargs in self.task_heads_kwargs.items():
task_level = self.task_heads_kwargs[task_name].get("task_level", None)
if task_level is None:
raise ValueError("task_level must be specified for each task head.")
if task_level not in ["node", "edge", "graph", "nodepair"]:
raise ValueError(f"task_level {task_level} is not supported.")
================================================
FILE: graphium/nn/architectures/pyg_architectures.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from torch import Tensor
from torch.nn import Module
from typing import Tuple, Union, List, Optional
from torch_geometric.data import Data, Batch
from graphium.nn.base_graph_layer import BaseGraphModule
from graphium.nn.pyg_layers import VirtualNodePyg, parse_pooling_layer_pyg
from graphium.nn.architectures.global_architectures import FeedForwardGraph
class FeedForwardPyg(FeedForwardGraph):
def _graph_layer_forward(
self,
layer: BaseGraphModule,
g: Batch,
feat: Tensor,
edge_feat: Optional[Tensor],
feat_prev: Optional[Tensor],
edge_feat_prev: Optional[Tensor],
step_idx: int,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
r"""
A flexible neural network architecture, with variable hidden dimensions,
support for multiple layer types, and support for different residual
connections.
This class is meant to work with different PyG-based graph neural networks
layers. Any layer must inherit from `graphium.nn.base_graph_layer.BaseGraphStructure`
or `graphium.nn.base_graph_layer.BaseGraphLayer`.
Apply the *i-th* PyG graph layer, where *i* is the index given by `step_idx`.
The layer is applied differently depending if there are edge features or not.
Then, the residual is also applied on both the features and the edges (if applicable)
Parameters:
layer:
The PyG layer used for the convolution
g:
graph on which the convolution is done
feat (torch.Tensor[..., N, Din]):
Node feature tensor, before convolution.
`N` is the number of nodes, `Din` is the input features
edge_feat (torch.Tensor[..., N, Ein]):
Edge feature tensor, before convolution.
`N` is the number of nodes, `Ein` is the input edge features
feat_prev:
Node feature of the previous residual connection, or `None`
edge_feat_prev:
Edge feature of the previous residual connection, or `None`
step_idx:
The current step idx in the forward loop
Returns:
feat (torch.Tensor[..., N, Dout]):
Node feature tensor, after convolution and residual.
`N` is the number of nodes, `Dout` is the output features of the layer and residual
edge_feat (torch.Tensor[..., N, Eout]):
Edge feature tensor, after convolution and residual.
`N` is the number of nodes, `Ein` is the input edge features
feat_prev:
Node feature tensor to be used at the next residual connection, or `None`
edge_feat_prev:
Edge feature tensor to be used at the next residual connection, or `None`
"""
# Set node / edge features into the graph
g["feat"] = feat
g["edge_feat"] = edge_feat
# Apply the GNN layer
g = layer(g)
# Get the node / edge features from the graph
feat = g["feat"]
edge_feat = g["edge_feat"]
# Apply the residual layers on the features and edges (if applicable)
if step_idx < len(self.layers) - 1:
feat, feat_prev = self.residual_layer.forward(feat, feat_prev, step_idx=step_idx)
if (self.residual_edges_layer is not None) and (layer.layer_outputs_edges):
edge_feat, edge_feat_prev = self.residual_edges_layer.forward(
edge_feat, edge_feat_prev, step_idx=step_idx
)
return feat, edge_feat, feat_prev, edge_feat_prev
def _parse_virtual_node_class(self) -> type:
return VirtualNodePyg
def _parse_pooling_layer(
self, in_dim: int, pooling: Union[str, List[str]], **kwargs
) -> Tuple[Module, int]:
return parse_pooling_layer_pyg(in_dim, pooling, **kwargs)
================================================
FILE: graphium/nn/base_graph_layer.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import abc
from typing import Union, Callable, List, Optional, Mapping
from copy import deepcopy
import torch
import torch.nn as nn
from torch import Tensor, IntTensor
from torch_sparse import SparseTensor
from graphium.nn.base_layers import get_activation, DropPath
from graphium.utils.decorators import classproperty
import torch_geometric as pyg
class BaseGraphStructure:
def __init__(
self,
in_dim: int,
out_dim: int,
activation: Union[str, Callable] = "relu",
dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
layer_idx: Optional[int] = None,
layer_depth: Optional[int] = None,
droppath_rate: float = 0.0,
):
r"""
Abstract class used to standardize the implementation of Pyg layers
in the current library. It will allow a network to seemlesly swap between
different GNN layers by better understanding the expected inputs
and outputs.
Parameters:
in_dim:
Input feature dimensions of the layer
out_dim:
Output feature dimensions of the layer
activation:
activation function to use in the layer
dropout:
The ratio of units to dropout. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
layer_idx:
The index of the current layer
layer_depth:
The total depth (number of layers) associated to this specific layer
droppath_rate:
stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382
"""
super().__init__()
# Basic attributes
self.in_dim = in_dim
self.out_dim = out_dim
self.normalization = normalization
self.dropout = dropout
self.activation = activation
self.layer_idx = layer_idx
self.layer_depth = layer_depth
self.droppath_rate = droppath_rate
self._max_num_nodes_per_graph = None
self._max_num_edges_per_graph = None
def _initialize_activation_dropout_norm(self):
if not isinstance(self, nn.Module):
raise TypeError(
"This function requires the current object to be an `nn.Module`. Use multi-inheritance or the class `BaseGraphModule` instead"
)
# Build the layers
self.activation_layer = get_activation(self.activation)
self.dropout_layer = self._parse_dropout(self.dropout)
self.norm_layer = self._parse_norm(self.normalization)
self.droppath_layer = self._parse_droppath(self.droppath_rate)
def _parse_dropout(self, dropout):
if callable(dropout):
return deepcopy(dropout)
elif dropout > 0:
return nn.Dropout(p=dropout)
return
def _parse_droppath(self, droppath_rate):
if droppath_rate == 0:
return
droppath_rate = DropPath.get_stochastic_drop_rate(droppath_rate, self.layer_idx, self.layer_depth)
return DropPath(drop_rate=droppath_rate)
def _parse_norm(self, normalization, dim=None):
if dim is None:
dim = self.out_dim * self.out_dim_factor
if normalization is None or normalization == "none":
parsed_norm = None
elif callable(normalization):
parsed_norm = deepcopy(normalization)
elif normalization == "batch_norm":
parsed_norm = nn.BatchNorm1d(dim)
elif normalization == "layer_norm":
parsed_norm = nn.LayerNorm(dim)
else:
raise ValueError(
f"Undefined normalization `{normalization}`, must be `None`, `Callable`, 'batch_norm', 'layer_norm', 'none'"
)
return parsed_norm
def apply_norm_activation_dropout(
self,
feat: Tensor,
normalization: bool = True,
activation: bool = True,
dropout: bool = True,
droppath: bool = True,
batch_idx: Optional[IntTensor] = None,
batch_size: Optional[int] = None,
):
r"""
Apply the different normalization and the dropout to the
output layer.
Parameters:
feat:
Feature tensor, to be normalized
batch_idx
normalization:
Whether to apply the normalization
activation:
Whether to apply the activation layer
dropout:
Whether to apply the dropout layer
droppath:
Whether to apply the DropPath layer
Returns:
feat:
Normalized and dropped-out features
"""
if normalization and (self.norm_layer is not None):
feat = self.norm_layer(feat)
if activation and (self.activation_layer is not None):
feat = self.activation_layer(feat)
if dropout and (self.dropout_layer is not None):
feat = self.dropout_layer(feat)
if droppath and (self.droppath_layer is not None):
feat = self.droppath_layer(feat, batch_idx=batch_idx, batch_size=batch_size)
return feat
@classproperty
def layer_supports_edges(cls) -> bool:
r"""
Abstract method. Return a boolean specifying if the layer type
supports output edges edges or not.
Returns:
bool:
Whether the layer supports the use of edges
"""
...
@property
@abc.abstractmethod
def layer_inputs_edges(self) -> bool:
r"""
Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_input_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Whether the layer uses input edges in the forward pass
"""
...
@property
@abc.abstractmethod
def layer_outputs_edges(self) -> bool:
r"""
Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_output_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Whether the layer outputs edges in the forward pass
"""
...
@property
@abc.abstractmethod
def out_dim_factor(self) -> int:
r"""
Abstract method.
Get the factor by which the output dimension is multiplied for
the next layer.
For standard layers, this will return ``1``.
But for others, such as ``GatLayer``, the output is the concatenation
of the outputs from each head, so the out_dim gets multiplied by
the number of heads, and this function should return the number
of heads.
Returns:
int:
The factor that multiplies the output dimensions
"""
...
@property
def max_num_nodes_per_graph(self) -> Optional[int]:
"""
Get the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU)
"""
return self._max_num_nodes_per_graph
@max_num_nodes_per_graph.setter
def max_num_nodes_per_graph(self, value: Optional[int]):
"""
Set the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU)
"""
if value is not None:
assert isinstance(value, int) and (
value > 0
), f"Value should be a positive integer, provided f{value} of type {type(value)}"
self._max_num_nodes_per_graph = value
@property
def max_num_edges_per_graph(self) -> Optional[int]:
"""
Get the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU)
"""
return self._max_num_edges_per_graph
@max_num_edges_per_graph.setter
def max_num_edges_per_graph(self, value: Optional[int]):
"""
Set the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU)
"""
if value is not None:
assert isinstance(value, int) and (
value > 0
), f"Value should be a positive integer, provided f{value} of type {type(value)}"
self._max_num_edges_per_graph = value
def __repr__(self):
r"""
Controls how the class is printed
"""
f = self.out_dim_factor
out_dim_f_print = "" if f == 1 else f" * {f}"
return f"{self.__class__.__name__}({self.in_dim} -> {self.out_dim}{out_dim_f_print}, activation={self.activation})"
class BaseGraphModule(BaseGraphStructure, nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
activation: Union[str, Callable] = "relu",
dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
layer_idx: Optional[int] = None,
layer_depth: Optional[int] = None,
droppath_rate: float = 0.0,
):
r"""
Abstract class used to standardize the implementation of Pyg layers
in the current library. It will allow a network to seemlesly swap between
different GNN layers by better understanding the expected inputs
and outputs.
Parameters:
in_dim:
Input feature dimensions of the layer
out_dim:
Output feature dimensions of the layer
activation:
activation function to use in the layer
dropout:
The ratio of units to dropout. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
layer_idx:
The index of the current layer
layer_depth:
The total depth (number of layers) associated to this specific layer
droppath_rate:
stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382
"""
super().__init__(
in_dim=in_dim,
out_dim=out_dim,
normalization=normalization,
dropout=dropout,
activation=activation,
layer_idx=layer_idx,
layer_depth=layer_depth,
droppath_rate=droppath_rate,
)
self._initialize_activation_dropout_norm()
def check_intpus_allow_int(obj, edge_index, size):
"""
Overwrite the __check_input__ to allow for int32 and int16
TODO: Remove when PyG and pytorch supports int32.
"""
the_size: List[Optional[int]] = [None, None]
if isinstance(edge_index, Tensor):
# These 3 lines are different. They check for more int types and avoid overflow
assert edge_index.dtype in (torch.long, torch.int64, torch.int32, torch.int16)
# assert edge_index.min() >= 0
# assert edge_index.max() < torch.iinfo(edge_index.dtype).max
assert edge_index.dim() == 2
assert edge_index.size(0) == 2
if size is not None:
the_size[0] = size[0]
the_size[1] = size[1]
return the_size
elif isinstance(edge_index, SparseTensor):
if obj.flow == "target_to_source":
raise ValueError(
(
'Flow direction "target_to_source" is invalid for '
"message propagation via `torch_sparse.SparseTensor`. If "
"you really want to make use of a reverse message "
"passing flow, pass in the transposed sparse tensor to "
"the message passing module, e.g., `adj_t.t()`."
)
)
the_size[0] = edge_index.sparse_size(1)
the_size[1] = edge_index.sparse_size(0)
return the_size
raise ValueError(
(
"`MessagePassing.propagate` only supports `torch.LongTensor` of "
"shape `[2, num_messages]` or `torch_sparse.SparseTensor` for "
"argument `edge_index`."
)
)
================================================
FILE: graphium/nn/base_layers.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Union, Callable, Optional, Type, Tuple, Iterable
from copy import deepcopy
from loguru import logger
import inspect
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor, IntTensor
import mup.init as mupi
from mup import set_base_shapes, MuReadout
from torch.nn.functional import linear
from graphium.ipu.ipu_utils import is_running_on_ipu
SUPPORTED_ACTIVATION_MAP = {
"ReLU",
"Sigmoid",
"Tanh",
"ELU",
"SELU",
"GLU",
"GELU",
"LeakyReLU",
"Softplus",
"None",
}
def get_activation(activation: Union[type(None), str, Callable]) -> Optional[Callable]:
r"""
returns the activation function represented by the input string
Parameters:
activation: Callable, `None`, or string with value:
"none", "ReLU", "Sigmoid", "Tanh", "ELU", "SELU", "GLU", "GELU", "LeakyReLU", "Softplus"
Returns:
Callable or None: The activation function
"""
if (activation is not None) and callable(activation):
# activation is already a function
return activation
if (activation is None) or (activation.lower() == "none"):
return None
# search in SUPPORTED_ACTIVATION_MAP a torch.nn.modules.activation
activation = [x for x in SUPPORTED_ACTIVATION_MAP if activation.lower() == x.lower()]
assert len(activation) == 1 and isinstance(
activation[0], str
), f"Unhandled activation function {activation} of type {type(activation)}"
activation = activation[0]
return vars(torch.nn.modules.activation)[activation]()
def get_activation_str(activation: Union[type(None), str, Callable]) -> str:
r"""
returns the string related to the activation function
Parameters:
activation: Callable, `None`, or string with value:
"none", "ReLU", "Sigmoid", "Tanh", "ELU", "SELU", "GLU", "LeakyReLU", "Softplus"
Returns:
The name of the activation function
"""
if isinstance(activation, str):
return activation
if activation is None:
return "None"
if isinstance(activation, Callable):
return activation.__class__._get_name(activation)
else:
raise ValueError(f"Unhandled activation function {activation} of type {type(activation)}")
def get_norm(normalization: Union[Type[None], str, Callable], dim: Optional[int] = None):
r"""
returns the normalization function represented by the input string
Parameters:
normalization: Callable, `None`, or string with value:
"none", "batch_norm", "layer_norm"
dim: Dimension where to apply the norm. Mandatory for 'batch_norm' and 'layer_norm'
Returns:
Callable or None: The normalization function
"""
parsed_norm = None
if (normalization is None) or (normalization in ["none", "NoneType"]):
pass
elif callable(normalization):
parsed_norm = normalization
elif normalization in ["batch_norm", "BatchNorm1d"]:
parsed_norm = nn.BatchNorm1d(dim)
elif normalization in ["layer_norm", "LayerNorm"]:
parsed_norm = nn.LayerNorm(dim)
else:
raise ValueError(
f"Undefined normalization `{normalization}`, must be `None`, `Callable`, 'batch_norm', 'layer_norm', 'none'"
)
return deepcopy(parsed_norm)
class MultiheadAttentionMup(nn.MultiheadAttention):
"""
Modifying the MultiheadAttention to work with the muTransfer paradigm.
The layers are initialized using the mup package.
The `_scaled_dot_product_attention` normalizes the attention matrix with `1/d` instead of `1/sqrt(d)`
The biased self-attention option is added to have 3D attention bias.
"""
def __init__(self, biased_attention, **kwargs):
super().__init__(**kwargs)
self.biased_attention = biased_attention
def _reset_parameters(self):
set_base_shapes(self, None, rescale_params=False) # Set the shapes of the tensors, useful for mup
if self._qkv_same_embed_dim:
mupi.xavier_uniform_(self.in_proj_weight)
else:
mupi.xavier_uniform_(self.q_proj_weight)
mupi.xavier_uniform_(self.k_proj_weight)
mupi.xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.0)
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
mupi.xavier_normal_(self.bias_k)
if self.bias_v is not None:
mupi.xavier_normal_(self.bias_v)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
attn_bias: Optional[Tensor] = None,
precision: Optional[str] = "32",
*args,
**kwargs,
) -> Tuple[Tensor, Optional[Tensor]]:
# attn_bias [batch, num_heads, nodes, nodes]
if not self.biased_attention or attn_bias is None:
attn_bias = 0.0
# assuming source and target have the same sequence length (homogeneous graph attention)
batch, nodes, hidden = query.size()
assert (
hidden == self.embed_dim
), f"query hidden dimension {hidden} != embed_dim {self.embed_dim} in class"
head_dim = self.embed_dim // self.num_heads
assert head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
scaling_factor = 1 / head_dim # use head_dim instead of (head_dim**0.5) for mup
b_q, b_k, b_v = self.in_proj_bias.chunk(3)
q_proj_weight, k_proj_weight, v_proj_weight = self.in_proj_weight.chunk(3)
# [batch, num_heads, nodes, head_size]
q = linear(query, q_proj_weight, b_q).view(batch, nodes, self.num_heads, -1).transpose(1, 2)
# [batch, num_heads, nodes, head_size]
k = linear(key, k_proj_weight, b_k).view(batch, nodes, self.num_heads, -1).transpose(1, 2)
# [batch, num_heads, nodes, head_size]
v = linear(value, v_proj_weight, b_v).view(batch, nodes, self.num_heads, -1).transpose(1, 2)
q = q * scaling_factor
# [batch, num_heads, nodes, nodes]
attn_weights = q @ k.transpose(-1, -2)
# [batch, num_heads, nodes, nodes]
attn_weights += attn_bias
key_padding_mask_value = float("-inf") if precision == "32" else -10000
# key_padding_mask: [batch, 1, 1, nodes]
if key_padding_mask is not None:
masked_attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
key_padding_mask_value,
)
else:
masked_attn_weights = attn_weights
masked_attn_weights = F.softmax(masked_attn_weights, dim=-1)
attn_probs = F.dropout(masked_attn_weights, p=self.dropout, training=self.training)
# [batch, num_heads, nodes, nodes] * [batch, num_heads, nodes, head_size] -> [batch, num_heads, nodes, head_size]
attn = attn_probs @ v
# [batch, nodes, embd_dim]
attn = attn.transpose(1, 2).contiguous().view(batch, nodes, self.embed_dim)
# [batch, nodes, embd_dim]
out = (self.out_proj(attn), None)
return out
class TransformerEncoderLayerMup(nn.TransformerEncoderLayer):
r"""
Modified version of ``torch.nn.TransformerEncoderLayer`` that uses :math:`1/n`-scaled attention
for compatibility with muP (as opposed to the original :math:`1/\sqrt{n}` scaling factor)
Arguments are the same as ``torch.nn.TransformerEncoderLayer``.
"""
def __init__(self, biased_attention, *args, **kwargs) -> None:
super(TransformerEncoderLayerMup, self).__init__(*args, **kwargs)
# Extract arguments passed to __init__ as a dictionary
signature = inspect.signature(nn.TransformerEncoderLayer.__init__)
# `self` needs to passed, which makes things tricky, but using this object seems fine for now
bound_signature = signature.bind(self, *args, **kwargs)
bound_signature.apply_defaults()
mha_names = ["embed_dim", "num_heads", "dropout", "batch_first", "device", "dtype"]
transformer_names = ["d_model", "nhead", "dropout", "batch_first", "device", "dtype"]
# Override self attention to use muP
self.self_attn = MultiheadAttentionMup(
biased_attention,
**{
mha_name: bound_signature.arguments[transformer_name]
for mha_name, transformer_name in zip(mha_names, transformer_names)
},
)
class MuReadoutGraphium(MuReadout):
"""
PopTorch-compatible replacement for `mup.MuReadout`
Not quite a drop-in replacement for `mup.MuReadout` - you need to specify
`base_width`.
Set `base_width` to width of base model passed to `mup.set_base_shapes`
to get same results on IPU and CPU. Should still "work" with any other
value, but won't give the same results as CPU
"""
def __init__(self, in_features, *args, **kwargs):
super().__init__(in_features, *args, **kwargs)
self._base_width = in_features
@property
def absolute_width(self):
return float(self.in_features)
@property
def base_width(self):
return self._base_width
@base_width.setter
def base_width(self, val):
if val is None:
return
assert isinstance(
val, (int, torch.int, torch.long)
), f"`base_width` must be None, int or long, provided {val} of type {type(val)}"
self._base_width = val
def width_mult(self):
return self.absolute_width / self.base_width
class FCLayer(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
activation: Union[str, Callable] = "relu",
dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
bias: bool = True,
init_fn: Optional[Callable] = None,
is_readout_layer: bool = False,
droppath_rate: float = 0.0,
):
r"""
A simple fully connected and customizable layer. This layer is centered around a `torch.nn.Linear` module.
The order in which transformations are applied is:
- Dense Layer
- Activation
- Dropout (if applicable)
- Batch Normalization (if applicable)
Parameters:
in_dim:
Input dimension of the layer (the `torch.nn.Linear`)
out_dim:
Output dimension of the layer.
dropout:
The ratio of units to dropout. No dropout by default.
activation:
Activation function to use.
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
bias:
Whether to enable bias in for the linear layer.
init_fn:
Initialization function to use for the weight of the layer. Default is
$$\mathcal{U}(-\sqrt{k}, \sqrt{k})$$ with $$k=\frac{1}{ \text{in_dim}}$$
is_readout_layer: Whether the layer should be treated as a readout layer by replacing of `torch.nn.Linear`
by `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup
droppath_rate:
stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382
Attributes:
dropout (int):
The ratio of units to dropout.
normalization (None or Callable):
Normalization layer
linear (`torch.nn.Linear`):
The linear layer
activation (`torch.nn.Module`):
The activation layer
init_fn (Callable):
Initialization function used for the weight of the layer
in_dim (int):
Input dimension of the linear layer
out_dim (int):
Output dimension of the linear layer
"""
super().__init__()
self.__params = locals()
del self.__params["__class__"]
del self.__params["self"]
# Basic parameters
self.in_dim = in_dim
self.out_dim = out_dim
self.bias = bias
self.dropout = None
self.normalization = get_norm(normalization, dim=out_dim)
# Dropout and activation
if dropout:
self.dropout = nn.Dropout(p=dropout)
self.activation = get_activation(activation)
self.drop_path = None
if droppath_rate > 0:
self.drop_path = DropPath(droppath_rate)
# Linear layer, or MuReadout layer
if not is_readout_layer:
self.linear = nn.Linear(in_dim, out_dim, bias=bias)
else:
self.linear = MuReadoutGraphium(in_dim, out_dim, bias=bias)
# Warn user in case of weird parameters
if self.normalization is not None:
logger.warning(
f"Normalization is not `None` for the readout layer. Provided {self.normalization}"
)
if (self.dropout is not None) and (self.dropout.p > 0):
logger.warning(f"Dropout is not `None` or `0` for the readout layer. Provided {self.dropout}")
# Define the initialization function based on `muTransfer`, and reset the parameters
self.init_fn = init_fn if init_fn is not None else mupi.xavier_uniform_
self.reset_parameters()
def reset_parameters(self, init_fn=None):
"""
Reset the parameters of the linear layer using the `init_fn`.
"""
set_base_shapes(self, None, rescale_params=False) # Set the shapes of the tensors, useful for mup
init_fn = init_fn or self.init_fn
if init_fn is not None:
init_fn(self.linear.weight)
if self.bias:
self.linear.bias.data.zero_()
def forward(self, h: torch.Tensor) -> torch.Tensor:
r"""
Apply the FC layer on the input features.
Parameters:
h: `torch.Tensor[..., Din]`:
Input feature tensor, before the FC.
`Din` is the number of input features
Returns:
`torch.Tensor[..., Dout]`:
Output feature tensor, after the FC.
`Dout` is the number of output features
"""
if torch.prod(torch.as_tensor(h.shape[:-1])) == 0:
h = torch.zeros(
list(h.shape[:-1]) + [self.linear.out_features],
device=h.device,
dtype=h.dtype,
)
return h
h = self.linear(h)
if self.normalization is not None:
if h.shape[1] != self.out_dim:
h = self.normalization(h.transpose(1, 2)).transpose(1, 2)
else:
h = self.normalization(h)
if self.dropout is not None:
h = self.dropout(h)
if self.activation is not None:
h = self.activation(h)
if self.drop_path is not None:
h = self.drop_path(h)
return h
@property
def in_channels(self) -> int:
r"""
Get the input channel size. For compatibility with PyG.
"""
return self.in_dim
@property
def out_channels(self) -> int:
r"""
Get the output channel size. For compatibility with PyG.
"""
return self.out_dim
def __repr__(self):
return f"{self.__class__.__name__}({self.in_dim} -> {self.out_dim}, activation={self.activation})"
class MLP(nn.Module):
def __init__(
self,
in_dim: int,
hidden_dims: Union[Iterable[int], int],
out_dim: int,
depth: Optional[int] = None,
activation: Union[str, Callable] = "relu",
last_activation: Union[str, Callable] = "none",
dropout: float = 0.0,
last_dropout: float = 0.0,
normalization: Union[Type[None], str, Callable] = "none",
last_normalization: Union[Type[None], str, Callable] = "none",
first_normalization: Union[Type[None], str, Callable] = "none",
last_layer_is_readout: bool = False,
droppath_rate: float = 0.0,
constant_droppath_rate: bool = True,
fc_layer: FCLayer = FCLayer,
fc_layer_kwargs: Optional[dict] = None,
):
r"""
Simple multi-layer perceptron, built of a series of FCLayers
Parameters:
in_dim:
Input dimension of the MLP
hidden_dims:
Either an integer specifying all the hidden dimensions,
or a list of dimensions in the hidden layers.
out_dim:
Output dimension of the MLP.
depth:
If `hidden_dims` is an integer, `depth` is 1 + the number of
hidden layers to use.
If `hidden_dims` is a list, then
`depth` must be `None` or equal to `len(hidden_dims) + 1`
activation:
Activation function to use in all the layers except the last.
if `layers==1`, this parameter is ignored
last_activation:
Activation function to use in the last layer.
dropout:
The ratio of units to dropout. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization in the hidden layers.
- `Callable`: Any callable function
if `layers==1`, this parameter is ignored
last_normalization:
Norrmalization to use **after the last layer**. Same options as `normalization`.
first_normalization:
Norrmalization to use in **before the first layer**. Same options as `normalization`.
last_dropout:
The ratio of units to dropout at the last layer.
last_layer_is_readout: Whether the last layer should be treated as a readout layer.
Allows to use the `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup
droppath_rate:
stochastic depth drop rate, between 0 and 1.
See https://arxiv.org/abs/1603.09382
constant_droppath_rate:
If `True`, drop rates will remain constant accross layers.
Otherwise, drop rates will vary stochastically.
See `DropPath.get_stochastic_drop_rate`
fc_layer:
The fully connected layer to use. Must inherit from `FCLayer`.
fc_layer_kwargs:
Keyword arguments to pass to the fully connected layer.
"""
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.fc_layer_kwargs = deepcopy(fc_layer_kwargs) or {}
# Parse the hidden dimensions and depth
if isinstance(hidden_dims, int):
self.hidden_dims = [hidden_dims] * (depth - 1)
else:
self.hidden_dims = list(hidden_dims)
assert (depth is None) or (
depth == len(self.hidden_dims) + 1
), "Mismatch between the provided network depth from `hidden_dims` and `depth`"
self.depth = len(self.hidden_dims) + 1
# Parse the normalization
self.first_normalization = get_norm(first_normalization, dim=in_dim)
all_dims = [in_dim] + self.hidden_dims + [out_dim]
fully_connected = []
if self.depth == 0:
self.fully_connected = None
return
else:
for ii in range(self.depth):
if ii < (self.depth - 1):
# Define the parameters for all intermediate layers
this_activation = activation
this_normalization = normalization
this_dropout = dropout
is_readout_layer = False
else:
# Define the parameters for the last layer
this_activation = last_activation
this_normalization = last_normalization
this_dropout = last_dropout
is_readout_layer = last_layer_is_readout
if constant_droppath_rate:
this_drop_rate = droppath_rate
else:
this_drop_rate = DropPath.get_stochastic_drop_rate(droppath_rate, ii, self.depth)
# Add a fully-connected layer
fully_connected.append(
fc_layer(
all_dims[ii],
all_dims[ii + 1],
activation=this_activation,
normalization=this_normalization,
dropout=this_dropout,
is_readout_layer=is_readout_layer,
droppath_rate=this_drop_rate,
**self.fc_layer_kwargs,
)
)
self.fully_connected = nn.Sequential(*fully_connected)
def forward(self, h: torch.Tensor) -> torch.Tensor:
r"""
Apply the MLP on the input features.
Parameters:
h: `torch.Tensor[..., Din]`:
Input feature tensor, before the MLP.
`Din` is the number of input features
Returns:
`torch.Tensor[..., Dout]`:
Output feature tensor, after the MLP.
`Dout` is the number of output features
"""
if self.first_normalization is not None:
h = self.first_normalization(h)
if self.fully_connected is not None:
h = self.fully_connected(h)
return h
@property
def in_features(self):
return self.in_dim
def __getitem__(self, idx: int) -> nn.Module:
return self.fully_connected[idx]
def __repr__(self):
r"""
Controls how the class is printed
"""
return self.__class__.__name__ + " (" + str(self.in_dim) + " -> " + str(self.out_dim) + ")"
class GRU(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
r"""
Wrapper class for the GRU used by the GNN framework, nn.GRU is used for the Gated Recurrent Unit itself
Parameters:
in_dim:
Input dimension of the GRU layer
hidden_dim:
Hidden dimension of the GRU layer.
"""
super().__init__()
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.gru = nn.GRU(in_dim=in_dim, hidden_dim=hidden_dim)
def forward(self, x, y):
r"""
Parameters:
x: `torch.Tensor[B, N, Din]`
where Din <= in_dim (difference is padded)
y: `torch.Tensor[B, N, Dh]`
where Dh <= hidden_dim (difference is padded)
Returns:
torch.Tensor: `torch.Tensor[B, N, Dh]`
"""
assert x.shape[-1] <= self.in_dim and y.shape[-1] <= self.hidden_dim
(B, N, _) = x.shape
x = x.reshape(1, B * N, -1).contiguous()
y = y.reshape(1, B * N, -1).contiguous()
# padding if necessary
if x.shape[-1] < self.in_dim:
x = F.pad(input=x, pad=[0, self.in_dim - x.shape[-1]], mode="constant", value=0)
if y.shape[-1] < self.hidden_dim:
y = F.pad(
input=y,
pad=[0, self.hidden_dim - y.shape[-1]],
mode="constant",
value=0,
)
x = self.gru(x, y)[1]
x = x.reshape(B, N, -1)
return x
class DropPath(nn.Module):
def __init__(self, drop_rate: float):
r"""
DropPath class for stochastic depth
Deep Networks with Stochastic Depth
Gao Huang, Yu Sun, Zhuang Liu, Daniel Sedra and Kilian Weinberger
https://arxiv.org/abs/1603.09382
Parameters:
drop_rate:
Drop out probability
"""
super().__init__()
self.drop_rate = drop_rate
@staticmethod
def get_stochastic_drop_rate(
drop_rate: float, layer_idx: Optional[int] = None, layer_depth: Optional[int] = None
):
"""
Get the stochastic drop rate from the nominal drop rate, the layer index, and the layer depth.
`return drop_rate * (layer_idx / (layer_depth - 1))`
Parameters:
drop_rate:
Drop out nominal probability
layer_idx:
The index of the current layer
layer_depth:
The total depth (number of layers) associated to this specific layer
"""
if drop_rate == 0:
return 0
else:
assert (layer_idx is not None) and (
layer_depth is not None
), f"layer_idx={layer_idx} and layer_depth={layer_depth} should be integers when `droppath_rate>0`"
return drop_rate * (layer_idx / (layer_depth - 1))
def forward(
self,
input: Tensor,
batch_idx: IntTensor,
batch_size: Optional[int] = None,
) -> Tensor:
r"""
Parameters:
input: `torch.Tensor[total_num_nodes, hidden]`
batch: batch attribute of the batch object, batch.batch
batch_size: The batch size. Must be provided when working on IPU
Returns:
torch.Tensor: `torch.Tensor[total_num_nodes, hidde]`
"""
on_ipu = is_running_on_ipu()
if self.drop_rate > 0:
keep_prob = 1 - self.drop_rate
# Parse the batch size
if batch_size is None:
if on_ipu:
raise ValueError(
"When using the IPU the batch size must be "
"provided during compilation instead of determined at runtime"
)
else:
batch_size = int(batch_idx.max()) + 1
# mask shape: [num_graphs, 1]
mask = input.new_empty(batch_size, 1).bernoulli_(keep_prob)
# if on_ipu, the last graph is a padded fake graph
if on_ipu:
mask[-1] = 0
# using gather to extend mask to [total_num_nodes, 1]
node_mask = mask[batch_idx]
if keep_prob == 0:
# avoid dividing by 0
input_scaled = input
else:
input_scaled = input / keep_prob
out = input_scaled * node_mask
else:
out = input
return out
================================================
FILE: graphium/nn/encoders/README.md
================================================
The Graph Of LIfe Library.
## What is in this folder?
code for positional encoder networks for positional encodings and structural encodings
- ✅ `base_encoder.py`: `BaseEncoder` class should be inherented by all positional encoder networks
- ✅ `mlp_encoder.py`: `MLPEncoder` as a simple MLP encoder for any positional encoding
- `gaussian_kernal_pos_encoder.py`: `GaussianKernalPosEncoder` class for gaussian kernal positional encoder
- `laplace_pos_encoder.py`: `LapPENodeEncoder` class for laplacian eigenvector and eigenvalue encoding
================================================
FILE: graphium/nn/encoders/__init__.py
================================================
from .base_encoder import BaseEncoder
from .laplace_pos_encoder import LapPENodeEncoder
from .mlp_encoder import MLPEncoder, CatMLPEncoder
from .signnet_pos_encoder import SignNetNodeEncoder
from .gaussian_kernel_pos_encoder import GaussianKernelPosEncoder
from .bessel_pos_encoder import BesselSphericalPosEncoder
================================================
FILE: graphium/nn/encoders/base_encoder.py
================================================
from typing import List, Dict, Any, Union, Callable
import abc
import torch
from torch_geometric.data import Batch
from graphium.nn.base_layers import get_norm
from graphium.nn.utils import MupMixin
class BaseEncoder(torch.nn.Module, MupMixin):
def __init__(
self,
input_keys: List[str],
output_keys: List[str],
in_dim: int,
out_dim: int,
num_layers: int,
activation: Union[str, Callable] = "relu",
first_normalization=None,
use_input_keys_prefix: bool = True, # TODO: might be redundant along with parse_input_keys_with_prefix function
):
r"""
Base class for all positional and structural encoders.
Initialize the encoder with the following arguments:
Parameters:
input_keys: The keys from the graph to use as input
output_keys: The keys to return as output encodings
in_dim: The input dimension for the encoder
out_dim: The output dimension of the encodings
num_layers: The number of layers of the encoder
activation: The activation function to use
first_normalization: The normalization to use before the first layer
use_input_keys_prefix: Whether to use the `key_prefix` argument in the `forward` method.
This is useful when the encodings are categorized by the function `get_all_positional_encoding`
"""
super().__init__()
if type(in_dim) is list:
in_dim = sum(in_dim)
self.input_keys = self.parse_input_keys(input_keys)
self.output_keys = self.parse_output_keys(output_keys)
self.in_dim = in_dim
self.out_dim = out_dim
self.num_layers = num_layers
self.activation = activation
self.use_input_keys_prefix = use_input_keys_prefix
self.first_normalization = get_norm(first_normalization, dim=in_dim)
# TODO: the function below seems redundant; could be removed/replaced moving forward
def parse_input_keys_with_prefix(self, key_prefix):
"""
Parse the `input_keys` argument, given a certain prefix.
If the prefix is `None`, it is ignored
"""
### TODO: redundant
input_keys = self.input_keys
if (key_prefix is not None) and (self.use_input_keys_prefix):
input_keys = [f"{k}" for k in input_keys]
# input_keys = [f"{key_prefix}/{k}" for k in input_keys]
###
return input_keys
@abc.abstractmethod
def forward(self, graph: Batch, key_prefix=None) -> Dict[str, torch.Tensor]:
r"""
Forward pass of the encoder on a graph.
This is a method to be implemented by the child class.
Parameters:
graph: The input pyg Batch
"""
raise ValueError("This method must be implemented by the child class")
@abc.abstractmethod
def parse_input_keys(self, input_keys: List[str]) -> List[str]:
r"""
Parse the `input_keys` argument. This is a method to be implemented by the child class.
Parameters:
input_keys: The input keys to parse
"""
raise ValueError("This method must be implemented by the child class")
@abc.abstractmethod
def parse_output_keys(self, output_keys: List[str]) -> List[str]:
"""
Parse the `output_keys` argument. This is a method to be implemented by the child class.
Parameters:
output_keys: The output keys to parse
"""
raise ValueError("This method must be implemented by the child class")
def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: bool = False) -> Dict[str, Any]:
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameters:
divide_factor: Factor by which to divide the width.
factor_in_dim: Whether to factor the input dimension
Returns:
A dictionary with the base model arguments
"""
base_kwargs = {
"input_keys": self.input_keys,
"output_keys": self.output_keys,
"in_dim": round(self.in_dim / divide_factor) if factor_in_dim else self.in_dim,
"out_dim": round(self.out_dim / divide_factor),
"num_layers": self.num_layers,
"activation": self.activation,
"first_normalization": type(self.first_normalization).__name__,
"use_input_keys_prefix": self.use_input_keys_prefix,
}
return base_kwargs
================================================
FILE: graphium/nn/encoders/bessel_pos_encoder.py
================================================
import torch
import torch.nn as nn
from torch import Tensor
from typing import Union, Callable, List, Dict, Any, Optional, Tuple
from torch_geometric.data import Batch
from torch_geometric.nn.models.dimenet import BesselBasisLayer, SphericalBasisLayer
from torch_geometric.nn import radius_graph
from graphium.nn.encoders.base_encoder import BaseEncoder
from graphium.nn.pyg_layers.utils import triplets
from graphium.nn.pyg_layers.dimenet_pyg import OutputBlock
from graphium.nn.base_layers import FCLayer
class BesselSphericalPosEncoder(BaseEncoder):
def __init__(
self,
input_keys: List[str], # The keys from the pyg graph
output_keys: List[str], # The keys to return
in_dim: int,
out_dim: int,
num_layers: int,
out_dim_edges: int,
num_output_layers: int,
num_spherical: int,
num_radial: int,
max_num_neighbors: int = 32,
cutoff: float = 5.0,
envelope_exponent: int = 5,
activation: Union[str, Callable] = "gelu",
first_normalization="none",
use_input_keys_prefix: bool = True,
):
r"""
Configurable DimeNet's embedding encoder from the
`"Directional Message Passing for Molecular Graphs" paper.
[!] code uses the pytorch-geometric implementation of BesselBasisLayer & SphericalBasisLayer
Parameters:
input_keys: The keys from the graph to use as input
output_keys: The keys to return as output encodings
(**should at least contain: `edge_rbf`, `triplet_sbf`, `radius_edge_index`; Optional: `node_*`, `edge_*`)
in_dim: The input dimension for the encoder (**not used)
out_dim: The output dimension of the node encodings
num_layers: The number of layers of the encoder (**not used)
out_dim_edges: The output dimension of the edge encodings
num_output_layers: The number of layers of the OutBlock
num_spherical (int): Number of spherical harmonics.
num_radial (int): Number of radial basis functions.
max_num_neighbors (int): The maximum number of neighbors to consider in radius graph. (default: 32)
cutoff (float): Cutoff distance for interatomic interactions. (default: 5.0)
envelope_exponent (int): Shape of the smooth cutoff. (default: 5)
activation: The activation function to use
first_normalization: The normalization to use before the first layer
use_input_keys_prefix: Whether to use the `key_prefix` argument in the `forward` method.
"""
super().__init__(
input_keys=input_keys,
output_keys=output_keys,
in_dim=in_dim,
out_dim=out_dim,
num_layers=num_layers,
activation=activation,
first_normalization=first_normalization,
use_input_keys_prefix=use_input_keys_prefix,
)
self.num_radial = num_radial
self.cutoff = cutoff
self.envelope_exponent = envelope_exponent
self.num_spherical = num_spherical
self.max_num_neighbors = max_num_neighbors
# static Bessel embeddings
self.rbf = BesselBasisLayer(num_radial, cutoff, envelope_exponent) # for edges
self.sbf = SphericalBasisLayer(num_spherical, num_radial, cutoff, envelope_exponent) # for triplets
# edge embedding (num_radial -> out_dim)
# ***simplified version*** by removing atom type input
self.rbf_proj = FCLayer(num_radial, out_dim_edges, activation=activation)
# node embedding from edge embedding (out_dim_edges -> out_dim)
self.output_block_0 = OutputBlock(
num_radial, out_dim_edges, out_dim, num_output_layers, activation
) # 1 outblock right after embedding in original dimenet
def forward(self, batch: Batch, key_prefix: Optional[str] = None) -> Dict[str, Any]:
r"""
forward function of the DimeNetEncoder class
Parameters:
batch: The batch of pyg graphs
key_prefix: The prefix to use for the input keys
Returns:
A dictionary of the outpaut encodings with keys specified by `output_keys`
"""
### Get the input keys ###
positions_3d_key = self.parse_input_keys_with_prefix(key_prefix)[0]
# be in shape [num_nodes, 3]
pos = batch[positions_3d_key]
# Create radius graph in encoder (not use chemical topology of molecules)
radius_edge_index = radius_graph(
pos, r=self.cutoff, batch=batch.batch, max_num_neighbors=self.max_num_neighbors
)
# Process edges and triplets.
i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(radius_edge_index, num_nodes=pos.size(0))
# Calculate distances.
dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()
# Calculate angles.
pos_i = pos[idx_i]
pos_ji, pos_ki = pos[idx_j] - pos_i, pos[idx_k] - pos_i
a = (pos_ji * pos_ki).sum(dim=-1)
b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
angle = torch.atan2(b, a)
# [num_edges, num_radial]
rbf = self.rbf(dist)
# [num_triplets, num_spherical * num_radial]
sbf = self.sbf(dist, angle, idx_kj)
# initial 3D edge embedding [num_edges, out_dim] (may merge with other edge encoder's output)
edge_feature_3d = self.rbf_proj(rbf)
# initial 3D node embedding [num_nodes, out_dim] (align with original DimeNet implementation)
P = self.output_block_0(edge_feature_3d, rbf, i, num_nodes=pos.size(0))
# Crash if the key starts with 'graph_'
# Return `rbf` and `sbf` for necessary message passing in DimeNet
# Return `radius_edge_index` for necessary message passing in DimeNet
# Return `P` as node embedding (if the key starts with 'node_')
# Return `edge_feature_3d` otherwise
output = {}
for key in self.output_keys:
if isinstance(key, str) and key.startswith("graph_"):
raise ValueError("Graph encodings are not supported for this encoder")
elif key.startswith("node_"):
output[key] = P
elif key == "edge_rbf":
output[key] = rbf
elif key == "triplet_sbf":
output[key] = sbf
elif key == "radius_edge_index":
output[key] = radius_edge_index
else:
output[key] = edge_feature_3d
return output
def make_mup_base_kwargs(
self,
divide_factor: float = 2.0,
factor_in_dim: bool = False,
) -> Dict[str, Any]:
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter:
divide_factor: Factor by which to divide the width.
factor_in_dim: Whether to factor the input dimension
Returns:
A dictionary of the base model kwargs
"""
base_kwargs = super().make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=factor_in_dim)
base_kwargs.update(
dict(
num_spherical=self.num_spherical,
num_radial=round(self.num_radial / divide_factor),
cutoff=self.cutoff,
envelope_exponent=self.envelope_exponent,
)
)
return base_kwargs
# these two below aligned with `gaussian_kernel_pos_encoder``
def parse_input_keys(
self,
input_keys: List[str],
) -> List[str]:
r"""
Parse the `input_keys`.
Parameters:
input_keys: The input keys to parse
Returns:
The parsed input keys
"""
if len(input_keys) != 1:
raise ValueError(f"`{self.__class__}` only supports one key")
for key in input_keys:
assert not key.startswith(
"edge_"
), f"Input keys must be node features, not edge features, for encoder {self.__class__}"
assert not key.startswith(
"graph_"
), f"Input keys must be node features, not graph features, for encoder {self.__class__}"
return input_keys
def parse_output_keys(
self,
output_keys: List[str],
) -> List[str]:
r"""
Parse the `output_keys`.
Parameters:
output_keys: The output keys to parse
Returns:
The parsed output keys
"""
# all of three are required to do DimeNet-style message passing
assert "edge_rbf" in output_keys, "Edge radial basis feature should present for this encoder"
assert (
"triplet_sbf" in output_keys
), "Triplet(angle) spherical radial basis feature should present for this encoder"
assert (
"radius_edge_index" in output_keys
), "Radius edge index (graph) should be built in forward of this encoder"
for key in output_keys:
assert not key.startswith("graph_"), "Graph encodings are not supported for this encoder"
return output_keys
================================================
FILE: graphium/nn/encoders/gaussian_kernel_pos_encoder.py
================================================
from typing import Union, Callable, List, Dict, Any, Optional
from torch_geometric.data import Batch
from graphium.nn.pyg_layers.utils import PreprocessPositions
from graphium.ipu.ipu_utils import is_running_on_ipu
from graphium.nn.encoders.base_encoder import BaseEncoder
class GaussianKernelPosEncoder(BaseEncoder):
def __init__(
self,
input_keys: List[str], # The keys from the pyg graph
output_keys: List[str], # The keys to return
in_dim: int,
out_dim: int,
embed_dim: int,
num_layers: int,
max_num_nodes_per_graph: Optional[int] = None,
activation: Union[str, Callable] = "gelu",
first_normalization="none",
use_input_keys_prefix: bool = True,
num_heads: int = 1,
):
r"""
Configurable gaussian kernel-based Positional Encoding node and edge encoder.
Useful for encoding 3D conformation positions.
Parameters:
input_keys: The keys from the pyg graph to use as input
output_keys: The keys to return corresponding to the output encodings
in_dim: The input dimension for the encoder
out_dim: The output dimension of the encodings
embed_dim: The dimension of the embedding
num_layers: The number of layers of the encoder
max_num_nodes_per_graph: The maximum number of nodes per graph
activation: The activation function to use
first_normalization: The normalization to use before the first layer
use_input_keys_prefix: Whether to use the `key_prefix` argument in the `forward` method.
num_heads: The number of heads to use for the multi-head attention
"""
super().__init__(
input_keys=input_keys,
output_keys=output_keys,
in_dim=in_dim,
out_dim=out_dim,
num_layers=num_layers,
activation=activation,
first_normalization=first_normalization,
use_input_keys_prefix=use_input_keys_prefix,
)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.max_num_nodes_per_graph = max_num_nodes_per_graph
# parameters for preprocessing 3d positions
self.preprocess_3d_positions = PreprocessPositions(
num_heads=self.num_heads,
embed_dim=self.embed_dim,
num_kernel=self.out_dim,
num_layers=self.num_layers,
activation=self.activation,
first_normalization=self.first_normalization,
)
def parse_input_keys(
self,
input_keys: List[str],
) -> List[str]:
r"""
Parse the `input_keys`.
Parameters:
input_keys: The input keys to parse
Returns:
The parsed input keys
"""
if len(input_keys) != 1:
raise ValueError(f"`{self.__class__}` only supports one key")
for key in input_keys:
assert not key.startswith(
"edge_"
), f"Input keys must be node features, not edge features, for encoder {self.__class__}"
assert not key.startswith(
"nodepair_"
), f"Input keys must be node features, not nodepair features, for encoder {self.__class__}"
assert not key.startswith(
"graph_"
), f"Input keys must be node features, not graph features, for encoder {self.__class__}"
return input_keys
def parse_output_keys(
self,
output_keys: List[str],
) -> List[str]:
r"""
Parse the `output_keys`.
Parameters:
output_keys: The output keys to parse
Returns:
The parsed output keys
"""
for key in output_keys:
assert not key.startswith("edge_"), "Edge encodings are not supported for this encoder"
assert not key.startswith("graph_"), "Graph encodings are not supported for this encoder"
return output_keys
def forward(self, batch: Batch, key_prefix: Optional[str] = None) -> Dict[str, Any]:
r"""
forward function of the GaussianKernelPosEncoder class
Parameters:
batch: The batch of pyg graphs
key_prefix: The prefix to use for the input keys
Returns:
A dictionary of the output encodings with keys specified by `output_keys`
"""
input_keys = self.parse_input_keys_with_prefix(key_prefix)
on_ipu = is_running_on_ipu()
max_num_nodes_per_graph = None
if on_ipu:
max_num_nodes_per_graph = self.max_num_nodes_per_graph
attn_bias_3d, node_feature_3d = self.preprocess_3d_positions(
batch, max_num_nodes_per_graph, on_ipu, positions_3d_key=input_keys[0]
)
# Return `attn_bias_3d` if the key starts with 'nodepair_'
# Crash if the key starts with 'edge_' or 'graph_'
# Return `node_feature_3d` otherwise
output = {}
for key in self.output_keys:
if isinstance(key, str) and key.startswith("nodepair_"):
output[key] = attn_bias_3d
elif isinstance(key, str) and key.startswith("edge_"):
raise ValueError("Edge encodings are not supported for this encoder")
else:
output[key] = node_feature_3d
return output
def make_mup_base_kwargs(
self,
divide_factor: float = 2.0,
factor_in_dim: bool = False,
) -> Dict[str, Any]:
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter:
divide_factor: Factor by which to divide the width.
factor_in_dim: Whether to factor the input dimension
Returns:
A dictionary of the base model kwargs
"""
base_kwargs = super().make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=factor_in_dim)
base_kwargs.update(
dict(
num_heads=self.num_heads,
embed_dim=round(self.embed_dim / divide_factor),
max_num_nodes_per_graph=self.max_num_nodes_per_graph,
)
)
return base_kwargs
================================================
FILE: graphium/nn/encoders/laplace_pos_encoder.py
================================================
from typing import List, Dict, Any, Optional, Union, Callable
import torch
import torch.nn as nn
from torch_geometric.data import Batch
from graphium.nn.base_layers import MLP, get_norm, FCLayer, TransformerEncoderLayerMup
from graphium.nn.encoders.base_encoder import BaseEncoder
class LapPENodeEncoder(BaseEncoder):
def __init__(
self,
input_keys: List[str],
output_keys: List[str],
in_dim: int, # Size of Laplace PE embedding. Only used by the MLP model
hidden_dim: int,
out_dim: int,
num_layers: int,
activation: Optional[Union[str, Callable]] = "relu",
model_type: str = "DeepSet", # 'Transformer' or 'DeepSet' or 'MLP'
num_layers_post=1, # Num. layers to apply after pooling
dropout=0.0,
first_normalization=None,
use_input_keys_prefix: bool = True,
**model_kwargs,
):
r"""
Laplace Positional Embedding node encoder.
LapPE of size dim_pe will get appended to each node feature vector.
Parameters:
input_keys: List of input keys to use from the data object.
output_keys: List of output keys to add to the data object.
in_dim : Size of Laplace PE embedding. Only used by the MLP model
hidden_dim: Size of hidden layer
out_dim: Size of final node embedding
num_layers: Number of layers in the MLP
activation: Activation function to use.
model_type: 'Transformer' or 'DeepSet' or 'MLP'
num_layers_post: Number of layers to apply after pooling
dropout: Dropout rate
first_normalization: Normalization to apply to the first layer.
"""
super().__init__(
input_keys=input_keys,
output_keys=output_keys,
in_dim=in_dim,
out_dim=out_dim,
num_layers=num_layers,
activation=activation,
first_normalization=first_normalization,
use_input_keys_prefix=use_input_keys_prefix,
)
# Parse the `input_keys`.
self.hidden_dim = hidden_dim
self.model_type = model_type
if num_layers_post == 0:
assert hidden_dim == out_dim, "Hidden dim must be equal to out dim if num_layers_post == 0"
self.num_layers_post = num_layers_post
self.dropout = dropout
self.model_kwargs = model_kwargs
if out_dim - in_dim < 1:
raise ValueError(f"LapPE size {in_dim} is too large for " f"desired embedding size of {out_dim}.")
# Initial projection of eigenvalue and the node's eigenvector value
self.linear_in = FCLayer(2, hidden_dim, activation="none")
if self.model_type == "Transformer":
# Transformer model for LapPE
model_kwargs.setdefault("nhead", 1)
encoder_layer = TransformerEncoderLayerMup(
None,
d_model=hidden_dim,
batch_first=True,
dropout=dropout,
activation=self.activation,
**model_kwargs,
)
self.pe_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
elif self.model_type == "DeepSet":
# DeepSet model for LapPE (this will be followed by a sum pooling)
self.pe_encoder = MLP(
in_dim=hidden_dim,
hidden_dims=hidden_dim,
out_dim=hidden_dim,
depth=num_layers,
dropout=dropout,
**model_kwargs,
)
elif self.model_type == "MLP":
# MLP that will mix all eigenvalues and eigenvectors
self.pe_encoder = MLP(
in_dim=self.in_dim * hidden_dim,
hidden_dims=hidden_dim,
out_dim=hidden_dim,
depth=num_layers_post,
dropout=dropout,
activation=activation,
last_activation="none",
**model_kwargs,
)
else:
raise ValueError(f"Unexpected PE model {self.model_type}")
self.post_mlp = None
if num_layers_post > 0:
# MLP to apply post pooling
self.post_mlp = MLP(
in_dim=hidden_dim,
hidden_dims=hidden_dim,
out_dim=out_dim,
depth=num_layers_post,
dropout=dropout,
activation=activation,
last_activation="none",
)
def parse_input_keys(
self,
input_keys: List[str],
) -> List[str]:
r"""
Parse the input keys and make sure they are supported for this encoder
Parameters:
input_keys: List of input keys to use from the data object.
Returns:
List of parsed input keys
"""
if len(input_keys) != 2:
raise ValueError(f"`{self.__class__}` only supports 2 keys")
for key in input_keys:
assert not key.startswith(
"edge_"
), f"Input keys must be node features, not edge features, for encoder {self.__class__}"
assert not key.startswith(
"nodepair_"
), f"Input keys must be node features, not graph features, for encoder {self.__class__}"
assert not key.startswith(
"graph_"
), f"Input keys must be node features, not graph features, for encoder {self.__class__}"
return input_keys
def parse_output_keys(
self,
output_keys: List[str],
) -> List[str]:
r"""
parse the output keys
Parameters:
output_keys: List of output keys to add to the data object.
Returns:
List of parsed output keys
"""
for key in output_keys:
assert not key.startswith(
"edge_"
), f"Edge encodings are not supported for encoder {self.__class__}"
assert not key.startswith(
"nodepair_"
), f"Edge encodings are not supported for encoder {self.__class__}"
assert not key.startswith(
"graph_"
), f"Graph encodings are not supported for encoder {self.__class__}"
return output_keys
def forward(
self,
batch: Batch,
key_prefix: Optional[str] = None,
) -> Dict[str, torch.Tensor]:
r"""
Forward pass of the encoder.
Parameters:
batch: pyg Batches of graphs
key_prefix: Prefix to use for the input and output keys.
Returns:
output dictionary with keys as specified in `output_keys` and their output embeddings.
"""
# input_keys = self.parse_input_keys_with_prefix(key_prefix)
eigvals, eigvecs = batch[self.input_keys[0]], batch[self.input_keys[1]]
# Random flipping to the Laplacian encoder
if self.training:
sign_flip = torch.rand(eigvecs.size(1), device=eigvecs.device, dtype=eigvecs.dtype)
sign_flip[sign_flip >= 0.5] = 1.0
sign_flip[sign_flip < 0.5] = -1.0
eigvecs = eigvecs * sign_flip.unsqueeze(0)
pos_enc = torch.cat(
(eigvecs.unsqueeze(2), eigvals.unsqueeze(2)), dim=2
) # (Num nodes) x (Num Eigenvectors) x 2
empty_mask = torch.isnan(pos_enc) # (Num nodes) x (Num Eigenvectors) x 2
pos_enc[empty_mask] = 0 # (Num nodes) x (Num Eigenvectors) x 2
if self.first_normalization:
pos_enc = self.first_normalization(pos_enc)
pos_enc = self.linear_in(pos_enc) # (Num nodes) x (Num Eigenvectors) x hidden_dim
# PE encoder: a Transformer or DeepSet model
if self.model_type == "Transformer":
pos_enc = self.pe_encoder(src=pos_enc, src_key_padding_mask=empty_mask[:, :, 0])
# (Num nodes) x (Num Eigenvectors) x hidden_dim
elif self.model_type == "DeepSet":
pos_enc = self.pe_encoder(pos_enc)
# (Num nodes) x (Num Eigenvectors) x hidden_dim
elif self.model_type == "MLP":
pos_enc = torch.flatten(pos_enc, start_dim=-2, end_dim=-1)
pos_enc = self.pe_encoder(pos_enc)
# (Num nodes) x hidden_dim
else:
raise ValueError(f"Unexpected PE model {self.model_type}")
if self.model_type in ["Transformer", "DeepSet"]:
# Mask out padded nodes
pos_enc[empty_mask[..., 0]] = 0 # (Num nodes) x (Num Eigenvectors) x hidden_dim
pos_enc = torch.sum(pos_enc, 1, keepdim=False) # (Num nodes) x hidden_dim
# MLP post pooling
if self.post_mlp is not None:
pos_enc = self.post_mlp(pos_enc) # (Num nodes) x out_dim
output = {key: pos_enc for key in self.output_keys}
return output
def make_mup_base_kwargs(
self,
divide_factor: float = 2.0,
factor_in_dim: bool = False,
) -> Dict[str, Any]:
r"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter:
divide_factor: Factor by which to divide the width.
factor_in_dim: Whether to factor the input dimension
Returns:
Dictionary of kwargs to be used to create the base model.
"""
base_kwargs = super().make_mup_base_kwargs(divide_factor, factor_in_dim)
base_kwargs.update(
dict(
hidden_dim=round(self.hidden_dim / divide_factor),
model_type=self.model_type,
num_layers_post=self.num_layers_post,
dropout=self.dropout,
**self.model_kwargs,
)
)
return base_kwargs
================================================
FILE: graphium/nn/encoders/mlp_encoder.py
================================================
import torch
import torch.nn as nn
from typing import Union, Callable, List, Dict, Any, Optional
from torch_geometric.data import Batch
from graphium.nn.base_layers import MLP, get_norm
from graphium.nn.encoders.base_encoder import BaseEncoder
class MLPEncoder(BaseEncoder):
def __init__(
self,
input_keys: List[str],
output_keys: str,
in_dim: int,
hidden_dim: int,
out_dim: int,
num_layers: int,
activation: Union[str, Callable] = "relu",
dropout=0.0,
normalization="none",
first_normalization="none",
use_input_keys_prefix: bool = True,
):
r"""
Configurable kernel-based Positional Encoding node/edge-level encoder.
Parameters:
input_keys: List of input keys to use from pyg batch graph
output_keys: List of output keys to add to the pyg batch graph
in_dim : input dimension of the mlp encoder
hidden_dim : hidden dimension of the mlp encoder
out_dim : output dimension of the mlp encoder
num_layers : number of layers of the mlp encoder
activation : activation function to use
dropout : dropout to use
normalization : normalization to use
first_normalization : normalization to use before the first layer
use_input_keys_prefix: Whether to use the `key_prefix` argument
"""
super().__init__(
input_keys=input_keys,
output_keys=output_keys,
in_dim=in_dim,
out_dim=out_dim,
num_layers=num_layers,
activation=activation,
first_normalization=first_normalization,
use_input_keys_prefix=use_input_keys_prefix,
)
# Check the output_keys
self.hidden_dim = hidden_dim
self.dropout = dropout
self.normalization = normalization
# Initialize the MLP
self.pe_encoder = MLP(
in_dim=in_dim,
hidden_dims=hidden_dim,
out_dim=out_dim,
depth=num_layers,
dropout=dropout,
activation=activation,
last_activation=activation,
first_normalization=self.first_normalization,
normalization=normalization,
last_normalization=normalization,
last_dropout=dropout,
)
def parse_input_keys(
self,
input_keys: List[str],
) -> List[str]:
r"""
Parse the `input_keys`.
Parameters:
input_keys: List of input keys to use from pyg batch graph
Returns:
parsed input_keys
"""
return input_keys
def parse_output_keys(
self,
output_keys: List[str],
) -> List[str]:
r"""
Parse the `output_keys`.
Parameters:
output_keys: List of output keys to add to the pyg batch graph
Returns:
parsed output_keys
"""
assert len(output_keys) == len(
self.input_keys
), f"The number of input keys {len(self.input_keys)} and output keys {len(output_keys)} must be the same for the class {self.__class__.__name__}"
for in_key, out_key in zip(self.input_keys, output_keys):
if in_key.startswith("edge_") or out_key.startswith("edge_"):
assert out_key.startswith("edge_") and in_key.startswith(
"edge_"
), f"The output key {out_key} and input key {in_key} must match the 'edge_' prefix for the class {self.__class__.__name__}"
if in_key.startswith("nodepair_") or out_key.startswith("nodepair_"):
assert out_key.startswith("nodepair_") and in_key.startswith(
"nodepair_"
), f"The output key {out_key} and input key {in_key} must match the 'nodepair_' prefix for the class {self.__class__.__name__}"
if in_key.startswith("graph_") or out_key.startswith("graph_"):
assert out_key.startswith("graph_") and in_key.startswith(
"graph_"
), f"The output key {out_key} and input key {in_key} must match the 'graph_' prefix for the class {self.__class__.__name__}"
return output_keys
def forward(
self,
batch: Batch,
key_prefix: Optional[str] = None,
) -> Dict[str, torch.Tensor]:
r"""
forward function of the mlp encoder
Parameters:
batch: pyg batch graph
key_prefix: Prefix to use for the input keys
Returns:
output: Dictionary of output embeddings with keys specified by input_keys
"""
# TODO: maybe we should also use the output key here? @Dom
# input_keys = self.parse_input_keys_with_prefix(key_prefix)
# TODO: maybe it makes sense to combine MLPEncoder and CatMLPEncoder into one class with
# CatMLPEncoder being executed when the list input_keys contains more than one element.
# Currently, the input_keys list can only contain one element in MLPEncoder.
# Run the MLP for each input key
output = {}
for input_key, output_key in zip(self.input_keys, self.output_keys):
output[output_key] = self.pe_encoder(batch[input_key]) # (Num nodes/edges) x dim_pe
return output
def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: bool = False) -> Dict[str, Any]:
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter:
divide_factor: Factor by which to divide the width.
factor_in_dim: Whether to factor the input dimension
Returns:
base_kwargs: Dictionary of kwargs to use for the base model
"""
base_kwargs = super().make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=factor_in_dim)
base_kwargs.update(
dict(
hidden_dim=round(self.hidden_dim / divide_factor),
dropout=self.dropout,
normalization=self.normalization,
)
)
return base_kwargs
class CatMLPEncoder(BaseEncoder):
def __init__(
self,
input_keys: List[str],
output_keys: str,
in_dim: Union[int, List[int]],
hidden_dim: int,
out_dim: int,
num_layers: int,
activation: Union[str, Callable] = "relu",
dropout=0.0,
normalization="none",
first_normalization="none",
use_input_keys_prefix: bool = True,
):
r"""
Configurable kernel-based Positional Encoding node/edge-level encoder.
Concatenates the list of input (node or edge) features in the feature dimension
Parameters:
input_keys: List of input keys; inputs are concatenated in feat dimension and passed through mlp
output_keys: List of output keys to add to the pyg batch graph
in_dim : input dimension of the mlp encoder; sum of input dimensions of inputs
hidden_dim : hidden dimension of the mlp encoder
out_dim : output dimension of the mlp encoder
num_layers : number of layers of the mlp encoder
activation : activation function to use
dropout : dropout to use
normalization : normalization to use
first_normalization : normalization to use before the first layer
use_input_keys_prefix: Whether to use the `key_prefix` argument
"""
super().__init__(
input_keys=input_keys,
output_keys=output_keys,
in_dim=in_dim,
out_dim=out_dim,
num_layers=num_layers,
activation=activation,
first_normalization=first_normalization,
use_input_keys_prefix=use_input_keys_prefix,
)
if type(in_dim) is list:
in_dim = sum(in_dim)
# Check the output_keys
self.hidden_dim = hidden_dim
self.dropout = dropout
self.normalization = normalization
# Initialize the MLP
self.pe_encoder = MLP(
in_dim=in_dim,
hidden_dims=hidden_dim,
out_dim=out_dim,
depth=num_layers,
dropout=dropout,
activation=activation,
last_activation=activation,
first_normalization=self.first_normalization,
normalization=normalization,
last_normalization=normalization,
last_dropout=dropout,
)
def parse_input_keys(
self,
input_keys: List[str],
) -> List[str]:
r"""
Parse the `input_keys`.
Parameters:
input_keys: List of input keys to use from pyg batch graph
Returns:
parsed input_keys
"""
return input_keys
def parse_output_keys(
self,
output_keys: List[str],
) -> List[str]:
r"""
Parse the `output_keys`.
Parameters:
output_keys: List of output keys to add to the pyg batch graph
Returns:
parsed output_keys
"""
return output_keys
def forward(
self,
batch: Batch,
key_prefix: Optional[str] = None,
) -> Dict[str, torch.Tensor]:
r"""
forward function of the mlp encoder
Parameters:
batch: pyg batch graph
key_prefix: Prefix to use for the input keys
Returns:
output: Dictionary of output embeddings with keys specified by input_keys
"""
# TODO: maybe we should also use the output key here? @Dom
# input_keys = self.parse_input_keys_with_prefix(key_prefix)
# Concatenate selected pes
input = torch.cat(
[batch[input_key] for input_key in self.input_keys], dim=-1
) # [num_nodes/num_edges, sum(in_dims)]
output = {}
for output_key in self.output_keys:
output[output_key] = self.pe_encoder(input)
return output
def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: bool = False) -> Dict[str, Any]:
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter:
divide_factor: Factor by which to divide the width.
factor_in_dim: Whether to factor the input dimension
Returns:
base_kwargs: Dictionary of kwargs to use for the base model
"""
base_kwargs = super().make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=factor_in_dim)
base_kwargs.update(
dict(
hidden_dim=round(self.hidden_dim / divide_factor),
dropout=self.dropout,
normalization=self.normalization,
)
)
return base_kwargs
================================================
FILE: graphium/nn/encoders/signnet_pos_encoder.py
================================================
"""
SignNet https://arxiv.org/abs/2202.13013
based on https://github.com/cptq/SignNet-BasisNet
"""
from typing import Dict, Any, Optional, List
import torch
import torch.nn as nn
from torch_geometric.nn import GINConv
from torch_scatter import scatter
from torch_geometric.data import Batch
from graphium.nn.base_layers import MLP
from graphium.nn.encoders.base_encoder import BaseEncoder
class SimpleGIN(nn.Module):
def __init__(
self, in_dim, hidden_dim, out_dim, num_layers, normalization="none", dropout=0.5, activation="relu"
):
r"""
not supported yet
"""
# TODO: Not sure this works? Needs updating.
super().__init__()
self.layers = nn.ModuleList()
self.normalization = normalization
# input layer
update_net = MLP(
in_dim,
hidden_dim,
hidden_dim,
2,
normalization=normalization,
dropout=dropout,
activation=activation,
last_normalization=normalization,
last_dropout=dropout,
)
self.layers.append(GINConv(update_net))
# hidden layers
for i in range(num_layers - 2):
update_net = MLP(
hidden_dim,
hidden_dim,
hidden_dim,
2,
normalization=normalization,
dropout=dropout,
activation=activation,
last_normalization=normalization,
last_dropout=dropout,
)
self.layers.append(GINConv(update_net))
# output layer
update_net = MLP(
hidden_dim,
hidden_dim,
out_dim,
2,
normalization=normalization,
dropout=dropout,
activation=activation,
last_normalization=normalization,
last_dropout=dropout,
)
self.layers.append(GINConv(update_net))
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, edge_index):
for ii, layer in enumerate(self.layers):
x = layer(x, edge_index)
return x
class GINDeepSigns(nn.Module):
"""Sign invariant neural network with MLP aggregation.
f(v1, ..., vk) = rho(enc(v1) + enc(-v1), ..., enc(vk) + enc(-vk))
"""
def __init__(
self,
in_channels,
hidden_channels,
out_channels,
num_layers,
k,
dim_pe,
rho_num_layers,
normalization="none",
dropout=0.5,
activation="relu",
):
# TODO: Not sure this works? Needs updating.
super().__init__()
self.enc = SimpleGIN(
in_channels,
hidden_channels,
out_channels,
num_layers,
normalization=normalization,
dropout=dropout,
activation=activation,
)
rho_dim = out_channels * k
self.rho = MLP(
rho_dim,
hidden_channels,
dim_pe,
rho_num_layers,
normalization=normalization,
dropout=dropout,
activation=activation,
)
def forward(self, x, edge_index, batch_index):
N = x.shape[0] # Total number of nodes in the batch.
x = x.transpose(0, 1) # N x K x In -> K x N x In
x = self.enc(x, edge_index) + self.enc(-x, edge_index)
x = x.transpose(0, 1).reshape(N, -1) # K x N x Out -> N x (K * Out)
x = self.rho(x) # N x dim_pe (Note: in the original codebase dim_pe is always K)
return x
class MaskedGINDeepSigns(nn.Module):
"""Sign invariant neural network with sum pooling and DeepSet.
f(v1, ..., vk) = rho(enc(v1) + enc(-v1), ..., enc(vk) + enc(-vk))
"""
def __init__(
self,
in_channels,
hidden_channels,
out_channels,
num_layers,
dim_pe,
rho_num_layers,
normalization="none",
dropout=0.5,
activation="relu",
):
# TODO: Not sure this works? Needs updating.
super().__init__()
self.enc = SimpleGIN(
in_channels,
hidden_channels,
out_channels,
num_layers,
normalization=normalization,
dropout=dropout,
activation=activation,
)
self.rho = MLP(
out_channels,
hidden_channels,
dim_pe,
rho_num_layers,
normalization=normalization,
dropout=dropout,
activation=activation,
)
def batched_n_nodes(self, batch_index):
batch_size = batch_index.max().item() + 1
one = batch_index.new_ones(batch_index.size(0))
n_nodes = scatter(
one, batch_index, dim=0, dim_size=batch_size, reduce="add"
) # Number of nodes in each graph.
n_nodes = n_nodes.unsqueeze(1)
return torch.cat([size * n_nodes.new_ones(size) for size in n_nodes])
def forward(self, x, edge_index, batch_index):
N = x.shape[0] # Total number of nodes in the batch.
K = x.shape[1] # Max. number of eigen vectors / frequencies.
x = x.transpose(0, 1) # N x K x In -> K x N x In
x = self.enc(x, edge_index) + self.enc(-x, edge_index) # K x N x Out
x = x.transpose(0, 1) # K x N x Out -> N x K x Out
batched_num_nodes = self.batched_n_nodes(batch_index)
mask = torch.cat([torch.arange(K).unsqueeze(0) for _ in range(N)])
mask = (mask.to(batch_index.device) < batched_num_nodes.unsqueeze(1)).bool()
x[~mask] = 0
x = x.sum(dim=1) # (sum over K) -> N x Out
x = self.rho(x) # N x Out -> N x dim_pe (Note: in the original codebase dim_pe is always K)
return x
class SignNetNodeEncoder(BaseEncoder):
r"""SignNet Positional Embedding node encoder.
https://arxiv.org/abs/2202.13013
https://github.com/cptq/SignNet-BasisNet
Uses precomputated Laplacian eigen-decomposition, but instead
of eigen-vector sign flipping + DeepSet/Transformer, computes the PE as:
SignNetPE(v_1, ... , v_k) = \rho ( [\phi(v_i) + \rhi(-v_i)]^k_i=1 )
where \phi is GIN network applied to k first non-trivial eigenvectors, and
\rho is an MLP if k is a constant, but if all eigenvectors are used then
\rho is DeepSet with sum-pooling.
SignNetPE of size dim_pe will get appended to each node feature vector.
Args:
dim_emb: Size of final node embedding
"""
def __init__(
self,
input_keys: Dict,
output_keys: List[str],
in_dim, # Size of PE embedding
hidden_dim,
out_dim,
num_layers, # Num. layers in \phi GNN part
model_type, # 'MLP' or 'DeepSet'
max_freqs, # Num. eigenvectors (frequencies)
use_input_keys_prefix: bool = True,
num_layers_post=0, # Num. layers in \rho MLP/DeepSet
activation="relu",
dropout=0.0,
normalization="none",
):
# TODO: Not sure this works? Needs updating.
super().__init__(
input_keys=input_keys,
output_keys=output_keys,
in_dim=in_dim,
out_dim=out_dim,
num_layers=num_layers,
activation=activation,
use_input_keys_prefix=use_input_keys_prefix,
)
# Parse the `input_keys`.
self.hidden_dim = hidden_dim
self.model_type = model_type
self.max_freqs = max_freqs
self.num_layers_post = num_layers_post
self.dropout = dropout
self.normalization = normalization
if model_type not in ["MLP", "DeepSet"]:
raise ValueError(f"Unexpected SignNet model {model_type}")
self.model_type = model_type
if num_layers_post < 1:
raise ValueError(f"Num layers in rho model has to be positive.")
if out_dim - in_dim < 1:
raise ValueError(
f"SignNet PE size {in_dim} is too large for " f"desired embedding size of {out_dim}."
)
# Sign invariant neural network.
if self.model_type == "MLP":
self.sign_inv_net = GINDeepSigns(
in_channels=1,
hidden_channels=hidden_dim,
out_channels=hidden_dim,
num_layers=num_layers,
k=max_freqs,
dim_pe=in_dim,
rho_num_layers=num_layers_post,
normalization=normalization,
dropout=dropout,
activation=activation,
)
elif self.model_type == "DeepSet":
self.sign_inv_net = MaskedGINDeepSigns(
in_channels=1,
hidden_channels=hidden_dim,
out_channels=out_dim,
num_layers=num_layers,
dim_pe=in_dim,
rho_num_layers=num_layers_post,
normalization=normalization,
dropout=dropout,
activation=activation,
)
else:
raise ValueError(f"Unexpected model {self.model_type}")
def parse_input_keys(self, input_keys):
if len(input_keys) != 1:
raise ValueError(f"`{self.__class__}` only supports 1 key")
for key in input_keys:
assert not key.startswith(
"edge_"
), f"Input keys must be node features, not edge features, for encoder {self.__class__}"
assert not key.startswith(
"graph_"
), f"Input keys must be node features, not graph features, for encoder {self.__class__}"
return input_keys
def parse_output_keys(self, output_keys):
for key in output_keys:
assert not key.startswith(
"edge_"
), f"Edge encodings are not supported for encoder {self.__class__}"
assert not key.startswith(
"graph_"
), f"Graph encodings are not supported for encoder {self.__class__}"
return output_keys
def forward(self, batch: Batch, key_prefix: Optional[str] = None) -> Dict[str, torch.Tensor]:
input_keys = self.parse_input_keys_with_prefix(key_prefix)
eigvecs, edge_index, batch_index = batch[input_keys[0]], batch["edge_index"], batch["batch_index"]
pos_enc = eigvecs.unsqueeze(-1) # (Num nodes) x (Num Eigenvectors) x 1
empty_mask = torch.isnan(pos_enc)
pos_enc[empty_mask] = 0 # (Num nodes) x (Num Eigenvectors) x 1
# SignNet
pos_enc = self.sign_inv_net(pos_enc, edge_index, batch_index) # (Num nodes) x (pos_enc_dim)
output = {key: pos_enc for key in self.output_keys}
return output
def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: bool = False) -> Dict[str, Any]:
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
# TODO: Update this. It is broken
Parameters:
divide_factor: Factor by which to divide the width.
factor_in_dim: Whether to factor the input dimension
"""
base_kwargs = super().make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=factor_in_dim)
base_kwargs.update(
dict(
hidden_dim=round(self.hidden_dim / divide_factor),
model_type=self.model_type,
max_freqs=self.max_freqs,
num_layers_post=self.num_layers_post,
dropout=self.dropout,
normalization=type(self.normalization).__name__,
)
)
return base_kwargs
================================================
FILE: graphium/nn/ensemble_layers.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Union, Callable, Optional, Type, Tuple, Iterable
from copy import deepcopy
from loguru import logger
import torch
import torch.nn as nn
import mup.init as mupi
from mup import set_base_shapes
from graphium.nn.base_layers import FCLayer, MLP
class EnsembleLinear(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
num_ensemble: int,
bias: bool = True,
init_fn: Optional[Callable] = None,
):
r"""
Multiple linear layers that are applied in parallel with batched matrix multiplication with `torch.matmul`.
Parameters:
in_dim:
Input dimension of the linear layers
out_dim:
Output dimension of the linear layers.
num_ensemble:
Number of linear layers in the ensemble.
"""
super(EnsembleLinear, self).__init__()
# Initialize weight and bias as learnable parameters
self.weight = nn.Parameter(torch.Tensor(num_ensemble, out_dim, in_dim))
if bias:
self.bias = nn.Parameter(torch.Tensor(num_ensemble, 1, out_dim))
else:
self.register_parameter("bias", None)
# Initialize parameters
self.init_fn = init_fn if init_fn is not None else mupi.xavier_uniform_
self.reset_parameters()
def reset_parameters(self):
"""
Reset the parameters of the linear layer using the `init_fn`.
"""
set_base_shapes(self, None, rescale_params=False) # Set the shapes of the tensors, useful for mup
# Initialize weight using the provided initialization function
self.init_fn(self.weight)
# Initialize bias if present
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, h: torch.Tensor) -> torch.Tensor:
r"""
Apply the batched linear transformation on the input features.
Parameters:
h: `torch.Tensor[B, Din]` or `torch.Tensor[..., 1, B, Din]` or `torch.Tensor[..., L, B, Din]`:
Input feature tensor, before the batched linear transformation.
`Din` is the number of input features, `B` is the batch size, and `L` is the number of linear layers.
Returns:
`torch.Tensor[..., L, B, Dout]`:
Output feature tensor, after the batched linear transformation.
`Dout` is the number of output features, , `B` is the batch size, and `L` is the number of linear layers.
"""
# Perform the linear transformation using torch.matmul
h = torch.matmul(self.weight, h.transpose(-1, -2)).transpose(-1, -2)
# Add bias if present
if self.bias is not None:
h += self.bias
return h
class EnsembleFCLayer(FCLayer):
def __init__(
self,
in_dim: int,
out_dim: int,
num_ensemble: int,
activation: Union[str, Callable] = "relu",
dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
bias: bool = True,
init_fn: Optional[Callable] = None,
is_readout_layer: bool = False,
droppath_rate: float = 0.0,
):
r"""
Multiple fully connected layers running in parallel.
This layer is centered around a `torch.nn.Linear` module.
The order in which transformations are applied is:
- Dense Layer
- Activation
- Dropout (if applicable)
- Batch Normalization (if applicable)
Parameters:
in_dim:
Input dimension of the layer (the `torch.nn.Linear`)
out_dim:
Output dimension of the layer.
num_ensemble:
Number of linear layers in the ensemble.
dropout:
The ratio of units to dropout. No dropout by default.
activation:
Activation function to use.
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
bias:
Whether to enable bias in for the linear layer.
init_fn:
Initialization function to use for the weight of the layer. Default is
$$\mathcal{U}(-\sqrt{k}, \sqrt{k})$$ with $$k=\frac{1}{ \text{in_dim}}$$
is_readout_layer: Whether the layer should be treated as a readout layer by replacing of `torch.nn.Linear`
by `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup
droppath_rate:
stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382
Attributes:
dropout (int):
The ratio of units to dropout.
normalization (None or Callable):
Normalization layer
linear (`torch.nn.Linear`):
The linear layer
activation (`torch.nn.Module`):
The activation layer
init_fn (Callable):
Initialization function used for the weight of the layer
in_dim (int):
Input dimension of the linear layer
out_dim (int):
Output dimension of the linear layer
"""
super().__init__(
in_dim=in_dim,
out_dim=out_dim,
activation=activation,
dropout=dropout,
normalization=normalization,
bias=bias,
init_fn=init_fn,
is_readout_layer=is_readout_layer,
droppath_rate=droppath_rate,
)
# Linear layer, or MuReadout layer
if not is_readout_layer:
self.linear = EnsembleLinear(
in_dim, out_dim, num_ensemble=num_ensemble, bias=bias, init_fn=init_fn
)
else:
self.linear = EnsembleMuReadoutGraphium(in_dim, out_dim, num_ensemble=num_ensemble, bias=bias)
self.reset_parameters()
def reset_parameters(self, init_fn=None):
"""
Reset the parameters of the linear layer using the `init_fn`.
"""
set_base_shapes(self, None, rescale_params=False) # Set the shapes of the tensors, useful for mup
self.linear.reset_parameters()
def __repr__(self):
rep = super().__repr__()
rep = rep[:-1] + f", num_ensemble={self.linear.weight.shape[0]})"
return rep
class EnsembleMuReadoutGraphium(EnsembleLinear):
"""
This layer implements an ensemble version of μP with a 1/width multiplier and a
constant variance initialization for both weights and biases.
"""
def __init__(
self,
in_dim: int,
out_dim: int,
num_ensemble: int,
bias: bool = True,
init_fn: Optional[Callable] = None,
readout_zero_init=False,
output_mult=1.0,
):
self.in_dim = in_dim
self.output_mult = output_mult
self.readout_zero_init = readout_zero_init
self._base_width = in_dim
super().__init__(
in_dim=in_dim,
out_dim=out_dim,
num_ensemble=num_ensemble,
bias=bias,
init_fn=init_fn,
)
def reset_parameters(self) -> None:
if self.readout_zero_init:
self.weight.data[:] = 0
if self.bias is not None:
self.bias.data[:] = 0
else:
super().reset_parameters()
def width_mult(self):
assert hasattr(self.weight, "infshape"), (
"Please call set_base_shapes(...). If using torch.nn.DataParallel, "
"switch to distributed training with "
"torch.nn.parallel.DistributedDataParallel instead"
)
return self.weight.infshape.width_mult()
def _rescale_parameters(self):
"""Rescale parameters to convert SP initialization to μP initialization.
Warning: This method is NOT idempotent and should be called only once
unless you know what you are doing.
"""
if hasattr(self, "_has_rescaled_params") and self._has_rescaled_params:
raise RuntimeError(
"`_rescale_parameters` has been called once before already. "
"Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n"
"If you called `set_base_shapes` on a model loaded from a checkpoint, "
"or just want to re-set the base shapes of an existing model, "
"make sure to set the flag `rescale_params=False`.\n"
"To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call."
)
if self.bias is not None:
self.bias.data *= self.width_mult() ** 0.5
self.weight.data *= self.width_mult() ** 0.5
self._has_rescaled_params = True
def forward(self, x):
return super().forward(self.output_mult * x / self.width_mult())
@property
def absolute_width(self):
return float(self.in_dim)
@property
def base_width(self):
return self._base_width
@base_width.setter
def base_width(self, val):
if val is None:
return
assert isinstance(
val, (int, torch.int, torch.long)
), f"`base_width` must be None, int or long, provided {val} of type {type(val)}"
self._base_width = val
def width_mult(self):
return self.absolute_width / self.base_width
class EnsembleMLP(MLP):
def __init__(
self,
in_dim: int,
hidden_dims: Union[Iterable[int], int],
out_dim: int,
num_ensemble: int,
depth: Optional[int] = None,
reduction: Optional[Union[str, Callable]] = "none",
activation: Union[str, Callable] = "relu",
last_activation: Union[str, Callable] = "none",
dropout: float = 0.0,
last_dropout: float = 0.0,
normalization: Union[Type[None], str, Callable] = "none",
last_normalization: Union[Type[None], str, Callable] = "none",
first_normalization: Union[Type[None], str, Callable] = "none",
last_layer_is_readout: bool = False,
droppath_rate: float = 0.0,
constant_droppath_rate: bool = True,
):
r"""
Simple multi-layer perceptron, built of a series of FCLayers
Parameters:
in_dim:
Input dimension of the MLP
hidden_dims:
Either an integer specifying all the hidden dimensions,
or a list of dimensions in the hidden layers.
out_dim:
Output dimension of the MLP.
num_ensemble:
Number of MLPs that run in parallel.
depth:
If `hidden_dims` is an integer, `depth` is 1 + the number of
hidden layers to use.
If `hidden_dims` is a list, then
`depth` must be `None` or equal to `len(hidden_dims) + 1`
reduction:
Reduction to use at the end of the MLP. Choices:
- "none" or `None`: No reduction
- "mean": Mean reduction
- "sum": Sum reduction
- "max": Max reduction
- "min": Min reduction
- "median": Median reduction
- `Callable`: Any callable function. Must take `dim` as a keyword argument.
activation:
Activation function to use in all the layers except the last.
if `layers==1`, this parameter is ignored
last_activation:
Activation function to use in the last layer.
dropout:
The ratio of units to dropout. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization in the hidden layers.
- `Callable`: Any callable function
if `layers==1`, this parameter is ignored
last_normalization:
Norrmalization to use **after the last layer**. Same options as `normalization`.
first_normalization:
Norrmalization to use in **before the first layer**. Same options as `normalization`.
last_dropout:
The ratio of units to dropout at the last layer.
last_layer_is_readout: Whether the last layer should be treated as a readout layer.
Allows to use the `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup
droppath_rate:
stochastic depth drop rate, between 0 and 1.
See https://arxiv.org/abs/1603.09382
constant_droppath_rate:
If `True`, drop rates will remain constant accross layers.
Otherwise, drop rates will vary stochastically.
See `DropPath.get_stochastic_drop_rate`
"""
super().__init__(
in_dim=in_dim,
hidden_dims=hidden_dims,
out_dim=out_dim,
depth=depth,
activation=activation,
last_activation=last_activation,
dropout=dropout,
last_dropout=last_dropout,
normalization=normalization,
last_normalization=last_normalization,
first_normalization=first_normalization,
last_layer_is_readout=last_layer_is_readout,
droppath_rate=droppath_rate,
constant_droppath_rate=constant_droppath_rate,
fc_layer=EnsembleFCLayer,
fc_layer_kwargs={"num_ensemble": num_ensemble},
)
self.reduction_fn = self._parse_reduction(reduction)
def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optional[Callable]:
r"""
Parse the reduction argument.
"""
if isinstance(reduction, str):
reduction = reduction.lower()
if reduction is None or reduction == "none":
return None
elif reduction == "mean":
return torch.mean
elif reduction == "sum":
return torch.sum
elif reduction == "max":
def max_vals(x, dim):
return torch.max(x, dim=dim).values
return max_vals
elif reduction == "min":
def min_vals(x, dim):
return torch.min(x, dim=dim).values
return min_vals
elif reduction == "median":
def median_vals(x, dim):
return torch.median(x, dim=dim).values
return median_vals
elif callable(reduction):
return reduction
else:
raise ValueError(f"Unknown reduction {reduction}")
def forward(self, h: torch.Tensor) -> torch.Tensor:
r"""
Apply the ensemble MLP on the input features, then reduce the output if specified.
Parameters:
h: `torch.Tensor[B, Din]` or `torch.Tensor[..., 1, B, Din]` or `torch.Tensor[..., L, B, Din]`:
Input feature tensor, before the MLP.
`Din` is the number of input features, `B` is the batch size, and `L` is the number of ensembles.
Returns:
`torch.Tensor[..., L, B, Dout]` or `torch.Tensor[..., B, Dout]`:
Output feature tensor, after the MLP.
`Dout` is the number of output features, `B` is the batch size, and `L` is the number of ensembles.
`L` is removed if a reduction is specified.
"""
h = super().forward(h)
if self.reduction_fn is not None:
h = self.reduction_fn(h, dim=-3)
return h
def __repr__(self):
r"""
Controls how the class is printed
"""
rep = super().__repr__()
rep = rep[:-1] + f", num_ensemble={self.layers[0].linear.weight.shape[0]})"
================================================
FILE: graphium/nn/pyg_layers/README.md
================================================
The Graph Of LIfe Library.
## What is in this folder?
code for different PyG GNN layers and graph transformer layers for this library
- ✅ `pna_pyg.py`: `PNAMessagePassingPyg` class implements the PNA layer, good example script for new GNN layers
- ✅ `gps_pyg.py`: `GPSLayerPyg` class is the implementation of the graphGPS layer in PyG, a graph transformer layer
- `gated_gcn_pyg.py`: `GatedGCNPyg` class for the gated GCN layer implementation.
- `gin_pyg.py`: GIN and GINE layer implementation in the `GINConvPyg` and `GINEConvPyg` class respectively
- `mpnn_pyg.py`: `MPNNPlusPyg` class implements the MPNN layer
- `pooling_pyg.py`: pooling layers in pyg
================================================
FILE: graphium/nn/pyg_layers/__init__.py
================================================
from .gcn_pyg import GCNConvPyg
from .gated_gcn_pyg import GatedGCNPyg
from .gin_pyg import GINConvPyg
from .gin_pyg import GINEConvPyg
from .pna_pyg import PNAMessagePassingPyg
from .mpnn_pyg import MPNNPlusPyg
from .gps_pyg import GPSLayerPyg
from .pooling_pyg import scatter_logsum_pool
from .pooling_pyg import scatter_std_pool
from .pooling_pyg import parse_pooling_layer_pyg
from .pooling_pyg import VirtualNodePyg
from .dimenet_pyg import DimeNetPyg
================================================
FILE: graphium/nn/pyg_layers/dimenet_pyg.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Callable, Union, Optional, Tuple
from functools import partial
import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.data import Data, Batch
from torch_geometric.utils import scatter
from graphium.nn.base_graph_layer import BaseGraphModule
from graphium.utils.decorators import classproperty
from graphium.nn.pyg_layers.utils import triplets
from graphium.nn.base_layers import MLP, FCLayer
class ResidualLayer(torch.nn.Module):
r"""Modified from
https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/dimenet.html
(align style with our codebase)
"""
def __init__(self, hidden_channels: int, activation: Union[Callable, str]):
super().__init__()
self.lin1 = FCLayer(hidden_channels, hidden_channels, activation=activation)
self.lin2 = FCLayer(hidden_channels, hidden_channels, activation=activation)
def forward(self, x: Tensor) -> Tensor:
return x + self.lin2(self.lin1(x))
class OutputBlock(torch.nn.Module):
r"""Modified from
https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/dimenet.html
(align style)
"""
def __init__(
self,
num_radial: int,
hidden_channels: int,
out_channels: int,
num_layers: int,
activation: Union[Callable, str],
):
super().__init__()
self.lin_rbf = FCLayer(num_radial, hidden_channels, bias=False, activation=None)
self.lins = nn.ModuleList()
for _ in range(num_layers):
self.lins.append(FCLayer(hidden_channels, hidden_channels, activation=activation))
self.lin = FCLayer(hidden_channels, out_channels, bias=False, activation=None)
def forward(self, x: Tensor, rbf: Tensor, i: Tensor, num_nodes: Optional[int] = None) -> Tensor:
x = self.lin_rbf(rbf) * x
x = scatter(x, i, dim=0, dim_size=num_nodes, reduce="sum")
for lin in self.lins:
x = lin(x)
return self.lin(x)
class InteractionBlock(nn.Module):
r"""Modified from
https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/dimenet.html
(align style; add output linear layer to allow change of dimension)
"""
def __init__(
self,
hidden_dim: int,
output_dim: int,
num_bilinear: int,
num_spherical: int,
num_radial: int,
num_before_skip: int,
num_after_skip: int,
activation: Union[Callable, str],
):
super().__init__()
self.lin_rbf = FCLayer(num_radial, hidden_dim, activation=None, bias=False)
self.lin_sbf = FCLayer(num_spherical * num_radial, num_bilinear, activation=None, bias=False)
# Dense transformations of input messages.
self.lin_kj = FCLayer(hidden_dim, hidden_dim, activation=activation)
self.lin_ji = FCLayer(hidden_dim, hidden_dim, activation=activation)
self.W = nn.Parameter(torch.Tensor(hidden_dim, num_bilinear, hidden_dim))
self.W.data.normal_(mean=0, std=2 / self.W.size(0))
self.layers_before_skip = nn.ModuleList(
[ResidualLayer(hidden_dim, activation) for _ in range(num_before_skip)]
)
self.lin = FCLayer(hidden_dim, hidden_dim, activation=activation)
self.layers_after_skip = nn.ModuleList(
[ResidualLayer(hidden_dim, activation) for _ in range(num_after_skip)]
)
self.lin_out = FCLayer(hidden_dim, output_dim, activation=activation)
def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor, idx_ji: Tensor) -> Tensor:
"""
Parameters:
x: edge features after encodings [num_edges, hidden_dim]
rbf: bessel rbf of edges [num_edges, num_radial]
sbf: spherical bessel rbf of triplets [num_triplet, num_spherical * num_radial]
idx_kj: indices in edge of triplets [num_triplet] (value range from 0 to num_edges)
idx_ji: indices in edge of triplets [num_triplet] (value range from 0 to num_edges)
"""
rbf = self.lin_rbf(rbf) # [num_edges, hidden_dim]
sbf = self.lin_sbf(sbf) # [num_triplet, hidden_dim]
x_ji = self.lin_ji(x)
x_kj = self.lin_kj(x)
x_kj = x_kj * rbf
x_kj = torch.einsum("wj,wl,ijl->wi", sbf, x_kj[idx_kj], self.W)
x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0), reduce="sum")
h = x_ji + x_kj
for layer in self.layers_before_skip:
h = layer(h)
h = self.lin(h) + x
for layer in self.layers_after_skip:
h = layer(h)
return self.lin_out(h) # [num_edges, output_dim]
class DimeNetPyg(BaseGraphModule):
def __init__(
self,
in_dim: int,
out_dim: int,
in_dim_edges: int,
out_dim_edges: int,
num_bilinear: int,
num_spherical: int,
num_radial: int,
num_before_skip: int = 1,
num_after_skip: int = 2,
num_output_layers: int = 3,
activation: Union[Callable, str] = "relu",
dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
**kwargs,
):
r"""
`"Directional Message Passing for Molecular Graphs" paper.
Parameters:
in_dim:
Input feature dimensions of the layer
out_dim:
Output feature dimensions of the layer
in_dim_edges:
Input feature dimensions of the edges
out_dim_edges:
Output feature dimensions of the edges
num_bilinear:
The dimension of bilinear layer in the interaction block
num_spherical:
The number of spherical harmonics
num_radial:
The number of radial basis functions
num_before_skip:
The number of residual layers before skip connection (default: 1)
num_after_skip:
The number of residual layers after skip connection (default: 2)
num_output_layers:
The number of output layers for a single output block (default: 3)
activation:
activation function to use in the layer
dropout:
The ratio of units to dropout. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
init_eps :
Initial :math:`\epsilon` value, default: ``0``.
learn_eps :
If True, :math:`\epsilon` will be a learnable parameter.
"""
super().__init__(
in_dim=in_dim,
out_dim=out_dim,
activation=activation,
dropout=dropout,
normalization=normalization,
**kwargs,
)
# transform old node feature
self.node_model = MLP(
in_dim=in_dim,
hidden_dims=in_dim,
out_dim=out_dim,
depth=2,
activation=activation,
normalization=self.normalization,
)
# update edge feature
self.interaction_block = InteractionBlock(
in_dim_edges,
out_dim_edges,
num_bilinear,
num_spherical,
num_radial,
num_before_skip,
num_after_skip,
activation,
)
# updated edge feature -> new node feature
self.output_block = OutputBlock(
num_radial=num_radial,
hidden_channels=out_dim_edges,
out_channels=out_dim,
num_layers=num_output_layers,
activation=activation,
)
def forward(self, batch: Union[Data, Batch]) -> Union[Data, Batch]:
r"""
forward function of the layer
Parameters:
batch: pyg Batch graphs to pass through the layer
Returns:
batch: pyg Batch graphs
"""
assert (
"radius_edge_index" in batch
), "radius_edge_index not in batch, make sure to use 3D encoder firstly"
# (j, i) = edge_index
i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(
batch.radius_edge_index, num_nodes=batch.feat.size(0)
)
x, P = batch.edge_feat, batch.feat
rbf, sbf = batch.edge_rbf, batch.triplet_sbf
# apply MLP to node embeddings
P = self.node_model(P) # [num_nodes, out_dim]
# rbf and sbf should be computed during pos encoder
x = self.interaction_block(x, rbf, sbf, idx_kj, idx_ji)
P = P + self.output_block(x, rbf, i, num_nodes=batch.feat.size(0)) # [num_nodes, out_dim]
batch.edge_feat = x # updated edge features
batch.feat = P # updated node features
return batch
############################################################################################################
@classproperty
def layer_supports_edges(cls) -> bool:
r"""
Return a boolean specifying if the layer type supports edges or not.
Returns:
supports_edges: bool
Always ``True`` for the current class
"""
return True
@property
def layer_inputs_edges(self) -> bool:
r"""
Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Always ``True`` for the current class
"""
return True
@property
def layer_outputs_edges(self) -> bool:
r"""
Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Always ``True`` for the current class
"""
return True
@property
def out_dim_factor(self) -> int:
r"""
Get the factor by which the output dimension is multiplied for
the next layer.
For standard layers, this will return ``1``.
But for others, such as ``GatLayer``, the output is the concatenation
of the outputs from each head, so the out_dim gets multiplied by
the number of heads, and this function should return the number
of heads.
Returns:
int:
Always ``1`` for the current class
"""
return 1
================================================
FILE: graphium/nn/pyg_layers/gated_gcn_pyg.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Union, Callable
from functools import partial
import torch
import torch.nn as nn
from torch_geometric.nn.conv import MessagePassing
from torch_scatter import scatter
from torch_geometric.data import Data, Batch
from graphium.nn.base_graph_layer import BaseGraphStructure, check_intpus_allow_int
from graphium.nn.base_layers import FCLayer
from graphium.utils.decorators import classproperty
class GatedGCNPyg(MessagePassing, BaseGraphStructure):
def __init__(
self,
in_dim: int,
out_dim: int,
in_dim_edges: int,
out_dim_edges: int = None,
activation: Union[Callable, str] = "relu",
dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
eps: float = 1e-5,
**kwargs,
):
r"""
ResGatedGCN: Residual Gated Graph ConvNets
An Experimental Study of Neural Networks for Variable Graphs (Xavier Bresson and Thomas Laurent, ICLR 2018)
https://arxiv.org/pdf/1711.07553v2.pdf
Parameters:
in_dim:
Input feature dimensions of the layer
out_dim:
Output feature dimensions of the layer, and for the edges
in_dim_edges:
Input edge-feature dimensions of the layer
activation:
activation function to use in the layer
dropout:
The ratio of units to dropout. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
eps:
Epsilon value for the normalization `sum(gate_weights * messages) / (sum(gate_weights) + eps)`,
where `gate_weights` are the weights of the gates and follow a sigmoid function.
"""
MessagePassing.__init__(self, aggr="add", flow="source_to_target", node_dim=-2)
BaseGraphStructure.__init__(
self,
in_dim=in_dim,
out_dim=out_dim,
activation=activation,
dropout=dropout,
normalization=normalization,
**kwargs,
)
self._initialize_activation_dropout_norm()
# Allow int32 in the edge_index
self.__check_input__ = partial(check_intpus_allow_int, self)
if out_dim_edges is None:
out_dim_edges = in_dim_edges
# Initialize the layers for the gating
self.A = FCLayer(in_dim, out_dim, activation=None, bias=True)
self.B = FCLayer(in_dim, out_dim, activation=None, bias=True)
self.C = FCLayer(in_dim_edges, out_dim, activation=None, bias=True)
self.D = FCLayer(in_dim, out_dim, activation=None, bias=True)
self.E = FCLayer(in_dim, out_dim, activation=None, bias=True)
self.edge_out = FCLayer(
in_dim=out_dim, out_dim=out_dim_edges, activation=None, dropout=dropout, bias=True
)
self.eps = eps
def forward(
self,
batch: Union[Data, Batch],
) -> Union[Data, Batch]:
r"""
Forward pass the Gated GCN layer
extract the following from the batch:
x, node features with dim [n_nodes, in_dim]
e, edge features with dim [n_edges, in_dim]
edge_index with dim [2, n_edges]
Parameters:
batch: pyg Batch graph to pass through the layer
Returns:
batch: pyg Batch graph
"""
x, e, edge_index = batch.feat, batch.edge_feat, batch.edge_index
# Apply the linear layers
Ax = self.A(x)
Bx = self.B(x)
Ce = self.C(e)
Dx = self.D(x)
Ex = self.E(x)
# Propagate, and apply norm, activation, dropout
x, e = self.propagate(edge_index, Bx=Bx, Dx=Dx, Ex=Ex, Ce=Ce, e=e, Ax=Ax)
x = self.apply_norm_activation_dropout(x, batch_idx=batch.batch)
e = self.edge_out(e)
# Output
batch.feat = x
batch.edge_feat = e
return batch
def message(self, Dx_i: torch.Tensor, Ex_j: torch.Tensor, Ce: torch.Tensor) -> torch.Tensor:
"""
message function
Parameters:
Dx_i: tensor with dimension [n_edges, out_dim]
Ex_j: tensor with dimension [n_edges, out_dim]
Ce: tensor with dimension [n_edges, out_dim]
Returns:
sigma_ij: tensor with dimension [n_edges, out_dim]
"""
e_ij = Dx_i + Ex_j + Ce
sigma_ij = torch.sigmoid(e_ij)
self.e = e_ij
return sigma_ij
def aggregate(
self,
sigma_ij: torch.Tensor,
index: torch.Tensor,
Bx_j: torch.Tensor,
Bx: torch.Tensor,
) -> torch.Tensor:
r"""
aggregation function of the layer
Parameters:
sigma_ij: the output from message() function with dim [n_edges, out_dim]
index: dim [n_edges]
Bx_j: dim [n_edges, out_dim]
Bx: dim [n_nodes, out_dim]
Returns:
out: dim [n_nodes, out_dim]
"""
dim_size = Bx.shape[0]
# Sum the messages, weighted by the gates. Sum the gates.
numerator_eta_xj = scatter(sigma_ij * Bx_j, index, 0, None, dim_size, reduce="sum")
denominator_eta_xj = scatter(sigma_ij, index, 0, None, dim_size, reduce="sum")
# Cast to float32 if needed
dtype = denominator_eta_xj.dtype
if dtype == torch.float16:
numerator_eta_xj = numerator_eta_xj.to(dtype=torch.float32)
denominator_eta_xj = denominator_eta_xj.to(dtype=torch.float32)
# Normalize the messages by the sum of the gates
out = numerator_eta_xj / (denominator_eta_xj + self.eps)
# Cast back to float16 if needed
if dtype == torch.float16:
out = out.to(dtype=dtype)
return out
def update(self, aggr_out: torch.Tensor, Ax: torch.Tensor):
r"""
update function of the layer
Parameters:
aggr_out: the output from aggregate() function with dim [n_nodes, out_dim]
Ax: tensor with dim [n_nodes, out_dim]
Returns:
x: dim [n_nodes, out_dim]
e_out: dim [n_edges, out_dim_edges]
"""
x = Ax + aggr_out
e_out = self.e
del self.e
return x, e_out
@classproperty
def layer_supports_edges(cls) -> bool:
r"""
Return a boolean specifying if the layer type supports edges or not.
Returns:
bool:
Always ``True`` for the current class
"""
return True
@property
def layer_inputs_edges(self) -> bool:
r"""
Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Always ``True`` for the current class
"""
return True
@property
def layer_outputs_edges(self) -> bool:
r"""
Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Always ``True`` for the current class
"""
return True
@property
def out_dim_factor(self) -> int:
r"""
Get the factor by which the output dimension is multiplied for
the next layer.
For standard layers, this will return ``1``.
But for others, such as ``GatLayer``, the output is the concatenation
of the outputs from each head, so the out_dim gets multiplied by
the number of heads, and this function should return the number
of heads.
Returns:
int:
Always ``1`` for the current class
"""
return 1
================================================
FILE: graphium/nn/pyg_layers/gcn_pyg.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Callable, Union
from functools import partial
import torch_geometric.nn as pyg_nn
from torch_geometric.data import Data, Batch
from graphium.nn.base_graph_layer import BaseGraphModule, check_intpus_allow_int
from graphium.utils.decorators import classproperty
class GCNConvPyg(BaseGraphModule):
def __init__(
self,
in_dim: int,
out_dim: int,
activation: Union[Callable, str] = "relu",
dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
**kwargs,
):
super().__init__(
in_dim=in_dim,
out_dim=out_dim,
activation=activation,
dropout=dropout,
normalization=normalization,
**kwargs,
)
self.model = pyg_nn.GCNConv(
in_channels=self.in_dim, out_channels=out_dim, add_self_loops=False, normalize=False
)
self.model.__check_input__ = partial(check_intpus_allow_int, self)
def forward(
self,
batch: Union[Data, Batch],
) -> Union[Data, Batch]:
r"""
forward function of the layer
Parameters:
batch: pyg Batch graphs to pass through the layer
Returns:
batch: pyg Batch graphs
"""
batch.feat = self.model(batch.feat, batch.edge_index)
batch.feat = self.apply_norm_activation_dropout(batch.feat, batch_idx=batch.batch)
return batch
@classproperty
def layer_supports_edges(cls) -> bool:
r"""
Return a boolean specifying if the layer type supports edges or not.
Returns:
supports_edges: bool
Always ``False`` for the current class
"""
return False
@property
def layer_inputs_edges(self) -> bool:
r"""
Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Always ``False`` for the current class
"""
return False
@property
def layer_outputs_edges(self) -> bool:
r"""
Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Always ``False`` for the current class
"""
return False
@property
def out_dim_factor(self) -> int:
r"""
Get the factor by which the output dimension is multiplied for
the next layer.
For standard layers, this will return ``1``.
But for others, such as ``GatLayer``, the output is the concatenation
of the outputs from each head, so the out_dim gets multiplied by
the number of heads, and this function should return the number
of heads.
Returns:
int:
Always ``1`` for the current class
"""
return 1
================================================
FILE: graphium/nn/pyg_layers/gin_pyg.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Callable, Union, Optional
from functools import partial
import torch_geometric.nn as pyg_nn
from torch_geometric.data import Data, Batch
from graphium.nn.base_graph_layer import BaseGraphModule, check_intpus_allow_int
from graphium.nn.base_layers import MLP
from graphium.utils.decorators import classproperty
class GINConvPyg(BaseGraphModule):
def __init__(
self,
in_dim: int,
out_dim: int,
activation: Union[Callable, str] = "relu",
dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
**kwargs,
):
r"""
GIN: Graph Isomorphism Networks
HOW POWERFUL ARE GRAPH NEURAL NETWORKS? (Keyulu Xu, Weihua Hu, Jure Leskovec and Stefanie Jegelka, ICLR 2019)
https://arxiv.org/pdf/1810.00826.pdf
[!] code uses the pytorch-geometric implementation of GINConv
Parameters:
in_dim:
Input feature dimensions of the layer
out_dim:
Output feature dimensions of the layer
activation:
activation function to use in the layer
dropout:
The ratio of units to dropout. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
init_eps :
Initial :math:`\epsilon` value, default: ``0``.
learn_eps :
If True, :math:`\epsilon` will be a learnable parameter.
"""
super().__init__(
in_dim=in_dim,
out_dim=out_dim,
activation=activation,
dropout=dropout,
normalization=normalization,
**kwargs,
)
gin_nn = MLP(
in_dim=self.in_dim,
hidden_dims=self.in_dim,
out_dim=self.out_dim,
depth=2,
activation=self.activation_layer,
last_activation="none",
normalization=self.normalization,
last_normalization="none",
)
self.model = pyg_nn.GINConv(gin_nn)
self.model.__check_input__ = partial(check_intpus_allow_int, self)
def forward(
self,
batch: Union[Data, Batch],
) -> Union[Data, Batch]:
r"""
forward function of the layer
Parameters:
batch: pyg Batch graphs to pass through the layer
Returns:
batch: pyg Batch graphs
"""
batch.feat = self.model(batch.feat, batch.edge_index)
batch.feat = self.apply_norm_activation_dropout(batch.feat, batch_idx=batch.batch)
return batch
@classproperty
def layer_supports_edges(cls) -> bool:
r"""
Return a boolean specifying if the layer type supports edges or not.
Returns:
supports_edges: bool
Always ``False`` for the current class
"""
return False
@property
def layer_inputs_edges(self) -> bool:
r"""
Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Always ``False`` for the current class
"""
return False
@property
def layer_outputs_edges(self) -> bool:
r"""
Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Always ``False`` for the current class
"""
return False
@property
def out_dim_factor(self) -> int:
r"""
Get the factor by which the output dimension is multiplied for
the next layer.
For standard layers, this will return ``1``.
But for others, such as ``GatLayer``, the output is the concatenation
of the outputs from each head, so the out_dim gets multiplied by
the number of heads, and this function should return the number
of heads.
Returns:
int:
Always ``1`` for the current class
"""
return 1
class GINEConvPyg(BaseGraphModule):
def __init__(
self,
in_dim: int,
out_dim: int,
in_dim_edges: Optional[int] = None,
activation: Union[Callable, str] = "relu",
dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
**kwargs,
):
r"""
GINE: Graph Isomorphism Networks with Edges
Strategies for Pre-training Graph Neural Networks
Weihua Hu, Bowen Liu, Joseph Gomes, Marinka Zitnik, Percy Liang, Vijay Pande, Jure Leskovec
https://arxiv.org/abs/1905.12265
[!] code uses the pytorch-geometric implementation of GINEConv
Parameters:
in_dim:
Input feature dimensions of the layer
out_dim:
Output feature dimensions of the layer
activation:
activation function to use in the layer
dropout:
The ratio of units to dropout. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
init_eps :
Initial :math:`\epsilon` value, default: ``0``.
learn_eps :
If True, :math:`\epsilon` will be a learnable parameter.
"""
super().__init__(
in_dim=in_dim,
out_dim=out_dim,
activation=activation,
dropout=dropout,
normalization=normalization,
)
gin_nn = MLP(
in_dim=self.in_dim,
hidden_dims=self.in_dim,
out_dim=self.out_dim,
depth=2,
activation=self.activation_layer,
last_activation="none",
normalization=self.normalization,
last_normalization="none",
)
self.model = pyg_nn.GINEConv(gin_nn, edge_dim=in_dim_edges) # , node_dim=-1)
self.model.__check_input__ = partial(check_intpus_allow_int, self)
def forward(
self,
batch: Union[Data, Batch],
) -> Union[Data, Batch]:
r"""
forward function of the layer
Parameters:
batch: pyg Batch graphs to pass through the layer
Returns:
batch: pyg Batch graphs
"""
batch.feat = self.model(batch.feat, batch.edge_index, batch.edge_feat)
batch.feat = self.apply_norm_activation_dropout(batch.feat, batch_idx=batch.batch)
return batch
@classproperty
def layer_supports_edges(cls) -> bool:
r"""
Return a boolean specifying if the layer type supports edges or not.
Returns:
supports_edges: bool
Always ``True`` for the current class
"""
return True
@property
def layer_inputs_edges(self) -> bool:
r"""
Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Always ``True`` for the current class
"""
return True
@property
def layer_outputs_edges(self) -> bool:
r"""
Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Always ``False`` for the current class
"""
return False
@property
def out_dim_factor(self) -> int:
r"""
Get the factor by which the output dimension is multiplied for
the next layer.
For standard layers, this will return ``1``.
But for others, such as ``GatLayer``, the output is the concatenation
of the outputs from each head, so the out_dim gets multiplied by
the number of heads, and this function should return the number
of heads.
Returns:
int:
Always ``1`` for the current class
"""
return 1
================================================
FILE: graphium/nn/pyg_layers/gps_pyg.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import torch
from copy import deepcopy
from typing import Callable, Union, Optional, Dict, Any
from torch.nn import Module
from torch import Tensor
from torch_geometric.data import Batch
from graphium.nn.base_graph_layer import BaseGraphModule
from graphium.nn.base_layers import FCLayer, MultiheadAttentionMup, MLP
from graphium.nn.pyg_layers import (
GatedGCNPyg,
GINConvPyg,
GINEConvPyg,
PNAMessagePassingPyg,
MPNNPlusPyg,
)
from graphium.data.utils import get_keys
from graphium.utils.decorators import classproperty
from graphium.ipu.to_dense_batch import (
to_dense_batch,
to_sparse_batch,
to_packed_dense_batch,
to_sparse_batch_from_packed,
)
from graphium.ipu.ipu_utils import is_running_on_ipu
PYG_LAYERS_DICT = {
"pyg:gin": GINConvPyg,
"pyg:gine": GINEConvPyg,
"pyg:gated-gcn": GatedGCNPyg,
"pyg:pna-msgpass": PNAMessagePassingPyg,
"pyg:mpnnplus": MPNNPlusPyg,
}
ATTENTION_LAYERS_DICT = {
"full-attention": MultiheadAttentionMup,
"none": None,
}
class GPSLayerPyg(BaseGraphModule):
def __init__(
self,
in_dim: int,
out_dim: int,
in_dim_edges: Optional[int] = None,
out_dim_edges: Optional[int] = None,
activation: Union[Callable, str] = "relu",
dropout: float = 0.0,
node_residual: Optional[bool] = True,
normalization: Union[str, Callable] = "none",
mpnn_type: str = "pyg:gine",
mpnn_kwargs: Optional[dict] = None,
attn_type: str = "full-attention",
precision: str = "32",
biased_attention_key: Optional[str] = None,
attn_kwargs=None,
force_consistent_in_dim: bool = True,
droppath_rate_attn: float = 0.0,
droppath_rate_ffn: float = 0.0,
hidden_dim_scaling: float = 4.0,
output_scale: float = 1.0,
**kwargs,
):
r"""
GPS layer implementation in pyg
adapated from https://github.com/rampasek/GraphGPS/blob/main/graphgps/layer/gps_layer.py
GPS: Recipe for a General, Powerful, Scalable Graph Transformer
Ladislav Rampášek, Mikhail Galkin, Vijay Prakash Dwivedi, Anh Tuan Luu, Guy Wolf, Dominique Beaini
https://arxiv.org/abs/2205.12454
GPS++: An Optimised Hybrid MPNN/Transformer for Molecular Property Prediction
Dominic Masters, Josef Dean, Kerstin Klaser, Zhiyi Li, Sam Maddrell-Mander, Adam Sanders, Hatem Helal, Deniz Beker, Ladislav Rampášek, Dominique Beaini
https://arxiv.org/abs/2212.02229
Parameters:
in_dim:
Input node feature dimensions of the layer
out_dim:
Output node feature dimensions of the layer
in_dim_edges:
input edge-feature dimensions of the layer
out_dim_edges:
output edge-feature dimensions of the layer
activation:
activation function to use in the layer
dropout:
The ratio of units to dropout. Must be between 0 and 1
node_residual:
If node residual is used after on the gnn layer output
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
mpnn_type:
type of mpnn used, choose from "pyg:gin", "pyg:gine", "pyg:gated-gcn", "pyg:pna-msgpass" and "pyg:mpnnplus"
mpnn_kwargs:
kwargs for mpnn layer
attn_type:
type of attention used, choose from "full-attention" and "none"
attn_kwargs:
kwargs for attention layer
force_consistent_in_dim:
whether to force the `embed_dim` to be the same as the `in_dim` for the attention and mpnn.
The argument is only valid if `attn_type` is not None. If `embed_dim` is not provided,
it will be set to `in_dim` by default, so this parameter won't have an effect.
droppath_rate_attn:
stochastic depth drop rate for attention layer https://arxiv.org/abs/1603.09382
droppath_rate_ffn:
stochastic depth drop rate for ffn layer https://arxiv.org/abs/1603.09382
mpnn_type:
Type of MPNN layer to use. Choices specified in PYG_LAYERS_DICT
mpnn_kwargs:
Keyword arguments to pass to the MPNN layer
attn_type:
Type of attention layer to use. Choices specified in ATTENTION_LAYERS_DICT
biased_attention_key:
indicates if biased attention is used by specifying a key corresponding to the pyg attribute in the batch (processed by the gaussian kernel encoder)
default: None means biased attention is not used
attn_kwargs:
Keyword arguments to pass to the attention layer
output_scale:
Float value that will be used to scale the activations, helps reduce growth of activations
as the model gets deeper. Default value of 1.0 leaves the layer unchanged.
"""
super().__init__(
in_dim=in_dim,
out_dim=out_dim,
activation=activation,
dropout=dropout,
normalization=normalization,
droppath_rate=droppath_rate_attn,
**kwargs,
)
# Set the other attributes
self.in_dim_edges = in_dim_edges
self.out_dim_edges = out_dim_edges
# Dropout layers
self.dropout_local = self.dropout_layer
self.dropout_attn = self._parse_dropout(dropout=self.dropout)
# DropPath layers
self.droppath_ffn = self._parse_droppath(droppath_rate_ffn)
# Residual connections
self.node_residual = node_residual
self.precision = precision
# MLP applied at the end of the GPS layer
self.mlp = MLP(
in_dim=in_dim,
hidden_dims=int(hidden_dim_scaling * in_dim),
out_dim=in_dim,
depth=2,
activation=activation,
dropout=self.dropout,
last_dropout=self.dropout,
)
self.f_out = FCLayer(in_dim, out_dim, normalization=normalization)
# Normalization layers
self.norm_layer_local = self._parse_norm(normalization=self.normalization, dim=in_dim)
self.norm_layer_attn = self._parse_norm(normalization=self.normalization, dim=in_dim)
self.norm_layer_ff = self._parse_norm(self.normalization)
self.biased_attention_key = biased_attention_key
# Initialize the MPNN and Attention layers
self.mpnn = self._parse_mpnn_layer(mpnn_type, mpnn_kwargs)
self.attn_layer = self._parse_attn_layer(
attn_type, self.biased_attention_key, attn_kwargs, force_consistent_in_dim=force_consistent_in_dim
)
self.output_scale = output_scale
self.use_edges = True if self.in_dim_edges is not None else False
def residual_add(self, feature: Tensor, input_feature: Tensor) -> Tensor:
r"""
Residual additition layer. Allows information to propagate through the model
by skipping the computational layers.
Parameters:
feature: The feature (typically nodes or edges) after message passing
input_feature: The same feature from before message passing
Returns:
The addition of the two tensors.
"""
feature += input_feature
return feature
def scale_activations(self, feature: Tensor, scale_factor: Tensor) -> Tensor:
"""Scale Activations by a constant factor to stop growth of activation scale
and reduce numerical stability issues at low precision
Args:
feature (Tensor): The feature to scale
scale_factor (float): The floating point scale factor
Returns:
Tensor: The scaled features
"""
scale_factor = torch.tensor(scale_factor).to(feature.device)
feature = feature / scale_factor.to(dtype=feature.dtype)
return feature
def forward(self, batch: Batch) -> Batch:
r"""
forward function of the layer
Parameters:
batch: pyg Batch graphs to pass through the layer
Returns:
batch: pyg Batch graphs
"""
# pe, feat, edge_index, edge_feat = batch.pos_enc_feats_sign_flip, batch.feat, batch.edge_index, batch.edge_feat
feat = batch.feat
feat_in = feat # for first residual connection
# Local MPNN with edge attributes.
batch_out = batch.clone()
if self.mpnn is not None:
batch_out = self.mpnn(batch_out)
h_local = batch_out.feat
if self.dropout_local is not None:
h_local = self.dropout_local(h_local)
# Apply the residual connection for the node features and scale the activations by some value to help reduce activation growth
if self.node_residual:
if self.layer_depth < 1:
h_local = self.residual_add(h_local, feat_in)
h_local = self.scale_activations(h_local, self.output_scale)
else:
h_local = self.scale_activations(h_local, self.output_scale)
h_local = self.residual_add(h_local, feat_in)
if self.norm_layer_local is not None:
h_local = self.norm_layer_local(h_local)
# Multi-head attention.
if self.attn_layer is not None:
h_attn = self._self_attention_block(feat, feat_in, batch)
# Combine local and global outputs.
feat = h_local + h_attn
else:
feat = h_local
# MLP block, with skip connection
feat_mlp = self.mlp(feat)
# Add the droppath to the output of the MLP
batch_size = None if feat.device.type != "ipu" else batch.graph_is_true.shape[0]
if self.droppath_ffn is not None:
feat_mlp = self.droppath_ffn(feat_mlp, batch.batch, batch_size)
feat = feat + feat_mlp
feat = self.f_out(feat)
batch_out.feat = feat
return batch_out
def _parse_mpnn_layer(self, mpnn_type, mpnn_kwargs: Dict[str, Any]) -> Optional[Module]:
"""Parse the MPNN layer."""
if mpnn_type is None or mpnn_type == "none":
return
mpnn_kwargs = deepcopy(mpnn_kwargs)
if mpnn_kwargs is None:
mpnn_kwargs = {}
# Set the default values
mpnn_kwargs = deepcopy(mpnn_kwargs)
mpnn_kwargs.setdefault("in_dim", self.in_dim)
mpnn_kwargs.setdefault("out_dim", self.in_dim)
mpnn_kwargs.setdefault("in_dim_edges", self.in_dim_edges)
mpnn_kwargs.setdefault("out_dim_edges", self.out_dim_edges)
# TODO: The rest of default values
self.mpnn_kwargs = mpnn_kwargs
# Initialize the MPNN layer
mpnn_class = PYG_LAYERS_DICT[mpnn_type]
mpnn_layer = mpnn_class(**mpnn_kwargs, layer_depth=self.layer_depth, layer_idx=self.layer_idx)
return mpnn_layer
def _parse_attn_layer(
self,
attn_type,
biased_attention_key: str,
attn_kwargs: Dict[str, Any],
force_consistent_in_dim: bool = True,
) -> Optional[Module]:
"""
parse the input attention layer and check if it is valid
Parameters:
attn_type: type of the attention layer
biased_attention_key: key for the attenion bias
attn_kwargs: kwargs for the attention layer
force_consistent_in_dim: whether to force the `embed_dim` to be the same as the `in_dim`
Returns:
attn_layer: the attention layer
"""
# Set the default values for the Attention layer
if attn_kwargs is None:
attn_kwargs = {}
attn_kwargs.setdefault("num_heads", 1)
attn_kwargs.setdefault("dropout", self.dropout)
attn_kwargs.setdefault("batch_first", True)
self.attn_kwargs = attn_kwargs
# Force the `embed_dim` to be the same as the `in_dim`
attn_kwargs.setdefault("embed_dim", self.in_dim)
if force_consistent_in_dim:
embed_dim = attn_kwargs["embed_dim"]
assert embed_dim == self.in_dim, f"embed_dim={embed_dim} must be equal to in_dim={self.in_dim}"
# Initialize the Attention layer
attn_layer, attn_class = None, None
if attn_type is not None:
attn_class = ATTENTION_LAYERS_DICT[attn_type]
if attn_class is not None:
attn_layer = attn_class(biased_attention_key, **attn_kwargs)
return attn_layer
def _use_packing(self, batch: Batch) -> bool:
"""
Check if we should use packing for the batch of graphs.
"""
batch_keys = get_keys(batch)
return "pack_from_node_idx" in batch_keys and "pack_attn_mask" in batch_keys
def _to_dense_batch(
self,
h: Tensor,
batch: Batch,
batch_size: Optional[int] = None,
max_num_nodes_per_graph: Optional[int] = None,
on_ipu: bool = False,
) -> Tensor:
"""
Convert the batch of graphs to a dense batch.
"""
if self._use_packing(batch):
attn_mask = batch.pack_attn_mask
key_padding_mask = None
idx = batch.pack_from_node_idx
h_dense = to_packed_dense_batch(
h,
pack_from_node_idx=idx,
pack_attn_mask=attn_mask,
max_num_nodes_per_pack=100, # TODO: This should be a parameter
)
else:
attn_mask = None
h_dense, key_padding_mask, idx = to_dense_batch(
h,
batch=batch.batch, # The batch index as a vector that indicates for nodes of which graph it belongs to
batch_size=batch_size,
max_num_nodes_per_graph=max_num_nodes_per_graph,
drop_nodes_last_graph=on_ipu,
)
key_padding_mask = ~key_padding_mask
return h_dense, attn_mask, key_padding_mask, idx
def _to_sparse_batch(self, batch: Batch, h_dense: Tensor, idx: Tensor) -> Tensor:
"""
Convert the dense batch back to a sparse batch.
"""
if self._use_packing(batch):
h = to_sparse_batch_from_packed(
h_dense,
pack_from_node_idx=idx,
)
else:
h = to_sparse_batch(
h_dense,
mask_idx=idx,
)
return h
def _self_attention_block(self, feat: Tensor, feat_in: Tensor, batch: Batch) -> Tensor:
"""
Applying the multi-head self-attention to the batch of graphs.
First the batch is converted from [num_nodes, hidden_dim] to [num_graphs, max_num_nodes, hidden_dim]
Then the self-attention is applied on each graph
Then the batch is converted again to [num_nodes, hidden_dim]
"""
# Multi-head attention.
on_ipu = is_running_on_ipu()
max_num_nodes_per_graph = None
if on_ipu:
max_num_nodes_per_graph = self.max_num_nodes_per_graph
# Convert the tensor to a dense batch, then back to a sparse batch
batch_size = None if feat.device.type != "ipu" else batch.graph_is_true.shape[0]
# h[num_nodes, hidden_dim] -> h_dense[num_graphs, max_num_nodes, hidden_dim]
feat_dense, attn_mask, key_padding_mask, idx = self._to_dense_batch(
feat,
batch=batch, # The batch index as a vector that indicates for nodes of which graph it belongs to
batch_size=batch_size,
max_num_nodes_per_graph=max_num_nodes_per_graph,
on_ipu=on_ipu,
)
attn_bias = None
if self.biased_attention_key is not None and self.biased_attention_key != "none":
attn_bias = batch[self.biased_attention_key]
# h_dense[num_graphs, max_num_nodes, hidden_dim] -> feat_attn[num_graphs, max_num_nodes, hidden_dim]
feat_attn = self._sa_block(
feat_dense, attn_bias=attn_bias, attn_mask=attn_mask, key_padding_mask=key_padding_mask
)
# feat_attn[num_graphs, max_num_nodes, hidden_dim] -> feat_attn[num_nodes, hidden_dim]
feat_attn = self._to_sparse_batch(batch, feat_attn, idx)
# Dropout, residual, norm
if self.dropout_attn is not None:
feat_attn = self.dropout_attn(feat_attn)
feat_attn = feat_in + feat_attn
if self.norm_layer_attn is not None:
feat_attn = self.norm_layer_attn(feat_attn)
if self.droppath_layer is not None:
self.droppath_layer(feat_attn, batch.batch, batch_size=batch_size)
# Combine local and global outputs.
return feat + feat_attn
def _sa_block(
self, x: torch.Tensor, attn_bias: torch.Tensor, attn_mask=None, key_padding_mask=None
) -> torch.Tensor:
"""
Self-attention block.
Parameters:
x: input tensor
attn_bias: attention bias tensor
attn_mask: None
key_padding_mask: None
Returns:
x: output tensor
"""
x = self.attn_layer(
x,
x,
x,
attn_bias=attn_bias,
precision=self.precision,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return x
@classproperty
def layer_supports_edges(cls) -> bool:
r"""
Return a boolean specifying if the layer type supports edges or not.
Returns:
supports_edges: bool
Always ``True`` for the current class
"""
return True
@property
def layer_inputs_edges(self) -> bool:
r"""
Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Always ``True`` for the current class
"""
return True
@property
def layer_outputs_edges(self) -> bool:
r"""
Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Always ``False`` for the current class
"""
if self.mpnn is None:
return False
return self.mpnn.layer_outputs_edges
@property
def out_dim_factor(self) -> int:
r"""
Get the factor by which the output dimension is multiplied for
the next layer.
For standard layers, this will return ``1``.
But for others, such as ``GatLayer``, the output is the concatenation
of the outputs from each head, so the out_dim gets multiplied by
the number of heads, and this function should return the number
of heads.
Returns:
int:
Always ``1`` for the current class
"""
return 1
================================================
FILE: graphium/nn/pyg_layers/mpnn_pyg.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Callable, Optional, Union, Tuple, List
import torch
from torch import Tensor, IntTensor, LongTensor
from graphium.nn.base_graph_layer import BaseGraphModule
from graphium.nn.base_layers import MLP
from graphium.utils.decorators import classproperty
from torch_geometric.nn.aggr import MultiAggregation, Aggregation
from torch_geometric.data import Batch
class MPNNPlusPyg(BaseGraphModule):
def __init__(
self,
in_dim: int = 64,
out_dim: int = 64,
activation: Union[str, Callable] = "gelu",
dropout: float = 0.3,
normalization: Union[str, Callable] = "layer_norm",
gather_from: str = "both",
scatter_to: str = "both",
node_combine_method: str = "concat",
num_node_mlp: int = 2,
mlp_expansion_ratio: int = 4,
use_edges: bool = True,
in_dim_edges: Optional[int] = 32,
out_dim_edges: Optional[int] = 32,
aggregation_method: Optional[List[Union[str, Aggregation]]] = ["sum"],
num_edge_mlp: Optional[int] = 2,
use_globals: bool = True,
edge_dropout_rate: Optional[float] = 0.0035,
**kwargs,
):
r"""
MPNNPlusPyg: InteractionNetwork layer witg edges and global feature, GPS++ type of GNN layer
GPS++: An Optimised Hybrid MPNN/Transformer for Molecular Property Prediction
Dominic Masters, Josef Dean, Kerstin Klaser, Zhiyi Li, Sam Maddrell-Mander, Adam Sanders,
Hatem Helal, Deniz Beker, Ladislav Rampášek, Dominique Beaini
https://arxiv.org/abs/2212.02229
Parameters:
in_dim:
Input feature dimensions of the nodes
out_dim:
Output feature dimensions of the nodes
activation:
Activation function to use in the edge and node model
dropout:
The ratio of units to dropout at the end within apply_norm_activation_dropout.
normalization:
Normalization to use with the edge, node models and at the end
within apply_norm_activation_dropout. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
gather_from:
The method to gather features from. Could choose from:
"senders", "receivers" and "both".
scatter_to:
The method to scatter features to. Could choose from:
"senders", "receivers" and "both".
node_combine_method:
The method to combine the node features, Could choose from:
"sum" and "concat".
aggregation_method:
Methods for aggregating (scatter) messages built from node and edge features.
Provide a list of `Aggregation` or strings.
supported strings are:
- "sum" / "add" (Default)
- "mean"
- "max"
- "min"
- "softmax"
- "median"
- "std"
- "var"
num_node_mlp:
Number of mlp layer used for node model
mlp_expansion_ratio:
Expansion ratio for node and edge mlp
use_edges:
If edge features are used
in_dim_edges:
Input feature dimensions of the edges
out_dim_edges:
Output feature dimensions of the edges
num_edge_mlp:
Number of mlp layer used for edge model
edge_dropout_rate:
dropout rate for the edges
"""
super().__init__(
in_dim=in_dim,
out_dim=out_dim,
activation=activation,
dropout=dropout,
normalization=normalization,
**kwargs,
)
self.gather_from = gather_from
self.scatter_to = scatter_to
self.node_combine_method = node_combine_method
self.num_node_mlp = num_node_mlp
self.mlp_expansion_ratio = mlp_expansion_ratio
self.use_edges = use_edges
self.in_dim_edges = in_dim_edges
self.out_dim_edges = out_dim_edges
self.num_edge_mlp = num_edge_mlp
self.edge_dropout_rate = edge_dropout_rate
self.aggregator = MultiAggregation(list(aggregation_method))
n_agg = len(aggregation_method)
# node_model:
edge_dim = self.out_dim_edges if use_edges else self.in_dim_edges
if self.node_combine_method == "concat":
node_model_in_dim = (1 + 2 * n_agg) * self.in_dim + 2 * n_agg * edge_dim
elif self.node_combine_method == "sum":
node_model_in_dim = (1 + n_agg) * self.in_dim + n_agg * edge_dim
else:
raise ValueError(f"node_combine_method {self.node_combine_method} not recognised.")
node_model_hidden_dim = self.mlp_expansion_ratio * self.in_dim
self.node_model = MLP(
in_dim=node_model_in_dim,
hidden_dims=node_model_hidden_dim,
out_dim=self.out_dim,
depth=self.num_node_mlp,
activation=self.activation_layer,
last_dropout=self.dropout,
normalization=self.normalization,
)
# edge_model:
if self.node_combine_method == "concat":
edge_model_in_dim = 2 * self.in_dim + self.in_dim_edges
elif self.node_combine_method == "sum":
edge_model_in_dim = self.in_dim + self.in_dim_edges
else:
raise ValueError(f"node_combine_method {self.node_combine_method} not recognised.")
edge_model_hidden_dim = self.mlp_expansion_ratio * self.in_dim_edges
self.edge_model = MLP(
in_dim=edge_model_in_dim,
hidden_dims=edge_model_hidden_dim,
out_dim=self.out_dim_edges,
depth=self.num_edge_mlp,
activation=self.activation_layer,
last_dropout=self.edge_dropout_rate,
normalization=self.normalization,
)
self.use_globals = use_globals
def gather_features(
self,
input_features: Tensor,
senders: Union[IntTensor, LongTensor],
receivers: Union[IntTensor, LongTensor],
) -> Tuple[Tensor, Tensor, Tensor]:
r"""
Function to gather node features based on the senders and receivers of the edge indices.
Parameters:
input_features:
Node features of the batch
senders:
Senders of the edge_index of the batch
receivers:
Receivers of the edge_index of the batch
Output:
Gathered node features (sender and receiver) summed up or concatenated
Gathered sender features
Gathered receiver features
"""
out = []
receiver_features = input_features[receivers]
sender_features = input_features[senders]
if self.gather_from == "receivers":
out.append(receiver_features)
if self.gather_from == "senders":
out.append(sender_features)
if self.gather_from == "both":
if self.node_combine_method == "sum":
out.append(receiver_features + sender_features)
elif self.node_combine_method == "concat":
out.append(torch.cat([receiver_features, sender_features], dim=-1))
else:
raise ValueError(f"node_combine_method {self.node_combine_method} not recognised.")
return out, sender_features, receiver_features
def aggregate_features(
self,
input_features: Tensor,
senders: Union[IntTensor, LongTensor],
receivers: Union[IntTensor, LongTensor],
sender_features: Tensor,
receiver_features: Tensor,
size: int,
) -> Tensor:
r"""
Function to aggregate (scatter) messages built from node and edge features.
Parameters:
input_features:
Edge features of the batch
senders:
Senders of the edge_index of the batch
receivers:
Receivers of the edge_index of the batch
sender_features:
Senders features gathered from the gather_features function
receiver_features:
Receiver features gathered from the gather_features function
size:
size of the aggregation, equals to the total number of nodes
Returns:
Tensor:
Aggregated node features
"""
out = []
aggregated_features = []
if self.scatter_to in ["receivers", "both"]:
# using direct_neighbour_aggregation to generate the message
message = torch.cat([input_features, sender_features], dim=-1)
# sum method is used with aggregators
aggregated_features.append(self.aggregator(message, receivers, dim_size=size))
if self.scatter_to in ["senders", "both"]:
# using direct_neighbour_aggregation to generate the message
message = torch.cat([input_features, receiver_features], dim=-1)
# sum method is used with aggregators
aggregated_features.append(self.aggregator(message, senders, dim_size=size))
if self.node_combine_method == "sum" and self.scatter_to == "both":
out.append(aggregated_features[0] + aggregated_features[1])
elif self.scatter_to == "both":
out.append(torch.cat([aggregated_features[0], aggregated_features[1]], dim=-1))
else:
out.extend(aggregated_features)
return out
def forward(self, batch: Batch) -> Batch:
r"""
Forward function of the MPNN Plus layer
Parameters:
batch:
pyg Batch graph to pass through the layer
Returns:
batch:
pyg Batch graph with updated node and edge features
"""
senders = batch.edge_index[0]
receivers = batch.edge_index[1]
# ---------------EDGE step---------------
edge_model_input, sender_nodes, receiver_nodes = self.gather_features(batch.feat, senders, receivers)
if self.use_edges:
edge_model_input.append(batch.edge_feat)
edge_model_input = torch.cat([edge_model_input[0], edge_model_input[1]], dim=-1)
# edge dropout included in the edge_model
batch.edge_feat = self.edge_model(edge_model_input)
else:
batch.edge_feat = edge_model_input
# ---------------NODE step---------------
# message + aggregate
node_count_per_pack = batch.feat.shape[-2]
node_model_input = self.aggregate_features(
batch.edge_feat, senders, receivers, sender_nodes, receiver_nodes, node_count_per_pack
)
node_model_input.append(batch.feat)
batch.feat = torch.cat([node_model_input[0], node_model_input[1]], dim=-1)
batch.feat = self.node_model(batch.feat)
# ---------------Apply norm activation and dropout---------------
# use dropout value of the layer (default 0.3)
batch.feat = self.apply_norm_activation_dropout(
batch.feat,
normalization=False,
activation=False,
batch_idx=batch.batch,
batch_size=batch.num_graphs,
)
return batch
@classproperty
def layer_supports_edges(cls) -> bool:
r"""
Return a boolean specifying if the layer type supports edges or not.
Returns:
supports_edges: bool
Always ``True`` for the current class
"""
return True
@property
def layer_inputs_edges(self) -> bool:
r"""
Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Always ``True`` for the current class
"""
return True
@property
def layer_outputs_edges(self) -> bool:
r"""
Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Always ``True`` for the current class
"""
return True
@property
def out_dim_factor(self) -> int:
r"""
Get the factor by which the output dimension is multiplied for
the next layer.
For standard layers, this will return ``1``.
But for others, such as ``GatLayer``, the output is the concatenation
of the outputs from each head, so the out_dim gets multiplied by
the number of heads, and this function should return the number
of heads.
Returns:
int:
Always ``1`` for the current class
"""
return 1
================================================
FILE: graphium/nn/pyg_layers/pna_pyg.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Dict, List, Optional, Union, Callable
from functools import partial
import torch
from torch import Tensor
from torch_scatter import scatter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import OptTensor
from torch_geometric.utils import degree
from torch_geometric.data import Data, Batch
from graphium.utils.decorators import classproperty
from graphium.nn.base_layers import MLP, FCLayer, get_activation
from graphium.nn.base_graph_layer import BaseGraphStructure, check_intpus_allow_int
class PNAMessagePassingPyg(MessagePassing, BaseGraphStructure):
def __init__(
self,
in_dim: int,
out_dim: int,
aggregators: List[str],
scalers: List[str],
activation: Union[Callable, str] = "relu",
dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
avg_d: Dict[str, float] = {"log": 1.0, "lin": 1.0},
last_activation: Union[Callable, str] = "none",
posttrans_layers: int = 1,
pretrans_layers: int = 1,
in_dim_edges: int = 0,
**kwargs,
):
r"""
Implementation of the message passing architecture of the PNA message passing layer,
previously known as `PNALayerComplex`. This layer applies an MLP as
pretransformation to the concatenation of $[h_u, h_v, e_{uv}]$ to generate
the messages, with $h_u$ the node feature, $h_v$ the neighbour node features,
and $e_{uv}$ the edge feature between the nodes $u$ and $v$.
After the pre-transformation, it aggregates the messages
multiple aggregators and scalers,
concatenates their results, then applies an MLP on the concatenated
features.
PNA: Principal Neighbourhood Aggregation
Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic
https://arxiv.org/abs/2004.05718
[!] code adapted from pytorch-geometric implementation of PNAConv
Parameters:
in_dim:
Input feature dimensions of the layer
out_dim:
Output feature dimensions of the layer
aggregators:
Set of aggregation function identifiers,
e.g. "mean", "max", "min", "std", "sum", "var", "moment3".
The results from all aggregators will be concatenated.
scalers:
Set of scaling functions identifiers
e.g. "identidy", "amplification", "attenuation"
The results from all scalers will be concatenated
activation:
activation function to use in the layer
dropout:
The ratio of units to dropout. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
avg_d:
Average degree of nodes in the training set, used by scalers to normalize
last_activation:
activation function to use in the last layer of the internal MLP
posttrans_layers:
number of layers in the MLP transformation after the aggregation
pretrans_layers:
number of layers in the transformation before the aggregation
in_dim_edges:
size of the edge features. If 0, edges are ignored
"""
MessagePassing.__init__(self, node_dim=0)
BaseGraphStructure.__init__(
self,
in_dim=in_dim,
out_dim=out_dim,
activation=activation,
dropout=dropout,
normalization=normalization,
**kwargs,
)
# Allow int32 as edge index
self.__check_input__ = partial(check_intpus_allow_int, self)
self.aggregators = aggregators
self.scalers = scalers
# Edge dimensions
self.in_dim_edges = in_dim_edges
self.edge_encoder = None
if self.in_dim_edges > 0:
self.edge_encoder = FCLayer(self.in_dim_edges, self.in_dim_edges, activation=None)
# Initializing basic attributes
self.avg_d = avg_d
self.last_activation = get_activation(last_activation)
# MLP used on each pair of nodes with their edge MLP(h_u, h_v, e_uv)
self.pretrans = MLP(
in_dim=2 * in_dim + in_dim_edges,
hidden_dims=in_dim,
out_dim=in_dim,
depth=pretrans_layers,
activation=self.activation,
last_activation=self.last_activation,
dropout=dropout,
normalization=normalization,
last_normalization=normalization,
)
# MLP used on the aggregated messages of the neighbours
self.posttrans = MLP(
in_dim=(len(aggregators) * len(scalers)) * self.in_dim,
hidden_dims=self.out_dim,
out_dim=self.out_dim,
depth=posttrans_layers,
activation=self.activation,
last_activation=self.last_activation,
dropout=dropout,
normalization=normalization,
last_normalization=normalization,
)
def forward(self, batch: Union[Data, Batch]) -> Union[Data, Batch]:
r"""
forward function of the layer
Parameters:
batch: pyg Batch graphs
Returns:
batch: pyg Batch graphs
"""
feat, edge_index, edge_feat = batch.feat, batch.edge_index, batch.edge_feat
out = self.propagate(edge_index, x=feat, edge_feat=edge_feat, size=None)
out = self.posttrans(out) # No more towers and concat with x
batch.feat = out
return batch
def message(self, x_i: Tensor, x_j: Tensor, edge_feat: OptTensor) -> Tensor:
r"""
message function
Parameters:
x_i: node features
x_j: neighbour node features
edge_feat: edge features
Returns:
feat: the message
"""
feat: Tensor = x_i # Dummy.
if (edge_feat is not None) and (self.edge_encoder is not None):
edge_feat = self.edge_encoder(edge_feat)
feat = torch.cat([x_i, x_j, edge_feat], dim=-1)
else:
feat = torch.cat([x_i, x_j], dim=-1)
return self.pretrans(feat) # No more towers
def aggregate(
self,
inputs: Tensor,
index: Tensor,
edge_index: Tensor,
dim_size: Optional[int] = None,
) -> Tensor:
r"""
aggregate function
Parameters:
inputs: input features
index: index of the nodes
edge_index: edge index
dim_size: dimension size
Returns:
out: aggregated features
"""
outs = []
for aggregator in self.aggregators:
if aggregator == "sum":
out = scatter(inputs, index, 0, None, dim_size, reduce="sum")
elif aggregator == "mean":
out = scatter(inputs, index, 0, None, dim_size, reduce="mean")
elif aggregator == "min":
out = scatter(inputs, index, 0, None, dim_size, reduce="min")
elif aggregator == "max":
out = scatter(inputs, index, 0, None, dim_size, reduce="max")
elif aggregator in ["var", "std"]:
mean = scatter(inputs, index, 0, None, dim_size, reduce="mean")
mean_squares = scatter(inputs * inputs, index, 0, None, dim_size, reduce="mean")
out = mean_squares - mean * mean
if aggregator == "std":
out = torch.sqrt(torch.relu(out) + 1e-5)
else:
raise ValueError(f'Unknown aggregator "{aggregator}".')
outs.append(out)
out = torch.cat(outs, dim=-1)
deg = degree(index, dim_size, dtype=inputs.dtype)
deg = deg.clamp_(1).view(-1, 1)
outs = []
for scaler in self.scalers:
if scaler == "identity":
out_scaler = out
elif scaler == "amplification":
out_scaler = out * (torch.log(deg + 1) / self.avg_d["log"])
elif scaler == "attenuation":
out_scaler = out * (self.avg_d["log"] / torch.log(deg + 1))
elif scaler == "linear":
out_scaler = out * (deg / self.avg_d["lin"])
elif scaler == "inverse_linear":
out_scaler = out * (self.avg_d["lin"] / deg)
else:
raise ValueError(f'Unknown scaler "{scaler}".')
outs.append(out_scaler)
return torch.cat(outs, dim=-1)
@property
def layer_outputs_edges(self) -> bool:
r"""
Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Always ``False`` for the current class
"""
return False
@property
def out_dim_factor(self) -> int:
r"""
Get the factor by which the output dimension is multiplied for
the next layer.
For standard layers, this will return ``1``.
But for others, such as ``GatLayer``, the output is the concatenation
of the outputs from each head, so the out_dim gets multiplied by
the number of heads, and this function should return the number
of heads.
Returns:
int:
Always ``1`` for the current class
"""
return 1
@property
def layer_inputs_edges(self) -> bool:
r"""
Return a boolean specifying if the layer type
uses edges as input or not.
It is different from ``layer_supports_edges`` since a layer that
supports edges can decide to not use them.
Returns:
bool:
Returns ``self.edge_features``
"""
return self.edge_features
@classproperty
def layer_supports_edges(cls) -> bool:
r"""
Return a boolean specifying if the layer type supports edges or not.
Returns:
bool:
Always ``True`` for the current class
"""
return True
================================================
FILE: graphium/nn/pyg_layers/pooling_pyg.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import torch
import torch.nn as nn
from torch import Tensor, LongTensor
from typing import List, Union, Callable, Tuple, Optional, Dict
from copy import deepcopy
from torch_scatter import scatter
from torch_geometric.data import Data, Batch
from graphium.nn.base_layers import MLP, FCLayer
from graphium.utils.tensor import ModuleListConcat, ModuleWrap
from graphium.nn.base_layers import MuReadoutGraphium
EPS = 1e-6
def scatter_logsum_pool(x: Tensor, batch: LongTensor, dim: int = 0, dim_size: Optional[int] = None) -> Tensor:
r"""
Apply pooling over the nodes in the graph using a mean aggregation,
but scaled by the log of the number of nodes. This gives the same
expressive power as the sum, but helps deal with graphs that are
significantly larger than others by using a logarithmic scale.
$$r^{(i)} = \frac{\log N_i}{N_i}\sum_{k=1}^{N_i} x^{(i)}_k$$
Parameters:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
batch (LongTensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots,
B-1\}}^N`, which assigns each node to a specific example.
size (int, optional): Batch-size :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
Returns:
the pooled features tensor
"""
dim_size = int(batch.max().detach() + 1) if dim_size is None else dim_size
mean_pool = scatter(x, batch, dim=dim, dim_size=dim_size, reduce="mean")
num_nodes = scatter(
torch.ones(x.shape[:-1], dtype=x.dtype, device=x.device),
batch,
dim=dim,
dim_size=dim_size,
reduce="sum",
)
lognum = torch.log(num_nodes)
return mean_pool * lognum.unsqueeze(-1)
def scatter_std_pool(x: Tensor, batch: LongTensor, dim: int = 0, dim_size: Optional[int] = None):
r"""Returns batch-wise graph-level-outputs by taking the channel-wise
minimum across the node dimension, so that for a single graph
:math:`\mathcal{G}_i` its output is computed by
.. math::
\mathbf{r}_i = \mathrm{max}_{n=1}^{N_i} \, \mathbf{x}_n
Parameters:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
batch (LongTensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots,
B-1\}}^N`, which assigns each node to a specific example.
size (int, optional): Batch-size :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
Returns:
the pooled features tensor
"""
dim_size = int(batch.max().detach() + 1) if dim_size is None else dim_size
mean = scatter(x, batch, dim=dim, out=None, dim_size=dim_size, reduce="mean")
mean_squares = scatter(x * x, batch, dim=dim, out=None, dim_size=dim_size, reduce="mean")
out = mean_squares - mean * mean
return torch.sqrt(torch.relu(out) + 1e-5)
class PoolingWrapperPyg(ModuleWrap):
def __init__(self, func, feat_type, *args, **kwargs) -> None:
super().__init__(func, *args, **kwargs)
self.feat_type = feat_type
def forward(self, g: Batch, feature: Tensor, *args, **kwargs):
"""
forward function
Parameters:
g: the pyg batch graph
feature: the node features
Returns:
the pooled features
"""
dim_size = g.num_graphs
if self.feat_type == "node":
index = g.batch
elif self.feat_type == "edge":
index = g.batch[g.edge_index][0]
else:
index = g.batch
return self.func(feature, index, dim_size=dim_size, *args, **kwargs, **self.kwargs)
def parse_pooling_layer_pyg(in_dim: int, pooling: Union[str, List[str]], feat_type: str = "node", **kwargs):
r"""
Select the pooling layers from a list of strings, and put them
in a Module that concatenates their outputs.
Parameters:
in_dim:
The dimension at the input layer of the pooling
pooling:
The list of pooling layers to use. The accepted strings are:
- "none": No pooling
- "sum": Sum all the nodes for each graph
- "mean": Mean all the nodes for each graph
- "logsum": Mean all the nodes then multiply by log(num_nodes) for each graph
- "max": Max all the nodes for each graph
- "min": Min all the nodes for each graph
- "std": Standard deviation of all the nodes for each graph
"""
# Create the pooling layer
pool_layer = ModuleListConcat()
out_pool_dim = 0
if isinstance(pooling, str):
pooling = [pooling]
assert feat_type in ["node", "edge", "global"]
for this_pool in pooling:
this_pool = None if this_pool is None else this_pool.lower()
out_pool_dim += in_dim
if this_pool == "sum":
pool_layer.append(PoolingWrapperPyg(scatter, dim=0, reduce="add", feat_type=feat_type, **kwargs))
elif this_pool == "mean":
pool_layer.append(PoolingWrapperPyg(scatter, dim=0, reduce="mean", feat_type=feat_type, **kwargs))
elif this_pool == "logsum":
pool_layer.append(PoolingWrapperPyg(scatter_logsum_pool, dim=0, feat_type=feat_type, **kwargs))
elif this_pool == "max":
pool_layer.append(PoolingWrapperPyg(scatter, dim=0, reduce="max", feat_type=feat_type, **kwargs))
elif this_pool == "min":
pool_layer.append(PoolingWrapperPyg(scatter, dim=0, reduce="min", feat_type=feat_type, **kwargs))
elif this_pool == "std":
pool_layer.append(PoolingWrapperPyg(scatter_std_pool, dim=0, feat_type=feat_type, **kwargs))
elif (this_pool == "none") or (this_pool is None):
pass
else:
raise NotImplementedError(f"Undefined pooling `{this_pool}`")
return pool_layer, out_pool_dim
class VirtualNodePyg(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
in_dim_edges: Optional[int],
out_dim_edges: Optional[int],
vn_type: Union[type(None), str] = "sum",
activation: Union[str, Callable] = "relu",
dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
bias: bool = True,
residual: bool = True,
use_edges: bool = False,
**kwargs,
):
r"""
The VirtualNode is a layer that pool the features of the graph,
applies a neural network layer on the pooled features,
then add the result back to the node features of every node.
Parameters:
in_dim:
Input feature dimensions of the virtual node layer.
out_dim:
Output feature dimensions of the virtual node layer.
in_dim_edges:
Input feature dimensions of the virtual node layer for the edges.
out_dim_edges:
Output feature dimensions of the virtual node layer for the edges.
vn_type:
The type of the virtual node. Choices are:
- "none": No pooling
- "sum": Sum all the nodes for each graph
- "mean": Mean all the nodes for each graph
- "logsum": Mean all the nodes then multiply by log(num_nodes) for each graph
- "max": Max all the nodes for each graph
- "min": Min all the nodes for each graph
- "std": Standard deviation of all the nodes for each graph
activation:
activation function to use in the neural network layer.
dropout:
The ratio of units to dropout. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
bias:
Whether to add a bias to the neural network
residual:
Whether all virtual nodes should be connected together
via a residual connection
use_edges:
Boolean flag to select if edges are used in the global node
aggregation and update of features
"""
super().__init__()
if (vn_type is None) or (vn_type.lower() == "none"):
self.vn_type = None
self.fc_layer = None
self.residual = None
return
# Layer sizes stored here
self.in_dim_edges = in_dim_edges
self.out_dim_edges = out_dim_edges
self.in_dim_nodes = in_dim
self.out_dim_nodes = out_dim
# Use edge features in pooling
self.use_edges = use_edges
has_edges = (
(in_dim_edges is not None)
and (out_dim_edges is not None)
and (in_dim_edges > 0)
and (out_dim_edges > 0)
)
if self.use_edges and not has_edges:
raise ValueError("The edge features are not defined but `use_edges` is True")
self.vn_type = vn_type.lower()
self.layer, out_pool_dim = parse_pooling_layer_pyg(
in_dim=self.in_dim_nodes, pooling=self.vn_type, feat_type="node"
)
self.residual = residual
if self.use_edges:
self.edge_layer, out_edge_pool_dim = parse_pooling_layer_pyg(
in_dim=self.in_dim_edges, pooling=self.vn_type, feat_type="edge"
)
out_pool_dim = out_pool_dim + out_edge_pool_dim
self.fc_layer = FCLayer(
in_dim=out_pool_dim,
out_dim=out_pool_dim,
activation=activation,
dropout=dropout,
normalization=normalization,
bias=bias,
)
# Projection layers from the pooling layer to node and edge feature sizes
self.node_projection = MuReadoutGraphium(out_pool_dim, self.out_dim_nodes)
self.edge_projection = None
if self.use_edges:
self.edge_projection = MuReadoutGraphium(out_pool_dim, self.out_dim_edges)
def forward(
self, g: Union[Data, Batch], feat: Tensor, vn_feat: LongTensor, edge_feat: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
r"""
Apply the virtual node layer.
Parameters:
g:
PyG Graphs or Batched graphs.
feat (torch.Tensor[..., N, Din]):
Node feature tensor, before convolution.
`N` is the number of nodes, `Din` is the input features
vn_feat (torch.Tensor[..., M, Din]):
Graph feature of the previous virtual node, or `None`
`M` is the number of graphs, `Din` is the input features.
It is added to the result after the MLP, as a residual connection
edge_feat (torch.Tensor[..., E, Din]):
Edge feature tensor, before convolution.
`E` is the number of edges, `Din` is the input features
Returns:
`feat = torch.Tensor[..., N, Dout]`:
Node feature tensor, after convolution and residual.
`N` is the number of nodes, `Dout` is the output features of the layer and residual
`vn_feat = torch.Tensor[..., M, Dout]`:
Graph feature tensor to be used at the next virtual node, or `None`
`M` is the number of graphs, `Dout` is the output features
`edge_feat = torch.Tensor[..., N, Dout]`:
Edge feature tensor, after convolution and residual - if edges are used, otherwise returned unchanged.
`N` is the number of edges, `Dout` is the output features of the layer and residual
"""
# Pool the features
if self.vn_type is None:
return feat, vn_feat, edge_feat
pool = self.layer(g, feat)
if self.use_edges:
edge_pool = self.edge_layer(g, edge_feat)
pool = torch.cat((pool, edge_pool), -1)
vn_h_temp = self.fc_layer.forward(vn_feat + pool)
if self.residual:
vn_feat = vn_feat + vn_h_temp
else:
vn_feat = vn_h_temp
# Add the virtual node projections to the node features
feat = feat + self.node_projection(vn_feat[g.batch])
if self.use_edges:
# Add the virtual node projections to the edge features
edge_feat = edge_feat + self.edge_projection(vn_feat[g.batch[g.edge_index][0]])
return feat, vn_feat, edge_feat
================================================
FILE: graphium/nn/pyg_layers/utils.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import math
import torch
import torch.nn as nn
from torch_geometric.data import Batch
from typing import Tuple
from torch import Tensor
from torch_geometric.typing import SparseTensor
from graphium.nn.base_layers import MLP, get_norm
from graphium.ipu.to_dense_batch import to_dense_batch, to_sparse_batch
class PreprocessPositions(nn.Module):
"""
Compute 3D attention bias and 3D node features according to the 3D position information.
"""
def __init__(
self,
num_heads,
embed_dim,
num_kernel,
in_dim=3,
num_layers=2,
activation="gelu",
first_normalization="none",
):
r"""
Parameters:
num_heads:
Number of attention heads used in self-attention.
embed_dim:
Hidden dimension of node features.
num_kernel:
Number of gaussian kernels.
num_layers: The number of layers in the MLP.
activation: The activation function used in the MLP.
first_normalization: The normalization function used before the gaussian kernel.
"""
super().__init__()
self.num_heads = num_heads
self.num_kernel = num_kernel
self.embed_dim = embed_dim
self.first_normalization = get_norm(first_normalization, dim=in_dim)
self.gaussian = GaussianLayer(self.num_kernel, in_dim=in_dim)
self.gaussian_proj = MLP(
in_dim=self.num_kernel,
hidden_dims=self.num_kernel,
out_dim=self.num_heads,
depth=num_layers,
activation=activation,
last_layer_is_readout=True, # Since the output is not proportional to the hidden dim, but to the number of heads
)
# make sure the 3D node feature has the same dimension as the embedding size
# so that it can be added to the original node features
self.node_proj = nn.Linear(self.num_kernel, self.embed_dim)
def forward(
self, batch: Batch, max_num_nodes_per_graph: int, on_ipu: bool, positions_3d_key: str
) -> Tuple[Tensor, Tensor]:
r"""
Inputs:
batch:
Batch object.
max_num_nodes_per_graph:
Maximum number of nodes per graph.
on_ipu:
If model rus on IPU.
positions_3d_key:
The key of the pyg graph object that contains the 3D positions.
"""
pos = batch[positions_3d_key]
if self.first_normalization is not None:
pos = self.first_normalization(pos)
batch_size = None if pos.device.type != "ipu" else batch.graph_is_true.shape[0]
# batch_size = None if batch.feat.device.type != "ipu" else batch.graph_is_true.shape[0] #[Andy] batch.feat is only available after passing through layers, not a good attribute to check
# pos: [batch, nodes, 3]
# padding_mask: [batch, nodes]
# idx: [totoal_nodes]
pos, mask, idx = to_dense_batch(
pos,
batch=batch.batch,
batch_size=batch_size,
max_num_nodes_per_graph=max_num_nodes_per_graph,
drop_nodes_last_graph=on_ipu,
)
# check nan with the pos from to_dense_batch,
# and generate mask. 1 for nan, 0 for other values.
# pos consists of real nodes and padding nodes
# for real nodes, if 3d position does not exit, it is nans. For padding nodes, 3d positions will be 0
# if the first node of a molecule has 3d position as nan, the whole molecule will be masked out.
# [batch]
nan_mask = torch.isnan(pos)[:, 0, 0]
# apply nan_mask on pos so that it does not give nan gradient
# when applying gaussian kernels
pos.masked_fill_(nan_mask.unsqueeze(1).unsqueeze(2), 0.0)
# we need the opposite of mask output
padding_mask = ~mask
# [batch, nodes]
batch, n_node, _ = pos.shape
# [batch, nodes, nodes, 3]
delta_pos = pos.unsqueeze(1) - pos.unsqueeze(2)
# [batch, nodes, nodes]
distance = delta_pos.norm(dim=-1).view(-1, n_node, n_node)
# [batch, nodes, nodes, num_kernel]
distance_feature = self.gaussian(distance)
# [batch, nodes, nodes, num_heads]
attn_bias = self.gaussian_proj(distance_feature)
# [batch, num_heads, nodes, nodes]
attn_bias = attn_bias.permute(0, 3, 1, 2).contiguous()
# apply padding_mask on attn_bias
# unsqueezed mask size: [batch, 1, 1, nodes] apply on tensor [batch, num_heads, nodes, nodes]
attn_bias.masked_fill_(
padding_mask.unsqueeze(1).unsqueeze(2),
float("-1000"),
)
# apply nan_mask on attn_bias
# unsqueezed mask size: [batch, 1, 1, 1] apply on tensor [batch, num_heads, nodes, nodes]
attn_bias.masked_fill_(
nan_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1),
0.0,
)
# apply padding_mask on distance_feature
# unsqueezed mask size: [batch, 1, nodes, 1] apply on tensor [batch, nodes, nodes, num_kernel]
distance_feature.masked_fill_(padding_mask.unsqueeze(1).unsqueeze(-1).to(torch.bool), 0.0)
# [batch, nodes, num_kernel]
distance_feature_sum = distance_feature.sum(dim=-2)
# Output of GaussianLayer is FP32, cast to dtype of self.node_proj here
distance_feature_sum = distance_feature_sum.to(self.node_proj.weight.dtype)
# [batch, nodes, embed_dim]
node_feature = self.node_proj(distance_feature_sum)
# apply nan_mask on node_feature
# unsqueezed mask size: [batch, 1, 1] apply on tensor [batch, nodes, embed_dim]
node_feature.masked_fill_(nan_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), 0.0)
# [total_nodes, embed_dim]
node_feature = to_sparse_batch(node_feature, idx)
return attn_bias, node_feature
class GaussianLayer(nn.Module):
def __init__(self, num_kernels=128, in_dim=3):
r"""
Gaussian kernel function that applied on the all-to-all 3D distances.
Parameters:
num_kernels:
Number of gaussian kernel used.
"""
super().__init__()
self.num_kernels = num_kernels
self.means = nn.Embedding(1, num_kernels)
self.stds = nn.Embedding(1, num_kernels)
nn.init.uniform_(self.means.weight, 0, in_dim)
nn.init.uniform_(self.stds.weight, 0, in_dim)
def forward(self, input: Tensor) -> Tensor:
# [batch, nodes, nodes, 1]
input = input.unsqueeze(-1)
# [batch, nodes, nodes, num_kernels]
expanded_input = input.expand(-1, -1, -1, self.num_kernels)
# [num_kernels]
mean = self.means.weight.float().view(-1)
# [num_kernels]
std = self.stds.weight.float().view(-1).abs() + 0.01 # epsilon is 0.01 that matches gps++ value
pre_exp_factor = (2 * math.pi) ** 0.5
# [batch, nodes, nodes, num_kernels]
tensor_with_kernel = torch.exp(-0.5 * (((expanded_input - mean) / std) ** 2)) / (pre_exp_factor * std)
return tensor_with_kernel
def triplets(
edge_index: Tensor,
num_nodes: int,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
r"""Generates triplets from the given edge indices.
A triplet is defined as a path of length two,
such that if node A is connected to node B,
and node B is connected to node C, then there is a triplet (A, B, C).
Parameters:
edge_index (LongTensor): The edge indices.
num_nodes (int): The number of nodes.
Returns:
col: The sink node indices of edges from the edge indices.
row: The source node indices of edges from the edge indices.
idx_i: The sink node indices of the triplets.
idx_j: The middle node indices of the triplets.
idx_k: The source node indices of the triplets.
idx_kj: The indices of edges those from the source node to the middle node.
idx_ji: The indices of edges those from the middle node to the sink node.
"""
row, col = edge_index # j->i
value = torch.arange(row.size(0), device=row.device)
adj_t = SparseTensor(row=col, col=row, value=value, sparse_sizes=(num_nodes, num_nodes))
adj_t_row = adj_t[row]
num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)
# Node indices (k->j->i) for triplets.
idx_i = col.repeat_interleave(num_triplets)
idx_j = row.repeat_interleave(num_triplets)
idx_k = adj_t_row.storage.col()
# Remove self-loop triplets d->b->d
mask = idx_i != idx_k # Remove i == k triplets.
idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]
# Edge indices (k->j, j->i) for triplets.
idx_kj = adj_t_row.storage.value()[mask]
idx_ji = adj_t_row.storage.row()[mask]
return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji
================================================
FILE: graphium/nn/residual_connections.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
"""
Different types of residual connections, including None, Simple (ResNet-like),
Concat and DenseNet
"""
import abc
import torch
import torch.nn as nn
from typing import List, Union, Callable
from graphium.nn.base_layers import FCLayer
from graphium.utils.decorators import classproperty
class ResidualConnectionBase(nn.Module):
def __init__(self, skip_steps: int = 1):
r"""
Abstract class for the residual connections. Using this class,
we implement different types of residual connections, such as
the ResNet, weighted-ResNet, skip-concat and DensNet.
The following methods must be implemented in a children class
- ``h_dim_increase_type()``
- ``has_weights()``
Parameters:
skip_steps: int
The number of steps to skip between the residual connections.
If `1`, all the layers are connected. If `2`, half of the
layers are connected.
"""
super().__init__()
self.skip_steps = skip_steps
def _bool_apply_skip_step(self, step_idx: int):
r"""
Whether to apply the skip connection, depending on the
``step_idx`` and ``self.skip_steps``.
Parameters:
step_idx: int
The current layer step index.
"""
return (self.skip_steps != 0) and ((step_idx % self.skip_steps) == 0)
def __repr__(self):
r"""
Controls how the class is printed
"""
return f"{self.__class__.__name__}(skip_steps={self.skip_steps})"
@classproperty
@abc.abstractmethod
def h_dim_increase_type(cls):
r"""
How does the dimension of the output features increases after each layer?
Returns:
h_dim_increase_type: None or str
- ``None``: The dimension of the output features do not change at each layer.
E.g. ResNet.
- "previous": The dimension of the output features is the concatenation of
the previous layer with the new layer.
- "cumulative": The dimension of the output features is the concatenation
of all previous layers.
"""
...
def get_true_out_dims(self, out_dims: List) -> List:
r"""
find the true output dimensions
Parameters:
out_dims: List
Returns:
true_out_dims: List
"""
true_out_dims = [out_dims[0]]
out_dims_at_skip = [out_dims[0]]
for ii in range(1, len(out_dims) - 1):
# For the `None` type, don't change the output dims
if self.h_dim_increase_type is None:
true_out_dims.append(out_dims[ii])
# For the "previous" type, add the previous layers when the skip connection applies
elif self.h_dim_increase_type == "previous":
if self._bool_apply_skip_step(step_idx=ii):
true_out_dims.append(out_dims[ii] + out_dims_at_skip[-1])
out_dims_at_skip.append(out_dims[ii])
else:
true_out_dims.append(out_dims[ii])
# For the "cumulative" type, add all previous layers when the skip connection applies
elif self.h_dim_increase_type == "cumulative":
if self._bool_apply_skip_step(step_idx=ii):
true_out_dims.append(out_dims[ii] + out_dims_at_skip[-1])
out_dims_at_skip.append(true_out_dims[ii])
else:
true_out_dims.append(out_dims[ii])
else:
raise ValueError(f"undefined value: {self.h_dim_increase_type}")
return true_out_dims
@classproperty
@abc.abstractmethod
def has_weights(cls):
r"""
Returns:
has_weights: bool
Whether the residual connection uses weights
"""
...
class ResidualConnectionNone(ResidualConnectionBase):
r"""
No residual connection.
This class is only used for simpler code compatibility
"""
def __init__(self, skip_steps: int = 1):
super().__init__(skip_steps=skip_steps)
def __repr__(self):
r"""
Controls how the class is printed
"""
return f"{self.__class__.__name__}"
@classproperty
def h_dim_increase_type(cls):
r"""
Returns:
None:
The dimension of the output features do not change at each layer.
"""
return None
@classproperty
def has_weights(cls):
r"""
Returns:
False
The current class does not use weights
"""
return False
def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int):
r"""
Ignore the skip connection.
Returns:
h: torch.Tensor(..., m)
Return same as input.
h_prev: torch.Tensor(..., m)
Return same as input.
"""
return h, h_prev
class ResidualConnectionSimple(ResidualConnectionBase):
def __init__(self, skip_steps: int = 1):
r"""
Class for the simple residual connections proposed by ResNet,
where the current layer output is summed to a
previous layer output.
Parameters:
skip_steps: int
The number of steps to skip between the residual connections.
If `1`, all the layers are connected. If `2`, half of the
layers are connected.
"""
super().__init__(skip_steps=skip_steps)
@classproperty
def h_dim_increase_type(cls):
r"""
Returns:
None:
The dimension of the output features do not change at each layer.
"""
return None
@classproperty
def has_weights(cls):
r"""
Returns:
False
The current class does not use weights
"""
return False
def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int):
r"""
Add ``h`` with the previous layers with skip connection ``h_prev``,
similar to ResNet.
Parameters:
h: torch.Tensor(..., m)
The current layer features
h_prev: torch.Tensor(..., m), None
The features from the previous layer with a skip connection.
At ``step_idx==0``, ``h_prev`` can be set to ``None``.
step_idx: int
Current layer index or step index in the forward loop of the architecture.
Returns:
h: torch.Tensor(..., m)
Either return ``h`` unchanged, or the sum with
on ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``.
h_prev: torch.Tensor(..., m)
Either return ``h_prev`` unchanged, or the same value as ``h``,
depending on the ``step_idx`` and ``self.skip_steps``.
"""
if self._bool_apply_skip_step(step_idx):
if step_idx > 0:
h = h + h_prev
h_prev = h
return h, h_prev
class ResidualConnectionWeighted(ResidualConnectionBase):
def __init__(
self,
out_dims,
skip_steps: int = 1,
dropout=0.0,
activation: Union[str, Callable] = "none",
normalization="none",
bias=False,
):
r"""
Class for the simple residual connections proposed by ResNet,
with an added layer in the residual connection itself.
The layer output is summed to a a non-linear transformation
of a previous layer output.
Parameters:
skip_steps: int
The number of steps to skip between the residual connections.
If `1`, all the layers are connected. If `2`, half of the
layers are connected.
out_dims: list(int)
list of all output dimensions for the network
that will use this residual connection.
E.g. ``out_dims = [4, 8, 8, 8, 2]``.
dropout: float
value between 0 and 1.0 representing the percentage of dropout
to use in the weights
activation: str, Callable
The activation function to use after the skip weights
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization in the hidden layers.
- `Callable`: Any callable function
bias: bool
Whether to apply add a bias after the weights
"""
super().__init__(skip_steps=skip_steps)
self.residual_list = nn.ModuleList()
self.skip_count = 0
self.out_dims = out_dims
for ii in range(0, len(self.out_dims) - 1, self.skip_steps):
this_dim = self.out_dims[ii]
self.residual_list.append(
FCLayer(
this_dim,
this_dim,
activation=activation,
dropout=dropout,
normalization=normalization,
bias=False,
)
)
@classproperty
def h_dim_increase_type(cls):
r"""
Returns:
None:
The dimension of the output features do not change at each layer.
"""
return None
@classproperty
def has_weights(cls):
r"""
Returns:
True
The current class uses weights
"""
return True
def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int):
r"""
Add ``h`` with the previous layers with skip connection ``h_prev``, after
a feed-forward layer.
Parameters:
h: torch.Tensor(..., m)
The current layer features
h_prev: torch.Tensor(..., m), None
The features from the previous layer with a skip connection.
At ``step_idx==0``, ``h_prev`` can be set to ``None``.
step_idx: int
Current layer index or step index in the forward loop of the architecture.
Returns:
h: torch.Tensor(..., m)
Either return ``h`` unchanged, or the sum with the output of a NN layer
on ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``.
h_prev: torch.Tensor(..., m)
Either return ``h_prev`` unchanged, or the same value as ``h``,
depending on the ``step_idx`` and ``self.skip_steps``.
"""
if self._bool_apply_skip_step(step_idx):
if step_idx > 0:
h = h + self.residual_list[self.skip_count].forward(h_prev)
self.skip_count += 1
h_prev = h
return h, h_prev
def _bool_apply_skip_step(self, step_idx: int):
return super()._bool_apply_skip_step(step_idx) and self.skip_count < len(self.residual_list)
class ResidualConnectionConcat(ResidualConnectionBase):
def __init__(self, skip_steps: int = 1):
r"""
Class for the simple residual connections proposed but where
the skip connection features are concatenated to the current
layer features.
Parameters:
skip_steps: int
The number of steps to skip between the residual connections.
If `1`, all the layers are connected. If `2`, half of the
layers are connected.
"""
super().__init__(skip_steps=skip_steps)
@classproperty
def h_dim_increase_type(cls):
r"""
Returns:
"previous":
The dimension of the output layer is the concatenation with the previous layer.
"""
return "previous"
@classproperty
def has_weights(cls):
r"""
Returns:
False
The current class does not use weights
"""
return False
def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int):
r"""
Concatenate ``h`` with the previous layers with skip connection ``h_prev``.
Parameters:
h: torch.Tensor(..., m)
The current layer features
h_prev: torch.Tensor(..., n), None
The features from the previous layer with a skip connection.
Usually, we have ``n`` equal to ``m``.
At ``step_idx==0``, ``h_prev`` can be set to ``None``.
step_idx: int
Current layer index or step index in the forward loop of the architecture.
Returns:
h: torch.Tensor(..., m) or torch.Tensor(..., m + n)
Either return ``h`` unchanged, or the concatenation
with ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``.
h_prev: torch.Tensor(..., m) or torch.Tensor(..., m + n)
Either return ``h_prev`` unchanged, or the same value as ``h``,
depending on the ``step_idx`` and ``self.skip_steps``.
"""
if self._bool_apply_skip_step(step_idx):
h_in = h
if step_idx > 0:
h = torch.cat([h, h_prev], dim=-1)
h_prev = h_in
return h, h_prev
class ResidualConnectionDenseNet(ResidualConnectionBase):
def __init__(self, skip_steps: int = 1):
r"""
Class for the residual connections proposed by DenseNet, where
all previous skip connection features are concatenated to the current
layer features.
Parameters:
skip_steps: int
The number of steps to skip between the residual connections.
If `1`, all the layers are connected. If `2`, half of the
layers are connected.
"""
super().__init__(skip_steps=skip_steps)
@classproperty
def h_dim_increase_type(cls):
r"""
Returns:
"cumulative":
The dimension of the output layer is the concatenation of all the previous layer.
"""
return "cumulative"
@classproperty
def has_weights(cls):
r"""
Returns:
False
The current class does not use weights
"""
return False
def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int):
r"""
Concatenate ``h`` with all the previous layers with skip connection ``h_prev``.
Parameters:
h: torch.Tensor(..., m)
The current layer features
h_prev: torch.Tensor(..., n), None
The features from the previous layers.
n = ((step_idx // self.skip_steps) + 1) * m
At ``step_idx==0``, ``h_prev`` can be set to ``None``.
step_idx: int
Current layer index or step index in the forward loop of the architecture.
Returns:
h: torch.Tensor(..., m) or torch.Tensor(..., m + n)
Either return ``h`` unchanged, or the concatenation
with ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``.
h_prev: torch.Tensor(..., m) or torch.Tensor(..., m + n)
Either return ``h_prev`` unchanged, or the same value as ``h``,
depending on the ``step_idx`` and ``self.skip_steps``.
"""
if self._bool_apply_skip_step(step_idx):
if step_idx > 0:
h = torch.cat([h, h_prev], dim=-1)
h_prev = h
return h, h_prev
class ResidualConnectionRandom(ResidualConnectionBase):
def __init__(self, skip_steps=1, out_dims: List[int] = None, num_layers: int = None):
r"""
Class for the random residual connection, where each layer is connected
to each following layer with a random weight between 0 and 1.
Parameters:
skip_steps:
Parameter only there for compatibility with other classes of the same parent.
out_dims:
The list of output dimensions. Only required to get the number
of layers. Must be provided if `num_layers` is None.
num_layers:
The number of layers. Must be provided if `out_dims` is None.
"""
if skip_steps != 1:
raise ValueError("Only `skip_step=1` is implemented")
super().__init__(skip_steps=skip_steps)
if out_dims is not None:
if num_layers is not None:
assert num_layers == len(out_dims)
num_layers = len(out_dims)
if num_layers is None:
raise ValueError("Either `out_dims` or `num_layers` must be provided")
self.num_layers = num_layers
self.random_dict_weights = {}
for ii in range(1, self.num_layers):
random_weights = torch.rand(ii)
self.random_dict_weights[ii] = random_weights
@classproperty
def h_dim_increase_type(cls):
r"""
Returns:
None:
The dimension of the output features do not change at each layer.
"""
return None
@classproperty
def has_weights(cls):
r"""
Returns:
False
The current class does not use weights
"""
return False
def forward(self, h: torch.Tensor, h_prev: torch.Tensor, step_idx: int):
r"""
Add ``h`` with the previous layers with skip connection ``h_prev``,
similar to ResNet.
Parameters:
h: torch.Tensor(..., m)
The current layer features
h_prev: torch.Tensor(..., m), None
The features from the previous layer with a skip connection.
At ``step_idx==0``, ``h_prev`` can be set to ``None``.
step_idx: int
Current layer index or step index in the forward loop of the architecture.
Returns:
h: torch.Tensor(..., m)
Either return ``h`` unchanged, or the sum with
on ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``.
h_prev: torch.Tensor(..., m)
Either return ``h_prev`` unchanged, or the same value as ``h``,
depending on the ``step_idx`` and ``self.skip_steps``.
"""
if self._bool_apply_skip_step(step_idx):
for i in range(0, step_idx):
h += (
self.random_dict_weights[step_idx][i].to(dtype=h_prev[i].dtype, device=h_prev[i].device)
* h_prev[i]
)
if h_prev is None:
h_prev = [h]
else:
h_prev.append(h)
return h, h_prev
================================================
FILE: graphium/nn/utils.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import abc
import inspect
from numbers import Real
from typing import Optional
class MupMixin(abc.ABC):
@abc.abstractmethod
def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: Optional[bool] = None):
"""
Create a 'base' model to be used by the `mup` or `muTransfer` scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
This is done using the `scale_kwargs()` method with `scale_factor = 1 / divide_factor`.
Parameter:
divide_factor: Factor by which to divide the width.
factor_in_dim: Whether to factor the input dimension for the nodes. If None, the default for scale_kwargs is used
Returns:
kwargs: Dictionary of parameters to be used to instanciate the base model divided by the factor
"""
...
def scale_kwargs(self, scale_factor: Real, scale_in_dim: bool = False):
"""
Create a "scaled" version of the module where the hidden dims are scaled as in muTransfer.
This can be used with `scale_factor` < 1 to create a "base" model to extract shape
information as in the `mup` package or with `scale_factor` > 1 to create a scaled model
to which optimal hyperparameters can be "muTransferred" from the original model
Parameters:
scale_factor: Factor by which to scale the width.
scale_in_dim: Whether to factor the input dimension for the nodes
Returns:
kwargs: Dictionary of parameters to be used to instantiate the scaled model
"""
divide_factor = 1 / scale_factor
if not scale_in_dim:
return self.make_mup_base_kwargs(divide_factor=divide_factor)
# If scale_in_dim passed, need to check it can be forwarded
try:
return self.make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=scale_in_dim)
except TypeError as e:
raise RuntimeError(
"This error may have been caused by passing scale_in_dim to scale_kwargs "
"for a class that does not support passing factor_in_dim to make_mup_base_kwargs, "
"which cannot be done"
) from e
================================================
FILE: graphium/trainer/README.md
================================================
The Graph Of LIfe Library.
## What is in this folder?
code for the metric, training and trainer
- ✅ `predictor.py`: the `PredictorModule` class is the main class for the trainer
- `metrics.py`: metrics for the task heads
- `predictor_options.py`: options for the predictor, intervals, tracking min/max etc.
- `predictor_summaries.py`: summaries for the task to track loss and metric
================================================
FILE: graphium/trainer/__init__.py
================================================
from . import predictor
from . import metrics
from .predictor import PredictorModule
================================================
FILE: graphium/trainer/losses.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Optional
import torch
from torch import Tensor
from torch.nn import functional as F
from torch.nn.modules.loss import _WeightedLoss
class HybridCELoss(_WeightedLoss):
def __init__(
self,
n_brackets,
regression_loss: str = "mse",
alpha: float = 0.5,
weight: Optional[Tensor] = None,
reduction: str = "mean",
) -> None:
"""
A hybrid between the regression loss (either MAE or MSE) and the cross entropy loss. Intended
to be used with noisy regression datasets, for which the targets are assigned to binary brackets,
and the task is transformed into a multi-class classification.
Note that it assumes that the brackets are consecutive integers starting at 0 up to n_brackets,
which has an impact on the scale of the regression loss component.
Parameters:
n_brackets: the number of brackets that will be used to group the regression targets.
Expected to have the same size as the number of classes in the transformed regression task.
regression_loss: type of regression loss, either 'mse' or 'mae'.
alpha: weight assigned to the CE loss component. Must be a value in [0, 1] range.
weight: a manual rescaling weight given to each class in the CE loss component.
If given, has to be a Tensor of the same size as the number of classes.
reduction: specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
'none': no reduction will be applied, 'mean': the sum of the output will be divided
by the number of elements in the output, 'sum': the output will be summed.
"""
super().__init__(weight=weight, reduction=reduction)
if regression_loss == "mae":
self.regression_loss = F.l1_loss
elif regression_loss == "mse":
self.regression_loss = F.mse_loss
else:
raise ValueError(
f"Expected regression_loss to be in {{'mae', 'mse'}}, received {regression_loss}."
)
if alpha < 0 or alpha > 1:
raise ValueError(f"Expected alpha to be in the [0, 1] range, received {alpha}.")
self.brackets = Tensor(range(n_brackets))
self.alpha = alpha
self.softmax = torch.nn.Softmax(dim=1)
def forward(self, input: Tensor, target: Tensor, nan_targets: Tensor = None) -> Tensor:
"""
Parameters:
input: (batch_size x n_classes) tensor of logits predicted for each bracket.
target: (batch_size) or (batch_size, 1) tensor of target brackets in {0, 1, ..., self.n_brackets}.
"""
target = target.flatten()
# set input and target with nans to 0s for regression loss
if nan_targets is not None:
input = torch.masked_fill(input, nan_targets.unsqueeze(1), 0)
target = torch.masked_fill(target, nan_targets, 0)
# regression loss needs normalized logits to probability as input to do inner product with self.brackets
# we apply softmax on the raw logits first
softmax_input = self.softmax(input)
# the softmax of a tensor of 0s would not be 0s anymore, so need to apply nan_targets here again to filter out
if nan_targets is not None:
softmax_input = torch.masked_fill(softmax_input, nan_targets.unsqueeze(1), 0.0)
# [batch_size, n_classes] * [n_classes] ([0, 1, 2...n_brakets-1]) -> [batch_size]
regression_input = torch.inner(softmax_input.to(self.brackets.dtype), self.brackets.to(input.device))
regression_loss = self.regression_loss(regression_input, target.float(), reduction=self.reduction)
# regression_loss needs some scaling by total_targets/num_real_targets
if nan_targets is not None:
num_real_targets = (~nan_targets).sum()
factor1 = torch.where(num_real_targets > 0, 1, 0)
factor2 = torch.where(num_real_targets > 0, 0, 1)
regression_loss = factor1 * regression_loss * nan_targets.numel() / (num_real_targets + factor2)
# set input and target with nans to -1000s for ce loss
input = torch.masked_fill(input, nan_targets.unsqueeze(1), -1000)
target = torch.masked_fill(target, nan_targets, -1000)
# cross_entropy loss needs raw logits as input
# ce_loss does not need scaling as it already ignores -1000 masked nan values
ce_loss = F.cross_entropy(
input, target.long(), weight=self.weight, ignore_index=-1000, reduction=self.reduction
)
return self.alpha * ce_loss + (1 - self.alpha) * regression_loss
================================================
FILE: graphium/trainer/metrics.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Union, Callable, Optional, Dict, Any
import sys
import torch
from torch import Tensor
import operator as op
from torchmetrics.utilities.distributed import reduce
import torchmetrics.functional.regression.mae
from graphium.utils.tensor import nan_mean
# NOTE(hadim): the below is a fix to be able to import previously saved Graphium model that are incompatible
# with the current version of torchmetrics.
# In the future, we should NOT save any torchmetrics objects during serialization.
# See https://github.com/valence-discovery/graphium/issues/106
sys.modules["torchmetrics.functional.regression.mean_absolute_error"] = torchmetrics.functional.regression.mae
EPS = 1e-5
class Thresholder:
def __init__(
self,
threshold: float,
operator: Union[str, Callable] = "greater",
th_on_preds: bool = True,
th_on_target: bool = False,
):
# Basic params
self.threshold = threshold
self.th_on_target = th_on_target
self.th_on_preds = th_on_preds
self.operator, self.op_str = self._get_operator(operator)
def compute(self, preds: Tensor, target: Tensor):
# Apply the threshold on the predictions
if self.th_on_preds:
preds = self.operator(preds, self.threshold)
# Apply the threshold on the targets
if self.th_on_target:
target = self.operator(target, self.threshold)
return preds, target
def __call__(self, preds: Tensor, target: Tensor):
return self.compute(preds, target)
def __repr__(self):
r"""
Control how the class is printed
"""
return f"{self.op_str}{self.threshold}"
@staticmethod
def _get_operator(operator):
"""Operator can either be a string, or a callable"""
if isinstance(operator, str):
op_name = operator.lower()
if op_name in ["greater", "gt"]:
op_str = ">"
operator = op.gt
elif op_name in ["lower", "lt"]:
op_str = "<"
operator = op.lt
else:
raise ValueError(f"operator `{op_name}` not supported")
elif callable(operator):
op_str = operator.__name__
if op_str == "lt":
op_str = "<"
elif op_str == "gt":
op_str = ">"
elif operator is None:
pass
else:
raise TypeError(f"operator must be either `str` or `callable`, provided: `{type(operator)}`")
return operator, op_str
def __getstate__(self):
"""Serialize the class for pickling."""
state = {}
state["threshold"] = self.threshold
state["th_on_target"] = self.th_on_target
state["th_on_preds"] = self.th_on_preds
# Set the operator state.
# If it's a callable, it's up to the user to ensure it unserializes well
operator = self.operator
if operator == op.lt:
operator = "lower"
elif operator == op.gt:
operator = "greater"
state["operator"] = operator
return state
def __setstate__(self, state: dict):
state["operator"], state["op_str"] = self._get_operator(state["operator"])
self.__dict__.update(state)
self.__init__(**state)
def __eq__(self, obj) -> bool:
is_eq = [
self.threshold == obj.threshold,
self.th_on_target == obj.th_on_target,
self.th_on_preds == obj.th_on_preds,
self.operator == obj.operator,
self.op_str == obj.op_str,
]
return all(is_eq)
class MetricWrapper:
r"""
Allows to initialize a metric from a name or Callable, and initialize the
`Thresholder` in case the metric requires a threshold.
"""
def __init__(
self,
metric: Union[str, Callable],
threshold_kwargs: Optional[Dict[str, Any]] = None,
target_nan_mask: Optional[Union[str, int]] = None,
multitask_handling: Optional[str] = None,
squeeze_targets: bool = False,
target_to_int: bool = False,
**kwargs,
):
r"""
Parameters
metric:
The metric to use. See `METRICS_DICT`
threshold_kwargs:
If `None`, no threshold is applied.
Otherwise, we use the class `Thresholder` is initialized with the
provided argument, and called before the `compute`
target_nan_mask:
- None: Do not change behaviour if there are NaNs
- int, float: Value used to replace NaNs. For example, if `target_nan_mask==0`, then
all NaNs will be replaced by zeros
- 'ignore': The NaN values will be removed from the tensor before computing the metrics.
Must be coupled with the `multitask_handling='flatten'` or `multitask_handling='mean-per-label'`.
multitask_handling:
- None: Do not process the tensor before passing it to the metric.
Cannot use the option `multitask_handling=None` when `target_nan_mask=ignore`.
Use either 'flatten' or 'mean-per-label'.
- 'flatten': Flatten the tensor to produce the equivalent of a single task
- 'mean-per-label': Loop all the labels columns, process them as a single task,
and average the results over each task
*This option might slow down the computation if there are too many labels*
squeeze_targets:
If true, targets will be squeezed prior to computing the metric.
Required in classifigression task.
target_to_int:
If true, targets will be converted to integers prior to computing the metric.
kwargs:
Other arguments to call with the metric
"""
self.metric, self.metric_name = self._get_metric(metric)
self.thresholder = None
if threshold_kwargs is not None:
self.thresholder = Thresholder(**threshold_kwargs)
self.target_nan_mask = self._parse_target_nan_mask(target_nan_mask)
self.multitask_handling = self._parse_multitask_handling(multitask_handling, self.target_nan_mask)
self.squeeze_targets = squeeze_targets
self.target_to_int = target_to_int
self.kwargs = kwargs
@staticmethod
def _parse_target_nan_mask(target_nan_mask):
"""
Parse the `target_nan_mask` parameter
"""
if (target_nan_mask is None) or isinstance(target_nan_mask, (int, float)):
# None, int, float, are accepted
pass
elif isinstance(target_nan_mask, Tensor) and (target_nan_mask.numel() == 1):
# Only single element tensors are accepted
target_nan_mask = target_nan_mask.flatten()[0]
elif isinstance(target_nan_mask, str):
# Only a few str options are accepted
target_nan_mask = target_nan_mask.lower()
accepted_str = ["ignore", "none"]
assert (
target_nan_mask in accepted_str
), f"Provided {target_nan_mask} not in accepted_str={accepted_str}"
if target_nan_mask == "none":
target_nan_mask = None
else:
raise ValueError(f"Unrecognized option `target_nan_mask={target_nan_mask}`")
return target_nan_mask
@staticmethod
def _parse_multitask_handling(multitask_handling, target_nan_mask):
"""
Parse the `multitask_handling` parameter
"""
if multitask_handling is None:
# None is accepted
pass
elif isinstance(multitask_handling, str):
# Only a few str options are accepted
multitask_handling = multitask_handling.lower()
accepted_str = ["flatten", "mean-per-label", "none"]
assert (
multitask_handling in accepted_str
), f"Provided {multitask_handling} not in accepted_str={accepted_str}"
if multitask_handling == "none":
multitask_handling = None
else:
raise ValueError(f"Unrecognized option `multitask_handling={multitask_handling}`")
if (target_nan_mask == "ignore") and (multitask_handling is None):
raise ValueError(
"Cannot use the option `multitask_handling=None` when `target_nan_mask=ignore`. Use either 'flatten' or 'mean-per-label'"
)
return multitask_handling
@staticmethod
def _get_metric(metric):
from graphium.utils.spaces import METRICS_DICT
if isinstance(metric, str):
metric_name = metric
metric = METRICS_DICT[metric]
else:
metric_name = None
metric = metric
return metric, metric_name
def compute(self, preds: Tensor, target: Tensor) -> Tensor:
r"""
Compute the metric, apply the thresholder if provided, and manage the NaNs
"""
if preds.ndim == 1:
preds = preds.unsqueeze(-1)
if target.ndim == 1:
target = target.unsqueeze(-1)
# Threshold the prediction
if self.thresholder is not None:
preds, target = self.thresholder(preds, target)
target_nans = torch.isnan(target)
# for the classifigression task, cast predictions from
# (batch_size, n_targets * n_brackets) to (batch_size, n_targets, n_brackets)
# TODO: make this more flexible to the target shape in the future
if preds.shape[1] != target.shape[1]:
preds = preds.view(target.shape[0], target.shape[1], -1)
classifigression = True
else:
classifigression = False
if self.multitask_handling is None:
# In case of no multi-task handling, apply the nan filtering, then compute the metrics
assert (
self.target_nan_mask != "ignore"
), f"Cannot use the option `multitask_handling=None` when `target_nan_mask=ignore`. Use either 'flatten' or 'mean-per-label'"
preds, target = self._filter_nans(preds, target)
if self.squeeze_targets:
target = target.squeeze()
if self.target_to_int:
target = target.to(int)
metric_val = self.metric(preds, target, **self.kwargs)
elif self.multitask_handling == "flatten":
# Flatten the tensors, apply the nan filtering, then compute the metrics
if classifigression:
preds = preds.view(-1, preds.shape[-1])
target = target.flatten()
else:
preds, target = preds.flatten(), target.flatten()
preds, target = self._filter_nans(preds, target)
if self.squeeze_targets:
target = target.squeeze()
if self.target_to_int:
target = target.to(int)
metric_val = self.metric(preds, target, **self.kwargs)
elif self.multitask_handling == "mean-per-label":
# Loop the columns (last dim) of the tensors, apply the nan filtering, compute the metrics per column, then average the metrics
target_list = [target[..., ii][~target_nans[..., ii]] for ii in range(target.shape[-1])]
# TODO: make this more flexible to the target shape in the future
if classifigression:
preds_list = [preds[..., i, :][~target_nans[..., i]] for i in range(preds.shape[1])]
else:
preds_list = [preds[..., ii][~target_nans[..., ii]] for ii in range(preds.shape[-1])]
metric_val = []
for ii in range(len(target_list)):
try:
this_preds, this_target = self._filter_nans(preds_list[ii], target_list[ii])
if self.squeeze_targets:
this_target = this_target.squeeze()
if self.target_to_int:
this_target = this_target.to(int)
metric_val.append(self.metric(this_preds, this_target, **self.kwargs))
except:
pass
# Average the metric
metric_val = nan_mean(torch.stack(metric_val))
else:
# Wrong option
raise ValueError(f"Invalid option `self.multitask_handling={self.multitask_handling}`")
return metric_val
def _filter_nans(self, preds: Tensor, target: Tensor):
"""Handle the NaNs according to the chosen options"""
target_nans = torch.isnan(target)
if self.target_nan_mask is None:
pass
elif isinstance(self.target_nan_mask, (int, float)):
target = target.clone()
target[torch.isnan(target)] = self.target_nan_mask
elif self.target_nan_mask == "ignore":
target = target[~target_nans]
preds = preds[~target_nans]
else:
raise ValueError(f"Invalid option `{self.target_nan_mask}`")
return preds, target
def __call__(self, preds: Tensor, target: Tensor) -> Tensor:
r"""
Compute the metric with the method `self.compute`
"""
return self.compute(preds, target)
def __repr__(self):
r"""
Control how the class is printed
"""
full_str = f"{self.metric.__name__}"
if self.thresholder is not None:
full_str += f"({self.thresholder})"
return full_str
def __eq__(self, obj) -> bool:
is_eq = [
self.metric == obj.metric,
self.metric_name == obj.metric_name,
self.thresholder == obj.thresholder,
self.target_nan_mask == obj.target_nan_mask,
self.multitask_handling == obj.multitask_handling,
self.squeeze_targets == obj.squeeze_targets,
self.target_to_int == obj.target_to_int,
self.kwargs == obj.kwargs,
]
return all(is_eq)
def __getstate__(self):
"""Serialize the class for pickling."""
state = {}
if self.metric_name is None:
state["metric"] = self.metric
else:
state["metric"] = self.metric_name
state["target_nan_mask"] = self.target_nan_mask
state["multitask_handling"] = self.multitask_handling
state["squeeze_targets"] = self.squeeze_targets
state["target_to_int"] = self.target_to_int
state["kwargs"] = self.kwargs
state["threshold_kwargs"] = None
if self.thresholder is not None:
state["threshold_kwargs"] = self.thresholder.__getstate__()
return state
def __setstate__(self, state: dict):
"""Reload the class from pickling."""
state["metric"], state["metric_name"] = self._get_metric(state["metric"])
thresholder = state.pop("threshold_kwargs", None)
if thresholder is not None:
thresholder = Thresholder(**thresholder)
state["thresholder"] = thresholder
self.__dict__.update(state)
================================================
FILE: graphium/trainer/predictor.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import time
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import lightning
import numpy as np
import torch
from loguru import logger
from mup.optim import MuAdam
from torch import Tensor, nn
from torch_geometric.data import Batch, Data
from graphium.config.config_convert import recursive_config_reformating
from graphium.data.datamodule import BaseDataModule
from graphium.trainer.metrics import MetricWrapper
from graphium.trainer.predictor_options import (
EvalOptions,
FlagOptions,
ModelOptions,
OptimOptions,
)
from graphium.trainer.predictor_summaries import TaskSummaries
from graphium.utils import fs
from graphium.utils.moving_average_tracker import MovingAverageTracker
from graphium.utils.spaces import GRAPHIUM_PRETRAINED_MODELS_DICT
from graphium.utils.tensor import dict_tensor_fp16_to_fp32
class PredictorModule(lightning.LightningModule):
def __init__(
self,
model_class: Type[nn.Module],
model_kwargs: Dict[str, Any],
loss_fun: Dict[str, Union[str, Callable]],
task_levels: Dict[str, str],
random_seed: int = 42,
featurization: Dict[str, str] = None,
optim_kwargs: Optional[Dict[str, Any]] = None,
torch_scheduler_kwargs: Optional[Dict[str, Any]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
target_nan_mask: Optional[Union[str, int]] = None,
multitask_handling: Optional[str] = None,
metrics: Dict[str, Callable] = None,
metrics_on_progress_bar: Dict[str, List[str]] = [],
metrics_on_training_set: Optional[Dict[str, List[str]]] = None,
flag_kwargs: Dict[str, Any] = None,
task_norms: Optional[Dict[Callable, Any]] = None,
metrics_every_n_train_steps: Optional[int] = None,
replicas: int = 1,
gradient_acc: int = 1,
global_bs: Optional[int] = 1,
):
"""
The Lightning module responsible for handling the predictions, losses, metrics, optimization, etc.
It works in a multi-task setting, with different losses and metrics per class
Parameters:
model_class: The torch Module containing the main forward function
model_kwargs: The arguments to initialize the model from `model_class`
loss_fun: A `dict[str, fun]`, where `str` is the task name and `fun` the loss function
random_seed: The seed for random initialization
optim_kwargs: The optimization arguments. See class `OptimOptions`
torch_scheduler_kwargs: The torch scheduler arguments. See class `OptimOptions`
scheduler_kwargs: The lightning scheduler arguments. See class `OptimOptions`
target_nan_mask: How to handle the NaNs. See `MetricsWrapper` for options
metrics: A `dict[str, fun]`, where `str` is the task name and `fun` the metric function
metrics_on_progress_bar: A `dict[str, list[str2]`, where `str` is the task name and `str2` the metrics to include on the progress bar
metrics_on_training_set: A `dict[str, list[str2]`, where `str` is the task name and `str2` the metrics to include on the training set
flag_kwargs: Arguments related to using the FLAG adversarial augmentation
task_norms: the normalization for each task
metrics_every_n_train_steps: Compute and log metrics every n training steps.
Set to `None` to never log the training metrics and statistics (the default). Set to `1` to log at every step.
"""
self.save_hyperparameters()
self.random_seed = random_seed
torch.random.manual_seed(self.random_seed)
np.random.seed(self.random_seed)
self.target_nan_mask = target_nan_mask
self.multitask_handling = multitask_handling
self.task_levels = task_levels
self.featurization = featurization
self.task_norms = task_norms
super().__init__()
# Setting the model options
self.model_kwargs = model_kwargs
self._model_options = ModelOptions(model_class=model_class, model_kwargs=model_kwargs)
# Setting the optimizer options
self.optim_options = OptimOptions(
optim_kwargs=optim_kwargs,
torch_scheduler_kwargs=torch_scheduler_kwargs,
scheduler_kwargs=scheduler_kwargs,
)
# Setting the evaluation options
eval_options = {}
for task in loss_fun:
eval_options[task] = EvalOptions(
loss_fun=loss_fun[task],
metrics=metrics[task],
metrics_on_progress_bar=metrics_on_progress_bar[task],
metrics_on_training_set=(
metrics_on_training_set[task] if metrics_on_training_set is not None else None
),
)
eval_options[task].check_metrics_validity()
self._eval_options_dict: Dict[str, EvalOptions] = eval_options
self._eval_options_dict = {
self._get_task_key(task_level=task_levels[key], task=key): value
for key, value in self._eval_options_dict.items()
}
# Setting the flag options
self._flag_options = FlagOptions(flag_kwargs=flag_kwargs)
self.model = self._model_options.model_class(**self._model_options.model_kwargs)
loss_fun = {
self._get_task_key(task_level=task_levels[key], task=key): value
for key, value in loss_fun.items()
}
self.tasks = list(loss_fun.keys())
# Task-specific evalutation attributes
self.loss_fun = {}
self.metrics = {}
self.metrics_on_progress_bar = {}
self.metrics_on_training_set = {}
for task in self.tasks:
self.loss_fun[task] = EvalOptions.parse_loss_fun(loss_fun[task])
self.metrics[task] = (
self._eval_options_dict[task].metrics
if self._eval_options_dict[task].metrics is not None
else {}
)
self.metrics_on_progress_bar[task] = self._eval_options_dict[task].metrics_on_progress_bar
self.metrics_on_training_set[task] = (
list(self.metrics[task].keys())
if self._eval_options_dict[task].metrics_on_training_set is None
else self._eval_options_dict[task].metrics_on_training_set
)
self.n_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
# Set the parameters and default values for the FLAG adversarial augmentation, and check values
self._flag_options.set_kwargs()
self.flag_kwargs = self._flag_options.flag_kwargs
# Set the parameters for optimizer options
self.optim_options.set_kwargs()
# Initialize the epoch summary
monitor = self.optim_options.scheduler_kwargs["monitor"].split("/")[0]
mode = self.optim_options.scheduler_kwargs["mode"]
self.task_epoch_summary = TaskSummaries(
task_loss_fun=self.loss_fun,
task_metrics=self.metrics,
task_metrics_on_training_set=self.metrics_on_training_set,
task_metrics_on_progress_bar=self.metrics_on_progress_bar,
monitor=monitor,
mode=mode,
)
# This helps avoid a bug when saving hparams to yaml with different dict or str formats
self._set_hparams(recursive_config_reformating(self.hparams))
# throughput estimation
self.mean_val_time_tracker = MovingAverageTracker()
self.mean_val_tput_tracker = MovingAverageTracker()
self.validation_step_outputs = []
self.test_step_outputs = []
self.epoch_start_time = None
# Decide whether to log every step or once at the end
# of the epoch.
self.metrics_every_n_train_steps = metrics_every_n_train_steps
# Wether save preds and targets for each training step.
self.samples_seen = 0
self.global_bs = global_bs
def forward(
self, inputs: Dict
) -> Dict[str, Union[Tensor, Dict[str, Tensor], Dict[str, Dict[str, Tensor]]]]:
r"""
Returns the result of `self.model.forward(*inputs)` on the inputs.
If the output of `out = self.model.forward` is a dictionary with a `"preds"` key,
it is returned directly. Otherwise, a new dictionary is created and
returns `{"preds": out}`.
Returns:
A dict with a key `"preds"` representing the prediction of the network.
"""
# Convert to the right dtype and run the model
feats = self._convert_features_dtype(inputs["features"])
# *check for nan in model output
out = self.model.forward(feats)
if isinstance(out, dict) and ("preds" in out.keys()):
out_dict = out
else:
out_dict = {"preds": out}
return out_dict
def _convert_features_dtype(self, feats):
# Convert features to dtype
if isinstance(feats, torch.Tensor):
feats = feats.to(self.dtype)
elif isinstance(feats, (Data, Batch, dict)):
for key, val in feats.items():
if isinstance(val, torch.Tensor) and (val.is_floating_point()):
feats[key] = val.to(dtype=self.dtype)
return feats
def _get_task_key(self, task_level: str, task: str):
task_prefix = f"{task_level}_"
if not task.startswith(task_prefix):
task = task_prefix + task
return task
def configure_optimizers(self, impl=None):
if impl is None:
impl = torch.optim.Adam
# Define the optimizer and schedulers
optimiser = MuAdam(self.parameters(), **self.optim_options.optim_kwargs, impl=impl)
self.optim_options.torch_scheduler_kwargs.pop("module_type")
torch_scheduler = self.optim_options.scheduler_class(
optimizer=optimiser, **self.optim_options.torch_scheduler_kwargs
)
scheduler = {
"scheduler": torch_scheduler,
**self.optim_options.scheduler_kwargs,
}
return [optimiser], [scheduler]
@staticmethod
def compute_loss(
preds: Dict[str, Tensor],
targets: Dict[str, Tensor],
weights: Optional[Tensor],
loss_fun: Dict[str, Callable],
target_nan_mask: Optional[Union[str, int]] = None,
multitask_handling: Optional[str] = None,
) -> Tuple[Tensor, Dict[str, Tensor]]:
r"""
Compute the loss using the specified loss function, and dealing with
the nans in the `targets`.
Parameters:
preds:
Predicted values
targets:
Target values
weights:
No longer supported, will raise an error.
target_nan_mask:
- None: Do not change behaviour if there are NaNs
- int, float: Value used to replace NaNs. For example, if `target_nan_mask==0`, then
all NaNs will be replaced by zeros
- 'ignore': The NaN values will be removed from the tensor before computing the metrics.
Must be coupled with the `multitask_handling='flatten'` or `multitask_handling='mean-per-label'`.
multitask_handling:
- None: Do not process the tensor before passing it to the metric.
Cannot use the option `multitask_handling=None` when `target_nan_mask=ignore`.
Use either 'flatten' or 'mean-per-label'.
- 'flatten': Flatten the tensor to produce the equivalent of a single task
- 'mean-per-label': Loop all the labels columns, process them as a single task,
and average the results over each task
*This option might slowdown the computation if there are too many labels*
loss_fun:
Loss function to use
Returns:
Tensor:
weighted_loss: Resulting weighted loss
all_task_losses: Loss per task
"""
wrapped_loss_fun_dict = {
task: MetricWrapper(
metric=loss,
threshold_kwargs=None,
target_nan_mask=target_nan_mask,
multitask_handling=multitask_handling,
)
for task, loss in loss_fun.items()
}
if weights is not None:
raise NotImplementedError("Weights are no longer supported in the loss")
all_task_losses = {
task: wrapped(preds=preds[task], target=targets[task])
for task, wrapped in wrapped_loss_fun_dict.items()
}
total_loss = torch.sum(torch.stack(list(all_task_losses.values())), dim=0)
num_tasks = len(all_task_losses.keys())
weighted_loss = total_loss / num_tasks
return weighted_loss, all_task_losses
def _general_step(self, batch: Dict[str, Tensor], step_name: str, to_cpu: bool) -> Dict[str, Any]:
r"""Common code for training_step, validation_step and testing_step"""
preds = self.forward(batch) # The dictionary of predictions
# * check for nan in model output
targets_dict = batch.get("labels")
# Different type of preds can be return by the forward
if isinstance(preds, dict) and ("preds" in preds.keys()):
preds = preds["preds"]
elif isinstance(preds, Tensor):
preds = {k: preds[ii] for ii, k in enumerate(targets_dict.keys())}
preds = {
self._get_task_key(task_level=self.task_levels[key], task=key): value
for key, value in preds.items()
}
# preds = {k: preds[ii] for ii, k in enumerate(targets_dict.keys())}
for task, pred in preds.items():
task_specific_norm = self.task_norms[task] if self.task_norms is not None else None
if hasattr(task_specific_norm, "normalize_val_test"):
normalize_val_test = task_specific_norm.normalize_val_test
else:
normalize_val_test = False
if step_name != "train" and not normalize_val_test:
# apply denormalization for val and test predictions for correct loss and metrics evaluation
# if normalize_val_test is not true, only train loss will stay as the normalized version
# if normalize_val_test is true, no denormalization is applied, all losses and metrics are normalized version
preds[task] = task_specific_norm.denormalize(pred)
targets_dict[task] = targets_dict[task].to(dtype=pred.dtype)
weights = batch.get("weights", None)
loss, task_losses = self.compute_loss(
preds=preds,
targets=targets_dict,
weights=weights,
loss_fun=self.loss_fun,
target_nan_mask=self.target_nan_mask,
multitask_handling=self.multitask_handling,
)
device = "cpu" if to_cpu else None
for task in preds:
task_specific_norm = self.task_norms[task] if self.task_norms is not None else None
if hasattr(task_specific_norm, "normalize_val_test"):
normalize_val_test = task_specific_norm.normalize_val_test
else:
normalize_val_test = False
if step_name == "train" and not normalize_val_test:
# apply denormalization for targets and predictions for the evaluation of training metrics (excluding loss)
# if normalize_val_test is not true, train loss will stay as the normalized version
# if normalize_val_test is true, no denormalization is applied, all losses and metrics are normalized version
preds[task] = task_specific_norm.denormalize(preds[task])
targets_dict[task] = task_specific_norm.denormalize(targets_dict[task])
preds[task] = preds[task].detach().to(device=device)
targets_dict[task] = targets_dict[task].detach().to(device=device)
if weights is not None:
weights = weights.detach().to(device=device)
step_dict = {"preds": preds, "targets": targets_dict, "weights": weights}
# step_dict[f"{self.loss_fun._get_name()}/{step_name}"] = loss.detach().cpu() original
# step_dict[f"weighted_loss/{step_name}"] = loss.detach().cpu()
# step_dict[f"loss/{step_name}"] = loss.detach().cpu()
for task in self.tasks:
step_dict[
self.task_epoch_summary.metric_log_name(task, self.loss_fun[task]._get_name(), step_name)
] = loss.detach()
step_dict["loss"] = loss
# print("loss ", self.global_step, self.current_epoch, loss)
step_dict["task_losses"] = task_losses
step_dict["gradient_norm"] = self.get_gradient_norm()
return step_dict
def flag_step(self, batch: Dict[str, Tensor], step_name: str, to_cpu: bool) -> Dict[str, Any]:
r"""
Perform adversarial data agumentation during one training step using FLAG.
Paper: https://arxiv.org/abs/2010.09891
Github: https://github.com/devnkong/FLAG
"""
alpha, n_steps = self.flag_kwargs["alpha"], self.flag_kwargs["n_steps"]
X = self._convert_features_dtype(batch["features"])
X_shape = X["feat"].shape
pert = torch.FloatTensor(X_shape).uniform_(-alpha, alpha).to(device=X["feat"].device)
pert.requires_grad = True
# Perturb the features
pert_batch = deepcopy(batch)
features = pert_batch["features"]
features["feat"] = features["feat"] + pert
preds = self.forward(pert_batch)["preds"]
targets = batch.pop("labels")
for key in targets.keys():
targets[key] = targets[key].to(dtype=preds[key].dtype)
weights = batch.pop("weights", None)
loss, task_losses = self.compute_loss(
preds=preds,
targets=targets,
weights=weights,
target_nan_mask=self.target_nan_mask,
multitask_handling=self.multitask_handling,
loss_fun=self.loss_fun,
)
loss = loss / n_steps
# Iteratively augment data by applying perturbations
# Accumulate the gradients to be applied to the weights of the network later on
for _ in range(n_steps - 1):
loss.backward()
pert_data = pert.detach() + alpha * torch.sign(pert.grad.detach())
pert.data = pert_data.data
pert.grad[:] = 0
features["feat"] = features["feat"] + pert
pert_batch["features"] = features
preds = self.forward(pert_batch)["preds"]
loss, _ = self.compute_loss(
preds=preds,
targets=targets,
weights=weights,
target_nan_mask=self.target_nan_mask,
multitask_handling=self.multitask_handling,
loss_fun=self.loss_fun,
)
loss = loss / n_steps
device = "cpu" if to_cpu else None
for key in preds.keys():
preds[key] = preds[key].detach().to(device=device)
targets[key] = targets[key].detach().to(device=device)
if weights is not None:
weights = weights.detach().to(device=device)
step_dict = {"preds": preds, "targets": targets, "weights": weights}
step_dict[f"loss/{step_name}"] = loss.detach().cpu()
step_dict["loss"] = loss
step_dict["task_losses"] = task_losses
return step_dict
def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]:
self.train_batch_start_time = time.time()
self.skip_log_train_metrics = (self.metrics_every_n_train_steps is None) or (
(batch_idx % self.metrics_every_n_train_steps) != 0
)
return super().on_train_batch_start(batch, batch_idx)
def on_train_batch_end(self, outputs, batch: Any, batch_idx: int) -> None:
train_batch_time = time.time() - self.train_batch_start_time # To be used for throughput calculation
# Get the metrics that are logged at every step (loss, grad_norm, batch_time, batch_tput)
concatenated_metrics_logs = {}
concatenated_metrics_logs["train/loss"] = outputs["loss"]
concatenated_metrics_logs["epoch_count"] = self.current_epoch
# Incriment by the batch size
self.samples_seen += self.global_bs
concatenated_metrics_logs["samples_seen"] = self.samples_seen
# report the training loss for each individual tasks
for task in self.tasks:
concatenated_metrics_logs[f"train/loss/{task}"] = outputs["task_losses"][task]
# get the mean loss value for individual tasks as they are a tensor of size --> gradient accumulation * replication * device_iter
# filter zeros out for the individual losses
for key in concatenated_metrics_logs:
if isinstance(concatenated_metrics_logs[key], torch.Tensor):
if concatenated_metrics_logs[key].numel() > 1:
concatenated_metrics_logs[key] = concatenated_metrics_logs[key][
concatenated_metrics_logs[key] != 0
].mean()
# If logging is skipped for this step, then log the important metrics anyway and return
if self.skip_log_train_metrics:
if self.logger is not None:
self.logger.log_metrics(
concatenated_metrics_logs, step=self.global_step
) # This is a pytorch lightning function call
return
### The code below is not executed if the logging is skipped for this step ###
# Get the throughput of the batch
num_graphs = self.get_num_graphs(batch["features"])
tput = num_graphs / train_batch_time
concatenated_metrics_logs["train/batch_time"] = train_batch_time
concatenated_metrics_logs["train/batch_tput"] = tput
# Compute all the metrics for the training set
self.task_epoch_summary.update_predictor_state(
step_name="train",
targets=outputs["targets"],
preds=outputs["preds"],
loss=outputs["loss"], # This is the weighted loss for now, but change to task-specific loss
task_losses=outputs["task_losses"],
n_epochs=self.current_epoch,
)
metrics_logs = self.task_epoch_summary.get_metrics_logs() # Dict[task, metric_logs]
metrics_logs["_global"]["grad_norm"] = self.get_gradient_norm()
concatenated_metrics_logs.update(metrics_logs)
# Log the metrics
if self.logger is not None:
self.logger.log_metrics(
concatenated_metrics_logs, step=self.global_step
) # This is a pytorch lightning function call
def training_step(self, batch: Dict[str, Tensor], to_cpu: bool = True) -> Dict[str, Any]:
step_dict = None
# Train using FLAG
if self.flag_kwargs["n_steps"] > 0:
step_dict = self.flag_step(batch=batch, step_name="train", to_cpu=to_cpu)
# Train normally, without using FLAG
elif self.flag_kwargs["n_steps"] == 0:
# step_dict = self._general_step(batch=batch, step_name="train", to_cpu=True)
step_dict = self._general_step(batch=batch, step_name="train", to_cpu=to_cpu)
# Remove the preds and targets if no logging is required
if self.skip_log_train_metrics:
step_dict.pop("preds")
step_dict.pop("targets")
return step_dict # Returning the metrics_logs with the loss
def get_gradient_norm(self):
# compute the norm
total_norm = torch.tensor(0.0)
for p in self.parameters():
if p.grad is not None:
param_norm = p.grad.detach().data.norm(2)
total_norm += param_norm.detach().cpu() ** 2
total_norm = total_norm**0.5
return total_norm
def validation_step(self, batch: Dict[str, Tensor], to_cpu: bool = True) -> Dict[str, Any]:
return self._general_step(batch=batch, step_name="val", to_cpu=to_cpu)
def test_step(self, batch: Dict[str, Tensor], to_cpu: bool = True) -> Dict[str, Any]:
return self._general_step(batch=batch, step_name="test", to_cpu=to_cpu)
def _general_epoch_end(self, outputs: Dict[str, Any], step_name: str, device: str) -> None:
r"""Common code for training_epoch_end, validation_epoch_end and testing_epoch_end"""
# Transform the list of dict of dict, into a dict of list of dict
preds = {}
targets = {}
for task in self.tasks:
preds[task] = torch.cat([out["preds"][task].to(device) for out in outputs], dim=0)
targets[task] = torch.cat([out["targets"][task].to(device) for out in outputs], dim=0)
if ("weights" in outputs[0].keys()) and (outputs[0]["weights"] is not None):
weights = torch.cat([out["weights"].to(device) for out in outputs], dim=0)
else:
weights = None
# NOTE: Computing the loss over the entire split may cause
# overflow issues when using fp16
loss, task_losses = self.compute_loss(
preds=dict_tensor_fp16_to_fp32(preds),
targets=dict_tensor_fp16_to_fp32(targets),
weights=weights,
target_nan_mask=self.target_nan_mask,
multitask_handling=self.multitask_handling,
loss_fun=self.loss_fun,
)
self.task_epoch_summary.update_predictor_state(
step_name=step_name,
preds=preds,
targets=targets,
loss=loss,
task_losses=task_losses,
n_epochs=self.current_epoch,
)
metrics_logs = self.task_epoch_summary.get_metrics_logs()
self.task_epoch_summary.set_results(task_metrics=metrics_logs)
return metrics_logs # Consider returning concatenated dict for logging
def on_train_epoch_start(self) -> None:
self.epoch_start_time = time.time()
def on_train_epoch_end(self) -> None:
if self.epoch_start_time is None:
logger.warning("epoch timer not initialized")
else:
epoch_time = time.time() - self.epoch_start_time
self.epoch_start_time = None
self.log("epoch_time", torch.tensor(epoch_time), sync_dist=True)
def on_validation_epoch_start(self) -> None:
self.mean_val_time_tracker.reset()
self.mean_val_tput_tracker.reset()
return super().on_validation_epoch_start()
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
self.validation_batch_start_time = time.time()
return super().on_validation_batch_start(batch, batch_idx, dataloader_idx)
def on_validation_batch_end(
self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> None:
val_batch_time = time.time() - self.validation_batch_start_time
self.validation_step_outputs.append(outputs)
self.mean_val_time_tracker.update(val_batch_time)
num_graphs = self.get_num_graphs(batch["features"])
self.mean_val_tput_tracker.update(num_graphs / val_batch_time)
return super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)
def on_validation_epoch_end(self) -> None:
metrics_logs = self._general_epoch_end(
outputs=self.validation_step_outputs, step_name="val", device="cpu"
)
self.validation_step_outputs.clear()
concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs)
concatenated_metrics_logs["val/mean_time"] = torch.tensor(self.mean_val_time_tracker.mean_value)
concatenated_metrics_logs["val/mean_tput"] = self.mean_val_tput_tracker.mean_value
self.log_dict(concatenated_metrics_logs, sync_dist=True)
# Save yaml file with the per-task metrics summaries
full_dict = {}
full_dict.update(self.task_epoch_summary.get_dict_summary())
def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
self.test_step_outputs.append(outputs)
def on_test_epoch_end(self) -> None:
metrics_logs = self._general_epoch_end(outputs=self.test_step_outputs, step_name="test", device="cpu")
self.test_step_outputs.clear()
concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs)
self.log_dict(concatenated_metrics_logs, sync_dist=True)
# Save yaml file with the per-task metrics summaries
full_dict = {}
full_dict.update(self.task_epoch_summary.get_dict_summary())
def on_train_start(self):
hparams_log = deepcopy(self.hparams)
hparams_log["n_params"] = self.n_params
if self.logger is not None:
self.logger.log_hyperparams(hparams_log)
def get_progress_bar_dict(self) -> Dict[str, float]:
prog_dict = {}
prog_dict["loss"] = self.task_epoch_summary.weighted_loss.detach().cpu()
results_on_progress_bar = self.task_epoch_summary.get_results_on_progress_bar("val")
for task in self.tasks:
prog_dict[self.task_epoch_summary.metric_log_name(task, "loss", "val")] = (
self.task_epoch_summary.task_summaries[task].summaries["val"].loss
)
prog_dict.update(results_on_progress_bar)
return prog_dict
def __repr__(self) -> str:
r"""
Controls how the class is printed
"""
model_str = self.model.__repr__()
summary_str = self.summarize().__repr__()
return model_str + "\n\n" + summary_str
@staticmethod
def list_pretrained_models():
"""List available pretrained models."""
return GRAPHIUM_PRETRAINED_MODELS_DICT
@staticmethod
def load_pretrained_model(name_or_path: str, device: str = None):
"""Load a pretrained model from its name.
Args:
name: Name of the model to load or a valid checkpoint path. List available
from `graphium.trainer.PredictorModule.list_pretrained_models()`.
"""
name = GRAPHIUM_PRETRAINED_MODELS_DICT.get(name_or_path)
if name is not None:
return PredictorModule.load_from_checkpoint(
GRAPHIUM_PRETRAINED_MODELS_DICT[name_or_path], map_location=device
)
if name is None and not (fs.exists(name_or_path) and fs.get_extension(name_or_path) == "ckpt"):
raise ValueError(
f"The model '{name_or_path}' is not available. Choose from {set(GRAPHIUM_PRETRAINED_MODELS_DICT.keys())} "
"or pass a valid checkpoint (.ckpt) path."
)
return PredictorModule.load_from_checkpoint(name_or_path, map_location=device)
def set_max_nodes_edges_per_graph(self, datamodule: BaseDataModule, stages: Optional[List[str]] = None):
datamodule.setup()
max_nodes = datamodule.get_max_num_nodes_datamodule(stages)
max_edges = datamodule.get_max_num_edges_datamodule(stages)
self.model.set_max_num_nodes_edges_per_graph(max_nodes, max_edges)
def get_num_graphs(self, data: Batch):
"""
Method to compute number of graphs in a Batch.
Essential to estimate throughput in graphs/s.
"""
return torch.max(data.batch) + 1
================================================
FILE: graphium/trainer/predictor_options.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
r"""Data classes to group together related arguments for the creation of a Predictor Module."""
"""
Replace the usage of **kwargs by adding checks to make sure that everything is type-safe.
Stricter typing is important because:
- It makes finding bugs easier since incorrect types will cause an obvious error.
- Static analysis becomes easier, and the IDE can give hints about errors in the code before runtime.
Add the post-init function to do the checks immediately.
"""
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Type, Union
from inspect import signature, isclass
from torch import nn
from graphium.utils.spaces import LOSS_DICT
from graphium.utils.spaces import SCHEDULER_DICT
@dataclass
class ModelOptions:
r"""
This data class stores the arguments necessary to instantiate a model for the Predictor.
Parameters:
model_class:
pytorch module used to create a model
model_kwargs:
Key-word arguments used to initialize the model from `model_class`.
"""
model_class: Type[nn.Module]
model_kwargs: Dict[str, Any]
@dataclass
class OptimOptions:
r"""
This data class stores the arguments necessary to configure the optimizer for the Predictor.
Parameters:
optim_kwargs:
Dictionnary used to initialize the optimizer, with possible keys below.
- lr `float`: Learning rate (Default=`1e-3`)
- weight_decay `float`: Weight decay used to regularize the optimizer (Default=`0.`)
torch_scheduler_kwargs:
Dictionnary for the scheduling of learning rate, with possible keys below.
- type `str`: Type of the learning rate to use from pytorch. Examples are
`'ReduceLROnPlateau'` (default), `'CosineAnnealingWarmRestarts'`, `'StepLR'`, etc.
- **kwargs: Any other argument for the learning rate scheduler
scheduler_kwargs:
Dictionnary for the scheduling of the learning rate modification used by pytorch-lightning
- monitor `str`: metric to track (Default=`"loss/val"`)
- interval `str`: Whether to look at iterations or epochs (Default=`"epoch"`)
- strict `bool`: if set to True will enforce that value specified in monitor is available
while trying to call scheduler.step(), and stop training if not found. If False will
only give a warning and continue training (without calling the scheduler). (Default=`True`)
- frequency `int`: **TODO: NOT REALLY SURE HOW IT WORKS!** (Default=`1`)
scheduler_class: The class to use for the scheduler, or the str representing the scheduler.
"""
optim_kwargs: Optional[Dict[str, Any]] = None
torch_scheduler_kwargs: Optional[Dict[str, Any]] = None
scheduler_kwargs: Optional[Dict[str, Any]] = None
scheduler_class: Optional[Union[str, Type]] = None
# Instead of passing a dictionary to be processed by the predictor,
# this class will process the dictionary in advance and return the optimizer
def set_kwargs(self):
torch_scheduler_kwargs = deepcopy(self.torch_scheduler_kwargs)
# Set the parameters and default value for the optimizer, and check values
if self.optim_kwargs is None:
self.optim_kwargs = {}
self.optim_kwargs.setdefault("lr", 1e-3)
self.optim_kwargs.setdefault("weight_decay", 0.0)
assert self.optim_kwargs["lr"] > 0
assert self.optim_kwargs["weight_decay"] >= 0
# Set the lightning scheduler
if self.scheduler_kwargs is None:
self.scheduler_kwargs = {}
self.scheduler_kwargs.setdefault("interval", "epoch")
self.scheduler_kwargs.setdefault("monitor", "loss/val")
self.scheduler_kwargs.setdefault("mode", "min")
self.scheduler_kwargs.setdefault("frequency", 1)
self.scheduler_kwargs.setdefault("strict", True)
# Set the pytorch scheduler arguments
if torch_scheduler_kwargs is None:
torch_scheduler_kwargs = {}
torch_scheduler_kwargs.setdefault("module_type", "ReduceLROnPlateau")
# Get the class for the scheduler
scheduler_class = torch_scheduler_kwargs.pop("module_type")
if self.scheduler_class is None:
if isinstance(scheduler_class, str):
self.scheduler_class = SCHEDULER_DICT[scheduler_class]
elif isclass(scheduler_class):
self.scheduler_class = scheduler_class
else:
raise TypeError("`scheduler_class` should be a str or a class")
# Add the `monitor` and `mode` variables
sig = signature(self.scheduler_class.__init__)
key_args = [p.name for p in sig.parameters.values()]
if "monitor" in key_args:
torch_scheduler_kwargs.setdefault("monitor", self.scheduler_kwargs["monitor"])
if "mode" in key_args:
torch_scheduler_kwargs.setdefault("mode", self.scheduler_kwargs["mode"])
@dataclass
class EvalOptions:
r"""
This data class stores the arguments necessary to instantiate a model for the Predictor.
Parameters:
loss_fun:
Loss function used during training.
Acceptable strings are graphium.utils.spaces.LOSS_DICT.keys().
If a dict, must contain a 'name' key with one of the acceptable loss function strings
as a value. The rest of the dict will be used as the arguments passed to the loss object.
Otherwise, a callable object must be provided, with a method `loss_fun._get_name()`.
metrics:
A dictionnary of metrics to compute on the prediction, other than the loss function.
These metrics will be logged into WandB or other.
metrics_on_progress_bar:
The metrics names from `metrics` to display also on the progress bar of the training
metrics_on_training_set:
The metrics names from `metrics` to be computed on the training set for each iteration.
If `None`, all the metrics are computed. Using less metrics can significantly improve
performance, depending on the number of readouts.
"""
loss_fun: Union[str, Dict, Callable]
metrics: Dict[str, Callable] = None
metrics_on_progress_bar: List[str] = field(default_factory=List[str])
metrics_on_training_set: Optional[List[str]] = None
def check_metrics_validity(self):
"""
Check that the metrics for the progress_par and training_set are valid
"""
if self.metrics_on_progress_bar is not None:
selected = set(self.metrics_on_progress_bar)
assert selected.issubset(
set(self.metrics.keys())
), f"Metrics {selected - set(self.metrics.keys())} not in `metrics` with choices {set(self.metrics.keys())}"
if self.metrics_on_training_set is not None:
selected = set(self.metrics_on_training_set)
assert selected.issubset(
set(self.metrics.keys())
), f"Metrics {selected - set(self.metrics.keys())} not in `metrics` with choices {set(self.metrics.keys())}"
# Parse before or after?
@staticmethod
def parse_loss_fun(loss_fun: Union[str, Dict, Callable]) -> Callable:
r"""
Parse the loss function from a string or a dict
Parameters:
loss_fun:
A callable corresponding to the loss function, a string specifying the loss
function from `LOSS_DICT`, or a dict containing a key 'name' specifying the
loss function, and the rest of the dict used as the arguments for the loss.
Accepted strings are: graphium.utils.spaces.LOSS_DICT.keys().
Returns:
Callable:
Function or callable to compute the loss, takes `preds` and `targets` as inputs.
"""
if isinstance(loss_fun, str):
if loss_fun not in LOSS_DICT.keys():
raise ValueError(
f"`loss_fun` expected to be one of the strings in {LOSS_DICT.keys()}. "
f"Provided: {loss_fun}."
)
loss_fun = LOSS_DICT[loss_fun]()
elif isinstance(loss_fun, dict):
if loss_fun.get("name") is None:
raise ValueError(f"`loss_fun` expected to have a key 'name'.")
if loss_fun["name"] not in LOSS_DICT.keys():
raise ValueError(
f"`loss_fun['name']` expected to be one of the strings in {LOSS_DICT.keys()}. "
f"Provided: {loss_fun}."
)
loss_fun = deepcopy(loss_fun)
loss_name = loss_fun.pop("name")
loss_fun = LOSS_DICT[loss_name](**loss_fun)
elif not callable(loss_fun):
raise ValueError(f"`loss_fun` must be `str`, `dict` or `callable`. Provided: {type(loss_fun)}")
return loss_fun
@dataclass
class FlagOptions:
r"""
This data class stores the arguments necessary to instantiate a model for the Predictor.
Parameters:
flag_kwargs:
Keyword arguments used for FLAG, and adversarial data augmentation for graph networks.
See: https://arxiv.org/abs/2010.09891
- n_steps: An integer that specifies the number of ascent steps when running FLAG during training.
Default value of 0 trains GNNs without FLAG, and any value greater than 0 will use FLAG with that
many iterations.
- alpha: A float that specifies the ascent step size when running FLAG. Default=0.01
"""
flag_kwargs: Dict[str, Any] = None
# Set the parameters and default values for the FLAG adversarial augmentation, and check values
def set_kwargs(self):
if self.flag_kwargs is None:
self.flag_kwargs = {}
self.flag_kwargs.setdefault("alpha", 0.01)
self.flag_kwargs.setdefault("n_steps", 0)
assert isinstance(self.flag_kwargs["n_steps"], int) and (self.flag_kwargs["n_steps"] >= 0)
assert self.flag_kwargs["alpha"] >= 0
================================================
FILE: graphium/trainer/predictor_summaries.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
r"""Classes to store information about resulting evaluation metrics when using a Predictor Module."""
from typing import Any, Callable, Dict, List, Optional, Union
from loguru import logger
import numpy as np
import torch
from torch import Tensor
from graphium.utils.tensor import nan_mean, nan_std, nan_median, tensor_fp16_to_fp32
class SummaryInterface(object):
r"""
An interface to define the functions implemented by summary classes that implement SummaryInterface.
"""
def set_results(self, **kwargs):
raise NotImplementedError()
def get_dict_summary(self):
raise NotImplementedError()
def update_predictor_state(self, **kwargs):
raise NotImplementedError()
def get_metrics_logs(self, **kwargs):
raise NotImplementedError()
class Summary(SummaryInterface):
# TODO (Gabriela): Default argument cannot be []
def __init__(
self,
loss_fun: Union[str, Callable],
metrics: Dict[str, Callable],
metrics_on_training_set: List[str] = [],
metrics_on_progress_bar: List[str] = [],
monitor: str = "loss",
mode: str = "min",
task_name: Optional[str] = None,
):
r"""
A container to be used by the Predictor Module that stores the results for the given metrics on the predictions and targets provided.
Parameters:
loss_fun:
Loss function used during training. Acceptable strings are 'mse', 'bce', 'mae', 'cosine'.
Otherwise, a callable object must be provided, with a method `loss_fun._get_name()`.
metrics:
A dictionnary of metrics to compute on the prediction, other than the loss function.
These metrics will be logged into WandB or similar.
metrics_on_training_set:
The metrics names from `metrics` to be computed on the training set for each iteration.
If `None`, all the metrics are computed. Using less metrics can significantly improve
performance, depending on the number of readouts.
metrics_on_progress_bar:
The metrics names from `metrics` to display also on the progress bar of the training
monitor:
`str` metric to track (Default=`"loss/val"`)
task_name:
name of the task (Default=`None`)
"""
self.loss_fun = loss_fun
self.metrics = metrics
self.metrics_on_training_set = metrics_on_training_set
self.metrics_on_progress_bar = metrics_on_progress_bar
self.monitor = monitor
self.mode = mode
self.summaries = {}
self.best_summaries = {}
# Current predictor state
# self.predictor_outputs = None
self.step_name: str = None
self.targets: Tensor = None
self.preds: Tensor = None
self.loss = None # What type?
self.n_epochs: int = None
self.task_name = task_name
self.logged_metrics_exceptions = [] # Track which metric exceptions have been logged
def update_predictor_state(
self, step_name: str, targets: Tensor, preds: Tensor, loss: Tensor, n_epochs: int
):
r"""
update the state of the predictor
Parameters:
step_name: which stage you are in, e.g. "train"
targets: the targets tensor
predictions: the predictions tensor
loss: the loss tensor
n_epochs: the number of epochs
"""
self.step_name = step_name
self.targets = targets
self.preds = preds
self.loss = loss
self.n_epochs = n_epochs
def set_results(
self,
metrics: Dict[str, Tensor],
):
r"""
set the reults from the metrics
[!] This function requires that self.update_predictor_state() be called before it.
Parameters:
metrics: a dictionary of metrics
"""
# Include the task_name in the loss for logging, and similarly for other metrics
metrics[self.metric_log_name(self.task_name, "loss", self.step_name)] = self.loss
self.summaries[self.step_name] = Summary.Results(
targets=self.targets,
preds=self.preds,
loss=self.loss,
metrics=metrics, # Should include task name from get_metrics_logs()
monitored_metric=f"{self.monitor}/{self.step_name}", # Include task name?
n_epochs=self.n_epochs,
)
if self.is_best_epoch(self.step_name, self.loss, metrics):
self.best_summaries[self.step_name] = self.summaries[self.step_name]
def is_best_epoch(self, step_name: str, loss: Tensor, metrics: Dict[str, Tensor]) -> bool:
r"""
check if the current epoch is the best epoch based on self.mode criteria
Parameters:
step_name: which stage you are in, e.g. "train"
loss: the loss tensor
metrics: a dictionary of metrics
"""
# TODO (Gabriela): Check for bugs related to monitor_name
if not (step_name in self.best_summaries.keys()):
return True
# Include the task_name in the loss for logging, and similarly for other metrics
metrics[self.metric_log_name(self.task_name, "loss", self.step_name)] = loss
monitor_name = f"{self.monitor}/{step_name}" # Include task_name?
if (
not monitor_name in self.best_summaries.keys()
): # Feels like there's a bug here. What is this trying to do???
return True
if self.mode == "max":
return metrics[monitor_name] > self.best_summaries[step_name].monitored
elif self.mode == "min":
return metrics[monitor_name] < self.best_summaries[step_name].monitored
else:
ValueError(f"Mode must be 'min' or 'max', provided `{self.mode}`")
def get_results(
self,
step_name: str,
):
r"""
retrieve the results for a given step
Parameters:
step_name: which stage you are in, e.g. "train"
Returns:
the results for the given step
"""
return self.summaries[step_name]
def get_best_results(
self,
step_name: str,
):
r"""
retrieve the best results for a given step
Parameters:
step_name: which stage you are in, e.g. "train"
Returns:
the best results for the given step
"""
return self.best_summaries[step_name]
def get_results_on_progress_bar(
self,
step_name: str,
) -> Dict[str, Tensor]:
r"""
retrieve the results to be displayed on the progress bar for a given step
Parameters:
step_name: which stage you are in, e.g. "train"
Returns:
the results to be displayed on the progress bar for the given step
"""
results = self.summaries[step_name]
results_prog = {
# f"{kk}/{step_name}": results.metrics[f"{kk}/{step_name}"] for kk in self.metrics_on_progress_bar
self.metric_log_name(self.task_name, kk, step_name): results.metrics[
self.metric_log_name(self.task_name, kk, step_name)
]
for kk in self.metrics_on_progress_bar
}
return results_prog
def get_dict_summary(self) -> Dict[str, Any]:
r"""
retrieve the full summary in a dictionary
Returns:
the full summary in a dictionary
"""
full_dict = {}
# Get metric summaries
full_dict["metric_summaries"] = {}
for key, val in self.summaries.items():
full_dict["metric_summaries"][key] = {k: v for k, v in val.metrics.items()}
full_dict["metric_summaries"][key]["n_epochs"] = val.n_epochs
# Get metric summaries at best epoch
full_dict["best_epoch_metric_summaries"] = {}
for key, val in self.best_summaries.items():
full_dict["best_epoch_metric_summaries"][key] = val.metrics
full_dict["best_epoch_metric_summaries"][key]["n_epochs"] = val.n_epochs
return full_dict
def get_metrics_logs(self) -> Dict[str, Any]:
r"""
Get the data about metrics to log.
Note: This function requires that self.update_predictor_state() be called before it.
Returns:
A dictionary of metrics to log.
"""
targets = tensor_fp16_to_fp32(self.targets)
preds = tensor_fp16_to_fp32(self.preds)
targets = targets.to(dtype=preds.dtype, device=preds.device)
# Compute the metrics always used in regression tasks
metric_logs = {}
metric_logs[self.metric_log_name(self.task_name, "mean_pred", self.step_name)] = nan_mean(preds)
metric_logs[self.metric_log_name(self.task_name, "std_pred", self.step_name)] = nan_std(preds)
metric_logs[self.metric_log_name(self.task_name, "median_pred", self.step_name)] = nan_median(preds)
metric_logs[self.metric_log_name(self.task_name, "mean_target", self.step_name)] = nan_mean(targets)
metric_logs[self.metric_log_name(self.task_name, "std_target", self.step_name)] = nan_std(targets)
metric_logs[self.metric_log_name(self.task_name, "median_target", self.step_name)] = nan_median(
targets
)
# Specify which metrics to use
metrics_to_use = self.metrics
if self.step_name == "train":
metrics_to_use = {
key: metric for key, metric in metrics_to_use.items() if key in self.metrics_on_training_set
}
# Compute the additional metrics
for key, metric in metrics_to_use.items():
metric_name = self.metric_log_name(
self.task_name, key, self.step_name
) # f"{key}/{self.step_name}"
try:
metric_logs[metric_name] = metric(preds, targets)
except Exception as e:
metric_logs[metric_name] = torch.as_tensor(float("nan"))
# Warn only if it's the first warning for that metric
if metric_name not in self.logged_metrics_exceptions:
self.logged_metrics_exceptions.append(metric_name)
logger.warning(f"Error for metric {metric_name}. NaN is returned. Exception: {e}")
# Convert all metrics to CPU, except for the loss
# metric_logs[f"{self.loss_fun._get_name()}/{self.step_name}"] = self.loss.detach().cpu()
metric_logs[self.metric_log_name(self.task_name, self.loss_fun._get_name(), self.step_name)] = (
self.loss.detach().cpu()
)
# print("Metrics logs keys: ", metric_logs.keys())
metric_logs = {key: metric.detach().cpu() for key, metric in metric_logs.items()}
return metric_logs
def metric_log_name(self, task_name, metric_name, step_name):
if task_name is None:
return f"{metric_name}/{step_name}"
else:
return f"{task_name}/{metric_name}/{step_name}"
class Results:
def __init__(
self,
targets: Tensor = None,
preds: Tensor = None,
loss: float = None, # Is this supposed to be a Tensor or float?
metrics: dict = None,
monitored_metric: str = None,
n_epochs: int = None,
):
r"""
This inner class is used as a container for storing the results of the summary.
Parameters:
targets: the targets
preds: the prediction tensor
loss: the loss, float or tensor
metrics: the metrics
monitored_metric: the monitored metric
n_epochs: the number of epochs
"""
self.targets = targets.detach().cpu()
self.preds = preds.detach().cpu()
self.loss = loss.item() if isinstance(loss, Tensor) else loss
self.monitored_metric = monitored_metric
if monitored_metric in metrics.keys():
self.monitored = metrics[monitored_metric].detach().cpu()
self.metrics = {
key: value.tolist() if isinstance(value, (Tensor, np.ndarray)) else value
for key, value in metrics.items()
}
self.n_epochs = n_epochs
class TaskSummaries(SummaryInterface):
def __init__(
self,
task_loss_fun: Callable,
task_metrics: Dict[str, Callable],
task_metrics_on_training_set: List[str],
task_metrics_on_progress_bar: List[str],
monitor: str = "loss",
mode: str = "min",
):
r"""
class to store the summaries of the tasks
Parameters:
task_loss_fun: the loss function for each task
task_metrics: the metrics for each task
task_metrics_on_training_set: the metrics to use on the training set
task_metrics_on_progress_bar: the metrics to use on the progress bar
monitor: the metric to monitor
mode: the mode of the metric to monitor
"""
self.task_loss_fun = task_loss_fun
self.task_metrics = task_metrics
self.task_metrics_on_progress_bar = task_metrics_on_progress_bar
self.task_metrics_on_training_set = task_metrics_on_training_set
self.monitor = monitor
self.mode = mode
self.task_summaries: Dict[str, Summary] = {}
self.task_best_summaries: Dict[str, Summary] = {}
self.tasks = list(task_loss_fun.keys())
for task in self.tasks:
self.task_summaries[task] = Summary(
self.task_loss_fun[task],
self.task_metrics[task],
self.task_metrics_on_training_set[task],
self.task_metrics_on_progress_bar[task],
self.monitor,
self.mode,
task_name=task,
)
# Current predictor state
self.weighted_loss = None
self.step_name = None
def update_predictor_state(
self,
step_name: str,
targets: Dict[str, Tensor],
preds: Dict[str, Tensor],
loss: Tensor,
task_losses: Dict[str, Tensor],
n_epochs: int,
):
r"""
update the state for all predictors
Parameters:
step_name: the name of the step
targets: the target tensors
preds: the prediction tensors
loss: the loss tensor
task_losses: the task losses
n_epochs: the number of epochs
"""
self.weighted_loss = loss
self.step_name = step_name
for task in self.tasks:
self.task_summaries[task].update_predictor_state(
step_name,
targets[task],
preds[task].detach(),
task_losses[task].detach(),
n_epochs,
)
def set_results(self, task_metrics: Dict[str, Dict[str, Tensor]]):
"""
set the results for all tasks
Parameters:
task_metrics: the metrics for each task
"""
for task in self.tasks:
self.task_summaries[task].set_results(task_metrics[task])
step_name = self.task_summaries[task].step_name
loss = self.task_summaries[task].loss
if self.task_summaries[task].is_best_epoch(step_name, loss, task_metrics[task]):
self.task_summaries[task].best_summaries[step_name] = self.task_summaries[task].summaries[
step_name
]
def get_results(
self,
step_name: str,
) -> Dict[str, Dict[str, Any]]:
"""
retrieve the results
Parameters:
step_name: the name of the step, i.e. "train"
Returns:
the results
"""
results = {}
for task in self.tasks:
results[task] = self.task_summaries[task].get_results(step_name)
return results
def get_best_results(
self,
step_name: str,
) -> Dict[str, Dict[str, Any]]:
"""
retrieve the best results
Parameters:
step_name: the name of the step, i.e. "train"
Returns:
the best results
"""
results = {}
for task in self.tasks:
results[task] = self.task_summaries[task].get_best_results(step_name)
return results
def get_results_on_progress_bar(
self,
step_name: str,
) -> Dict[str, Dict[str, Any]]:
r"""
return all results from all tasks for the progress bar
Combine the dictionaries. Instead of having keys as task names, we merge all the task-specific dictionaries.
Parameters:
step_name: the name of the step
Returns:
the results for the progress bar
"""
task_results_prog = {}
for task in self.tasks:
# task_results_prog[task] = self.task_summaries[task].get_results_on_progress_bar(step_name)
task_results_prog.update(self.task_summaries[task].get_results_on_progress_bar(step_name))
return task_results_prog
def get_dict_summary(
self,
) -> Dict[str, Dict[str, Any]]:
r"""
get task summaries in a dictionary
Returns:
the task summaries
"""
task_full_dict = {}
for task in self.tasks:
task_full_dict[task] = self.task_summaries[task].get_dict_summary()
return task_full_dict
def get_metrics_logs(
self,
) -> Dict[str, Dict[str, Tensor]]:
r"""
get the logs for the metrics
Returns:
the task logs for the metrics
"""
task_metrics_logs = {}
for task in self.tasks:
task_metrics_logs[task] = self.task_summaries[task].get_metrics_logs()
# average metrics
for key in task_metrics_logs[task]:
if isinstance(task_metrics_logs[task][key], torch.Tensor):
if task_metrics_logs[task][key].numel() > 1:
task_metrics_logs[task][key] = task_metrics_logs[task][key][
task_metrics_logs[task][key] != 0
].mean()
# Include global (weighted loss)
task_metrics_logs["_global"] = {}
task_metrics_logs["_global"][f"loss/{self.step_name}"] = self.weighted_loss.detach().cpu()
return task_metrics_logs
# TODO (Gabriela): This works to fix the logging on TB, but make it more efficient
def concatenate_metrics_logs(
self,
metrics_logs: Dict[str, Dict[str, Tensor]],
) -> Dict[str, Tensor]:
r"""
concatenate the metrics logs
Parameters:
metrics_logs: the metrics logs
Returns:
the concatenated metrics logs
"""
concatenated_metrics_logs = {}
for task in list(self.tasks) + ["_global"]:
concatenated_metrics_logs.update(metrics_logs[task])
concatenated_metrics_logs[f"loss/{self.step_name}"] = self.weighted_loss.detach().cpu()
return concatenated_metrics_logs
def metric_log_name(
self,
task_name: str,
metric_name: str,
step_name: str,
) -> str:
r"""
print the metric name, task name and step name
Returns:
the metric name, task name and step name
"""
if task_name is None:
return f"{metric_name}/{step_name}"
else:
return f"{task_name}/{metric_name}/{step_name}"
================================================
FILE: graphium/utils/README.md
================================================
The Graph Of LIfe Library.
## What is in this folder?
folder for utils and helper functions
- ✅ `spaces.py`: the file to track the space of all names and dictionaries
- `mup.py`: helper function for mup operations
================================================
FILE: graphium/utils/__init__.py
================================================
from . import fs
from . import tensor
================================================
FILE: graphium/utils/arg_checker.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
""" Argument checker module """
import collections
import numpy as np
# Global variable of accepted string types
KNOWN_TYPES = {
"none": None,
"str": str,
"list": list,
"tuple": tuple,
"dict": dict,
"int": int,
"float": float,
"complex": complex,
"bool": bool,
"callable": callable,
}
def _parse_type(type_to_validate, accepted_types):
# Check if the provided type is accepted
if (type_to_validate is not None) and (not isinstance(type_to_validate, accepted_types)):
raise TypeError(
"type_to_validate should be None, type or str. {} provided".format(type(type_to_validate))
)
if isinstance(type_to_validate, str):
type_to_validate = type_to_validate.lower()
if type_to_validate in KNOWN_TYPES.keys():
type_to_validate = KNOWN_TYPES[type_to_validate]
else:
raise TypeError(
"type_to_validate is not a known type. Known types are :"
" \n{}\n Provided : \n{}".format(KNOWN_TYPES.keys(), type_to_validate)
)
return type_to_validate
def _enforce_iter_type(arg, enforce_type):
# Cast the arg to be either a list or a tuple
if enforce_type is not None:
if (enforce_type == list) and (not isinstance(arg, list)):
arg = list(arg)
elif (enforce_type == tuple) and (not isinstance(arg, tuple)):
arg = tuple(arg)
elif enforce_type not in (list, tuple):
raise TypeError('enforce_type should be None, "list" or "tuple", but is {}'.format(enforce_type))
return arg
def check_arg_iterator(arg, enforce_type=None, enforce_subtype=None, cast_subtype: bool = True):
r"""
Verify if the type is an iterator. If it is `None`, convert to an empty list/tuple. If it is
not a list/tuple/str, try to convert to an iterator. If it is a str or cannot be converted to
an iterator, then put the `arg` inside an iterator.
Possibly enforce the iterator type to `list` or `tuple`, if `enfoce_type` is not None.
Possibly enforce the subtype to any given type if `enforce_subtype` is not None,
and decide whether to cast the subtype or to throw an error.
Parameters:
arg (any type):
The input to verify/convert to an iterator (list or tuple). If None, an empty iterator
is returned.
enforce_type (str or type):
The type to enforce the iterator. The valid choices are :
`None`, `list`, `tuple`, `'none'`, `'list'`, `'tuple'`.
If `None`, then the iterator type is not enforced.
enforce_subtype (type, np.dtype or str representing basic type):
Verify if all the elements inside the iterator are the desired type.
If `None`, then the sub-type is not enforced.
Accepted strings are ['none', 'str', 'list', 'tuple', 'dict', 'int',
'float', 'complex', 'bool', 'callable']
cast_subtype:
If True, then the type specified by `enforce_subtype` is used to cast the
elements inside the iterator. If False, then an error is thrown if the
types do not match.
Returns:
output (iterator):
An iterator based on the input of the desired type (list or tuple) and
the desired subtypes.
"""
# If not list or tuple, put into a list
if arg is None:
arg = []
elif isinstance(arg, str):
arg = [arg]
elif isinstance(arg, tuple):
if enforce_type is None:
enforce_type = tuple
arg = list(arg)
elif not isinstance(arg, (tuple, list)):
try:
arg = list(arg)
except Exception:
arg = [arg]
output = arg
# Make sure that enforce_type and enforce_subtype are a good inputs
enforce_type = _parse_type(enforce_type, (type, str))
enforce_subtype = _parse_type(enforce_subtype, (type, str, np.dtype))
# Cast all the subtypes of the list/tuple into the desired subtype
if enforce_subtype is not None:
if enforce_type is None:
arg2 = output
elif not isinstance(output, enforce_type):
arg2 = list(output)
else:
arg2 = output
try:
for idx, a in enumerate(output):
if not isinstance(a, enforce_subtype):
if cast_subtype:
arg2[idx] = enforce_subtype(a)
else:
raise TypeError(
"iter subtype is {}, desired subtype is {}, "
"but cast_subtype is set to False".format(type(arg2[idx]), enforce_subtype)
)
except Exception as e:
raise TypeError(
"iterator subtype is {} and cannot be casted to {}\n{}".format(type(a), enforce_subtype, e)
)
output = _enforce_iter_type(arg2, enforce_type)
output = _enforce_iter_type(output, enforce_type)
return output
def check_list1_in_list2(list1, list2, throw_error=True):
r"""
Verify if the list1 (iterator) is included in list2 (iterator). If not, raise an error.
Parameters:
list1, list2: list, tuple or object
A list or tuple containing the elements to verify the inclusion.
If an object is provided other than a list or tuple,
then it is considered as a list of a single element.
throw_error: bool
Whether to throw an error if list1 is not in list2
Returns:
list1_in_list2: bool
A boolean representing the inclusion of list1 in list2. It is returned if
throw_error is set to false
"""
list1 = check_arg_iterator(list1)
list2 = check_arg_iterator(list2)
# If all elements of list1 are not in list2, throw an error
list1_in_list2 = all(elem in list2 for elem in list1)
if not list1_in_list2 and throw_error:
raise ValueError(
("Elements in list1 should be contained in list2." + "\n\nlist1 = {} \n\n list2 = {}").format(
list1, list2
)
)
return list1_in_list2
def check_columns_choice(dataframe, columns_choice, extra_accepted_cols=None, enforce_type="list"):
r"""
Verify if the choice of column `columns_choice` is inside the dataframe or
the extra_accepted_cols. Otherwise, errors are thrown by the sub-functions.
Parameters:
dataframe: (pd.DataFrame)
The dataframe on which to verify if the column choice is valid.
columns_choice: str, iterator(str)
The columns chosen from the dataframe
extra_accepted_cols: str, iterator(str)
A list
enforce_type: str or type
The type to enforce the iterator. The valid choices are :
`None`, `list`, `tuple`, `'none'`, `'list'`, `'tuple'`.
If `None`, then the iterator type is not enforced.
Returns:
output: iterator
A str iterator based on the input of the desired type (list or tuple)
"""
extra_accepted_cols = [] if extra_accepted_cols is None else extra_accepted_cols
valid_columns = list(dataframe.columns)
kwargs_iterator = {
"enforce_type": enforce_type,
"enforce_subtype": None,
"cast_subtype": False,
}
columns_choice = check_arg_iterator(columns_choice, **kwargs_iterator)
extra_accepted_cols = check_arg_iterator(extra_accepted_cols, **kwargs_iterator)
valid_columns = check_arg_iterator(valid_columns, **kwargs_iterator)
valid_columns_full = valid_columns + extra_accepted_cols
check_list1_in_list2(columns_choice, valid_columns_full)
return columns_choice
================================================
FILE: graphium/utils/command_line_utils.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import re
from collections import defaultdict
from typing import List, Dict
def get_anchors_and_aliases(filepath):
"""Utility function to extract anchors and aliases from YAML config file
In a YAML file we can specify an anchor as `some_name: &anchor_name value`
This is then picked up with alises in the rest of the config as
`other_name: *anchor_name`
Using this format in the YAML file all anchors and aliases will be extracted
Args:
filepath (str): path to the config file
Returns:
anchors (Dict): A dictionary containing the YAML paths of the anchors
as keys, and a list of YAML paths to alises.
"""
anchors = defaultdict(list)
current_level = {}
anchor_to_path = {}
with open(filepath, "r") as file:
for line in file:
indent = len(line) - len(line.lstrip(" "))
key_match = re.search(r"(\w+):", line)
anchor_match = re.search(r"&(\w+)", line)
alias_match = re.search(r"\*(\w+)", line)
if key_match:
key = key_match.group(1)
# Compute the full path of the current key.
full_path = ".".join(
[current_level[i] for i in sorted(current_level.keys()) if i < indent] + [key]
)
current_level[indent] = key
# Remove any keys that are indented more than the current line.
keys_to_remove = [i for i in current_level if i > indent]
for i in keys_to_remove:
del current_level[i]
else:
full_path = ".".join([current_level[i] for i in sorted(current_level.keys())])
if anchor_match:
anchor = anchor_match.group(1)
anchor_to_path[anchor] = full_path
if alias_match:
alias = alias_match.group(1)
if alias in anchor_to_path:
anchors[anchor_to_path[alias]].append(full_path)
return anchors
def update_config(cfg: Dict, unknown: List, anchors: List):
"""
Update the configuration dictionary with command line arguments.
"""
for arg in unknown:
if arg.startswith("--"):
key, value = arg[2:].split("=")
if key in anchors.keys():
all_refs = anchors[key]
else:
all_refs = []
all_refs.append(key)
for key in all_refs:
keys = key.split(".")
temp_cfg = cfg
for k in keys[:-1]:
temp_cfg = temp_cfg[k]
temp_cfg[keys[-1]] = type(temp_cfg[keys[-1]])(value) if keys[-1] in temp_cfg else value
return cfg
================================================
FILE: graphium/utils/custom_lr.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import warnings
from torch.optim.lr_scheduler import _LRScheduler
class WarmUpLinearLR(_LRScheduler):
"""Custom linear learning rate with warmup.
Args:
optimizer (Optimizer): Wrapped optimizer.
max_num_epochs (int): Maximum number of epochs.
warmup_epochs (int): The number of epochs for learning rate warmup. Default: 0.
min_lr (float): the minimum learning rate value. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
"""
def __init__(self, optimizer, max_num_epochs, warmup_epochs=0, min_lr=0, last_epoch=-1, verbose=False):
self.max_num_epochs = max_num_epochs
self.warmup_epochs = warmup_epochs
self.min_lr = min_lr
self.last_epoch = last_epoch
super(WarmUpLinearLR, self).__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
UserWarning,
)
if self.warmup_epochs > 0 and (self.last_epoch + 1) < self.warmup_epochs:
return [(self.last_epoch + 1) * base_lr / self.warmup_epochs for base_lr in self.base_lrs]
else:
# check epoch_diff in case there is a division by zero error
epoch_diff = self.max_num_epochs - self.warmup_epochs
if epoch_diff <= 0:
factor = 0
else:
factor = ((self.last_epoch + 1) - self.warmup_epochs) / epoch_diff
return [factor * self.min_lr + (1 - factor) * base_lr for base_lr in self.base_lrs]
def _get_closed_form_lr(self):
lr_list = []
for base_lr in self.base_lrs:
if self.warmup_epochs > 0 and (self.last_epoch + 1) < self.warmup_epochs:
lr = (self.last_epoch + 1) * base_lr / self.warmup_epochs
else:
factor = ((self.last_epoch + 1) - self.warmup_epochs) / (
self.max_num_epochs - self.warmup_epochs
)
lr = factor * self.min_lr + (1 - factor) * base_lr
lr_list.append(lr)
return lr_list
================================================
FILE: graphium/utils/decorators.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
class classproperty(property):
r"""
Decorator used to declare a class property, defined for the class
without needing to instanciate an object.
!!! Example
``` python linenums="1"
@classproperty
def my_class_property(cls):
return 5
```
"""
def __get__(self, cls, owner):
return classmethod(self.fget).__get__(None, owner)()
================================================
FILE: graphium/utils/fs.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Union
from typing import Optional
import os
import io
import platformdirs
import pathlib
from tqdm.auto import tqdm
import fsspec
def get_cache_dir(suffix: str = None, create: bool = True) -> pathlib.Path:
"""Get a local cache directory. You can append a suffix folder to it and optionnaly create
the folder if it doesn't exist.
"""
cache_dir = pathlib.Path(platformdirs.user_cache_dir(appname="graphium"))
if suffix is not None:
cache_dir /= suffix
if create:
cache_dir.mkdir(exist_ok=True, parents=True)
return cache_dir
def get_mapper(path: Union[str, os.PathLike]):
"""Get the fsspec mapper.
Args:
path: a path supported by `fsspec` such as local, s3, gcs, etc.
"""
return fsspec.get_mapper(str(path))
def get_basename(path: Union[str, os.PathLike]):
"""Get the basename of a file or a folder.
Args:
path: a path supported by `fsspec` such as local, s3, gcs, etc.
"""
path = str(path)
mapper = get_mapper(path)
clean_path = path.rstrip(mapper.fs.sep)
return str(clean_path).split(mapper.fs.sep)[-1]
def get_extension(path: Union[str, os.PathLike]):
"""Get the extension of a file.
Args:
path: a path supported by `fsspec` such as local, s3, gcs, etc.
"""
basename = get_basename(path)
return basename.split(".")[-1]
def exists(path: Union[str, os.PathLike, fsspec.core.OpenFile, io.IOBase]):
"""Check whether a file exists.
Args:
path: a path supported by `fsspec` such as local, s3, gcs, etc.
"""
if isinstance(path, fsspec.core.OpenFile):
return path.fs.exists(path.path)
elif isinstance(path, (str, pathlib.Path)):
mapper = get_mapper(str(path))
return mapper.fs.exists(path)
else:
# NOTE(hadim): file-like objects always exist right?
return True
def exists_and_not_empty(path: Union[str, os.PathLike]):
"""Check whether a directory exists and is not empty."""
if not exists(path):
return False
fs = get_mapper(path).fs
return len(fs.ls(path)) > 0
def mkdir(path: Union[str, os.PathLike], exist_ok: bool = True):
"""Create directory including potential parents."""
fs = get_mapper(path).fs
fs.mkdirs(path, exist_ok=exist_ok)
def rm(path: Union[str, os.PathLike], recursive=False, maxdepth=None):
"""Delete a file or a directory with all nested files."""
fs = get_mapper(path).fs
fs.rm(path, recursive=recursive, maxdepth=maxdepth)
def join(*paths):
"""Join paths together. The first element determine the
filesystem to use (and so the separator.
Args:
paths: a list of paths supported by `fsspec` such as local, s3, gcs, etc.
"""
paths = [str(path) for path in paths]
source_path = paths[0]
fs = get_mapper(source_path).fs
full_path = fs.sep.join(paths)
return full_path
def get_size(file: Union[str, os.PathLike, io.IOBase, fsspec.core.OpenFile]) -> Optional[int]:
"""Get the size of a file given its path. Return None if the
size can't be retrieved.
"""
if isinstance(file, io.IOBase) and hasattr(file, "name"):
fs_local = fsspec.filesystem("file")
file_size = fs_local.size(getattr(file, "name"))
elif isinstance(file, (str, pathlib.Path)):
fs = get_mapper(str(file)).fs
file_size = fs.size(str(file))
elif isinstance(file, fsspec.core.OpenFile):
file_size = file.fs.size(file.path)
else:
file_size = None
return file_size
def copy(
source: Union[str, os.PathLike, io.IOBase, fsspec.core.OpenFile],
destination: Union[str, os.PathLike, io.IOBase, fsspec.core.OpenFile],
chunk_size: int = None,
force: bool = False,
progress: bool = False,
leave_progress: bool = True,
):
"""Copy one file to another location across different filesystem (local, S3, GCS, etc).
Args:
source: path or file-like object to copy from.
destination: path or file-like object to copy to.
chunk_size: the chunk size to use. If progress is enabled the chunk
size is `None`, it is set to 2048.
force: whether to overwrite the destination file it it exists.
progress: whether to display a progress bar.
leave_progress: whether to hide the progress bar once the copy is done.
"""
if progress and chunk_size is None:
chunk_size = 2048
if isinstance(source, (str, os.PathLike)):
source_file = fsspec.open(str(source), "rb")
else:
source_file = source
if isinstance(destination, (str, os.PathLike)):
# adapt the file mode of the destination depending on the source file.
destination_mode = "wb"
if hasattr(source_file, "mode"):
destination_mode = "wb" if "b" in getattr(source_file, "mode") else "w"
elif isinstance(source_file, io.BytesIO):
destination_mode = "wb"
elif isinstance(source_file, io.StringIO):
destination_mode = "w"
destination_file = fsspec.open(str(destination), destination_mode)
else:
destination_file = destination
if not exists(source_file):
raise ValueError(f"The file being copied does not exist: {source}")
if not force and exists(destination_file):
raise ValueError(f"The destination file to copy already exists: {destination}")
with source_file as source_stream:
with destination_file as destination_stream:
if chunk_size is None:
# copy without chunks
destination_stream.write(source_stream.read())
else:
# copy with chunks
# determine the size of the source file
source_size = None
if progress:
source_size = get_size(source)
# init progress bar
pbar = tqdm(
total=source_size,
leave=leave_progress,
disable=not progress,
unit="B",
unit_divisor=1024,
unit_scale=True,
)
# start the loop
while True:
data = source_stream.read(chunk_size)
if not data:
break
destination_stream.write(data)
pbar.update(chunk_size)
pbar.close()
================================================
FILE: graphium/utils/hashing.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import Any
import hashlib
import yaml
def get_md5_hash(object: Any) -> str:
"""
MD5 hash of any object.
The object is converted to a YAML string before being hashed.
This allows for nested dictionaries/lists and for hashing of classes and their attributes.
"""
dhash = hashlib.md5()
encoded = yaml.dump(object, sort_keys=True).encode()
dhash.update(encoded)
return dhash.hexdigest()
================================================
FILE: graphium/utils/moving_average_tracker.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from dataclasses import dataclass
@dataclass
class MovingAverageTracker:
num_samples: int = 0
mean_value: float = 0.0
def update(self, value: float):
self.mean_value = self.mean_value * (self.num_samples / (self.num_samples + 1)) + value / (
self.num_samples + 1
)
self.num_samples += 1
def reset(self):
# no need to update mean_value, it will be reset when update is called
self.num_samples = 0
================================================
FILE: graphium/utils/mup.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
##### Code adapted from the `mup` package from Microsoft https://github.com/microsoft/mup
from torch.nn import Linear
from torch.nn.modules.conv import _ConvNd
from mup import get_shapes, assert_hidden_size_inf, MuReadout, rescale_linear_bias, save_base_shapes
from mup.shape import _zip_infshape_dict, _extract_shapes
from graphium.nn.base_layers import MuReadoutGraphium
def apply_infshapes(model, infshapes):
"""
Modified from the regular `mup.apply_infshapes` by explicitly adding `base_dim` to the `MuReadoutGraphium`.
This allows the code to work on IPUs.
"""
for name, p in model.named_parameters():
p.infshape = infshapes[name]
for _, module in model.named_modules():
if isinstance(module, MuReadoutGraphium):
module.base_width = module.weight.infshape[-1].base_dim
def set_base_shapes(model, base, rescale_params=True, delta=None, savefile=None, do_assert=True):
"""Sets the `p.infshape` attribute for each parameter `p` of `model`.
Code taken from the `mup` package from Microsoft https://github.com/microsoft/mup.
No change except in the `apply_inf_shapes`, using the one from Graphium instead of `mup`
Inputs:
model: nn.Module instance
base: The base model.
Can be nn.Module, a dict of shapes, a str, or None.
If None, then defaults to `model`
If str, then treated as filename for yaml encoding of a dict of base shapes.
rescale_params:
assuming the model is initialized using the default pytorch init (or
He initialization etc that scale the same way with fanin): If True
(default), rescales parameters to have the correct (μP) variances.
do_assert:
Output:
same object as `model`, after setting the `infshape` attribute of each parameter.
"""
if base is None:
base = model
base_shapes = _extract_shapes(base)
if delta is not None:
delta_shapes = _extract_shapes(delta)
base_shapes = _zip_infshape_dict(base_shapes, delta_shapes)
shapes = get_shapes(model)
infshapes = _zip_infshape_dict(base_shapes, shapes)
if savefile is not None:
save_base_shapes(infshapes, savefile)
apply_infshapes(model, infshapes)
if do_assert:
assert_hidden_size_inf(model)
if rescale_params:
for name, module in model.named_modules():
if isinstance(module, MuReadout):
module._rescale_parameters()
elif isinstance(module, (Linear, _ConvNd)):
rescale_linear_bias(module)
return model
================================================
FILE: graphium/utils/packing.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from typing import List, Tuple, Iterable, Optional
import numpy as np
import torch
class MolPack:
"""
Class that keeps track of the number of atoms and indices that are added
to each pack. Useful when doing packing, or other forms of smart batching.
A pack is a batch, but with optimized memory consumption.
"""
def __init__(self):
self.num_nodes = 0
self.num_graphs = 0
self.average_atom = 0
self.indices = []
def add_mol(self, num_nodes: int, idx: int) -> "MolPack":
"""
Add a molecule and it's index to the batch
Parameters:
num_nodes: Number of atoms of the new molecule
idx: Index associated to the molecule
"""
self.num_nodes += num_nodes
self.num_graphs += 1
self.average_atom = self.num_nodes / self.num_graphs
self.indices.append(idx)
return self
def expected_atoms(self, remaining_mean_num_nodes: float, batch_size: int) -> float:
"""
Given a desired batch size, and given the remaining mean number of
atoms, find the expected number of atoms of the current batch when it is full
Parameters:
remaining_mean_num_nodes: Average number of atoms per molecule
left to be sampled and distributed across tasks.
batch_size: Desired batch size
Returns:
expected_atoms: The expected number of atoms in this batch if we
sample randomly the remaining molecules.
"""
return self.num_nodes + ((batch_size - self.num_graphs) * remaining_mean_num_nodes)
def __repr__(self) -> str:
"""
Print the main attributes of the current class
"""
return f"{self.__class__.__name__}(m: {self.num_graphs},\ta: {self.num_nodes},\tav: {self.average_atom:.1f})"
def smart_packing(num_nodes: List[int], batch_size: int) -> List[List[int]]:
"""
Simple and fast algorithm for packing graphs such that each batch has roughly the
same number of atoms.
Has for-loop scalability issues `O(num_graphs * ipu_batch_size)` = `O(num_graphs^2 / batch_size)`
Parameters:
num_nodes: List of the number of atoms per molecule for the entire global batch.
Must be of length `batch_size * ipu_batch_size`.
batch_size: The batch size per iteration, considering a single device and single
forward pass.
The global batch size is `batch_size * device_iterations * replication_factor * gradient_accumulation`
Returns:
packed_indices: A list of packs, each containing a list of indices, such that
if we collect `num_nodes` from the indices, then each pack has roughly the
same total number of atoms.
"""
# Sort the list
num_nodes = np.asarray(num_nodes)
argsort_num_nodes = np.argsort(num_nodes)
sorted_num_nodes = num_nodes[argsort_num_nodes]
ipu_batch_size = int(len(num_nodes) / batch_size)
sorted_num_nodes, initial_num_nodes = (
sorted_num_nodes[:-ipu_batch_size],
sorted_num_nodes[-ipu_batch_size:],
)
reverse_cumsum = np.sum(sorted_num_nodes) - np.cumsum(sorted_num_nodes) + sorted_num_nodes[-1]
# Start with the largest element in separate packs
mol_batches = [
MolPack().add_mol(initial_num_nodes[-ii - 1], argsort_num_nodes[-ii - 1])
for ii in range(ipu_batch_size)
]
# Loop from smallest to largest molecule, and add each molecule to the pack with smallest expected sum
for ii, num_atom in enumerate(sorted_num_nodes):
remaining_mean = reverse_cumsum[ii] / (len(sorted_num_nodes) - ii)
max_expected, idx_max_expected = 0, 0
for jj, m in enumerate(mol_batches):
if m.num_graphs >= batch_size:
continue
expected = m.num_nodes + (
(batch_size - m.num_graphs) * remaining_mean
) # Faster than calling m.expected_atoms
if expected > max_expected:
max_expected = expected
idx_max_expected = jj
mol_batches[idx_max_expected].add_mol(num_atom, argsort_num_nodes[ii])
packed_indices = [batch.indices for batch in mol_batches]
return packed_indices
def fast_packing(num_nodes: List[int], batch_size: int) -> List[List[int]]:
"""
Super fast algorithm for packing graphs such that each batch has roughly the
same number of atoms. Not as good as `smart_packing` but
faster and more scalable for-loop complexity of `O(batch_size)`.
Parameters:
num_nodes: List of the number of atoms per molecule for the entire global batch.
Must be of length `batch_size * ipu_batch_size`.
batch_size: The batch size per iteration, considering a single device and single
forward pass.
The global batch size is `batch_size * device_iterations * replication_factor * gradient_accumulation`
Returns:
packed_indices: A list of packs, each containing a list of indices, such that
if we collect `num_nodes` from the indices, then each pack has roughly the
same total number of atoms.
"""
num_nodes = np.asarray(num_nodes)
argsort_num_nodes = np.argsort(num_nodes)
ipu_batch_size = int(len(num_nodes) / batch_size)
packed_indices = np.stack(
[
np.random.permutation(argsort_num_nodes[ii * ipu_batch_size : (ii + 1) * ipu_batch_size])
for ii in range(batch_size)
],
axis=0,
).T.tolist()
return packed_indices
def hybrid_packing(num_nodes: List[int], batch_size: int) -> List[List[int]]:
"""
Uses a combination of the `smart_packing` `O(n^2)` on the most important data points,
and the `fast_packing` `O(n)` on the average-sized data points.
Depending on the expected complexity
Parameters:
num_nodes: List of the number of atoms per molecule for the entire global batch.
Must be of length `batch_size * ipu_batch_size`.
batch_size: The batch size per iteration, considering a single device and single
forward pass.
The global batch size is `batch_size * device_iterations * replication_factor * gradient_accumulation`
Returns:
packed_indices: A list of packs, each containing a list of indices, such that
if we collect `num_nodes` from the indices, then each pack has roughly the
same total number of atoms.
"""
# Determine the parameters based on the complexity of the smart-packing.
# The bigger the complexity, the more the `fast_packing` algorithm becomes
# statistically powerful, and the more speed benefits it provides.
smart_packing_complexity = len(num_nodes) ** 2 / batch_size
if smart_packing_complexity < 1e4:
return smart_packing(num_nodes=num_nodes, batch_size=batch_size)
elif smart_packing_complexity < 1e5:
big, small = 3, 6
else:
return fast_packing(num_nodes=num_nodes, batch_size=batch_size)
# Small datasets benefit from smart-packing, without compute burden
ipu_batch_size = int(len(num_nodes) / batch_size)
if len(num_nodes) < (big + small) * ipu_batch_size:
return smart_packing(num_nodes=num_nodes, batch_size=batch_size)
# Sort the list
num_nodes = np.asarray(num_nodes)
argsort_num_nodes = np.argsort(num_nodes)
# Smallest and biggest graphs are often outliers and will benefit from the `smart_packing`
biggest_graphs = argsort_num_nodes[-big * ipu_batch_size :]
smallest_graphs = argsort_num_nodes[: small * ipu_batch_size]
big_n_small_graphs = np.concatenate([biggest_graphs, smallest_graphs])
big_n_small_packs = smart_packing(num_nodes[big_n_small_graphs], batch_size=big + small)
big_n_small_indices = [big_n_small_graphs[pack] for pack in big_n_small_packs]
big_n_small_nodes = [num_nodes[pack] for pack in big_n_small_indices]
# Medium graphs will be packed faster
medium_graphs = argsort_num_nodes[small * ipu_batch_size : -big * ipu_batch_size]
medium_packs = fast_packing(num_nodes[medium_graphs], batch_size=batch_size - big - small)
medium_indices = [medium_graphs[pack] for pack in medium_packs]
medium_nodes = [num_nodes[pack] for pack in medium_indices]
# Pack the big/small with the medium in a smart way
big_n_small_sort = np.argsort(np.sum(np.stack(big_n_small_nodes, axis=1), axis=0))
medium_sort = np.argsort(np.sum(np.stack(medium_nodes, axis=1), axis=0))
packed_indices = [
np.concatenate([medium_indices[medium_sort[ii]], big_n_small_indices[big_n_small_sort[-ii]]])
for ii in range(len(medium_indices))
]
return packed_indices
def get_pack_sizes(packed_indices, num_nodes):
"""
Get the number of atoms of each pack
"""
pack_sums = []
for pack in packed_indices:
pack_sum = 0
for idx in pack:
pack_sum += num_nodes[idx]
pack_sums.append(pack_sum)
return pack_sums
def estimate_max_pack_node_size(num_nodes: Iterable[int], batch_size: int, combined_batch_size: int):
"""
Estimate the value of `max_num_nodes`, which represents the maximum number of nodes
needed in a batch to fit the data.
Parameters:
num_nodes: Number of nodes for all the graphs in the dataset
batch_size: The regular batch size per IPU
combined_batch_size: batch_size * device_iterations
* replication_factor * gradient_accumulation
"""
# Estimate the packing size needed
rand_indices = np.arange(len(num_nodes))
np.random.shuffle(rand_indices)
max_pack_size = 0
for ii in range(0, len(num_nodes), combined_batch_size):
this_indices = rand_indices[ii : ii + combined_batch_size]
choice = num_nodes[this_indices]
if len(choice) == combined_batch_size:
packed_indices = hybrid_packing(choice, batch_size)
max_pack_size = max(max_pack_size, max(get_pack_sizes(packed_indices, num_nodes[this_indices])))
max_pack_size_per_graph = max_pack_size / batch_size
return max_pack_size, max_pack_size_per_graph
def node_to_pack_indices_mask(
packed_indices: Iterable[Iterable[int]], all_num_nodes: Iterable[int], max_pack_size: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given a list of packed indices, and the number of nodes in each graph,
return a tensor of shape (sum(all_num_nodes), 2) where the first column
is the pack index, and the second column is the node index within the pack.
Can be used to generate a dense packing of the nodes as follows:
```
# node_features: A tensor of shape (num_nodes, num_node_features)
# num_packs: The number of packs desired
# max_nodes_per_pack: The maximum number of nodes per pack
# dense_pack: A tensor of shape (num_packs, max_nodes_per_pack, num_node_features)
dense_pack = torch.zeros([num_packs, max_nodes_per_pack, num_node_features])
dense_pack[pack_from_node_idx[:, 0], pack_from_node_idx[:, 1]] = node_features
```
This is useful when using a Transformer, to avoid wasteful padding when the
the longest sequence is much longer than the average sequence length.
Parameters:
packed_indices: A list of lists of graph indices, where each sub-list
represents a pack of graphs
all_num_nodes: The number of nodes in each graph
max_pack_size: The maximum number of nodes per pack. If None, will be
infered from the provided packs.
Useful to determine the shape of the `pack_attn_mask`.
Returns:
pack_from_node_idx: A tensor of shape (num_nodes, 2) where the first column
is the pack index, and the second column is the node index within the pack.
pack_attn_mask: A tensor of shape (num_packs, max_pack_size, max_pack_size),
that represents the attention masking for each pack,
such that the graphs in the pack are masked out from each other.
"""
all_num_nodes = torch.as_tensor(all_num_nodes, dtype=torch.long)
cumsum_num_nodes = torch.cumsum(all_num_nodes, dim=0)
if max_pack_size is None:
pack_sizes = get_pack_sizes(packed_indices, all_num_nodes)
max_pack_size = max(pack_sizes)
# Get the node indices associated to the packs, with 0 padding
pack_from_node_idx = torch.zeros(sum(all_num_nodes), 2, dtype=torch.long)
pack_attn_mask = [] # masks for the attention
for ii, pack in enumerate(packed_indices):
jj = 0 # Counter for the number of nodes in the pack
this_pack_attn_mask = torch.ones((max_pack_size, max_pack_size), dtype=torch.bool)
for graph_idx in pack:
num_nodes = all_num_nodes[graph_idx]
node_idx = torch.arange(cumsum_num_nodes[graph_idx] - num_nodes, cumsum_num_nodes[graph_idx])
this_pack_attn_mask[jj : jj + num_nodes, jj : jj + num_nodes] = False
pack_from_node_idx[node_idx, 0] = ii
pack_from_node_idx[node_idx, 1] = jj + torch.arange(num_nodes)
jj += num_nodes
pack_attn_mask.append(this_pack_attn_mask)
pack_attn_mask = torch.stack(pack_attn_mask, dim=0)
return pack_from_node_idx, pack_attn_mask
================================================
FILE: graphium/utils/safe_run.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from loguru import logger
import traceback as tb
class SafeRun:
def __init__(self, name: str, raise_error: bool = True, verbose: int = 2) -> None:
"""
Run some code with error handling and some printing, using the with statment.
Example:
In the example below, the `2+None`, an error will be caught and printed.
```
with SafeRun(name="Addition that fails", raise_error=False):
2 + None
```
Parameters:
name: Name of the code, used for printing
raise_error: Whether to raise an error, or to catch it and print instead
verbose: The level of verbosity
0: Do not print anything
1: Print only the traced stack when an error is caught and `raise_error` is False
2: Print headers and footers at the start and exit of the with statement.
"""
self.name = name
self.raise_error = raise_error
self.verbose = verbose
def __enter__(self):
"""
Print that the with-statement started, if `self.verbose >= 2`
"""
if self.verbose >= 2:
logger.info(f"\n------------ {self.name} STARTED ------------")
def __exit__(self, type, value, traceback):
"""
Handle the error. Raise it if `self.raise_error==True`, otherwise ignore it
and print it if `self.verbose >= 1`. Also print that the with-statement is
completed if `self.verbose >= 2`.
"""
if traceback is not None:
if self.raise_error:
if self.verbose >= 1:
logger.error(f"------------ {self.name} ERROR: ------------")
return False
else:
if self.verbose >= 1:
logger.error(f"------------ {self.name} ERROR: ------------\nERROR skipped. Traceback:\n")
logger.trace(print("".join(tb.format_exception(None, value, traceback))))
return True
else:
if self.verbose >= 2:
logger.info("\n------------ {self.name} COMPLETED ------------\n\n")
return True
================================================
FILE: graphium/utils/spaces.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
from copy import deepcopy
import torch
import torch.optim.lr_scheduler as sc
import torchmetrics.functional as TorchMetrics
import graphium.nn.base_layers as BaseLayers
import graphium.nn.ensemble_layers as EnsembleLayers
import graphium.nn.architectures as Architectures
import graphium.utils.custom_lr as CustomLR
import graphium.data.datamodule as Datamodules
import graphium.ipu.ipu_losses as IPULosses
import graphium.ipu.ipu_metrics as Metrics
import graphium.nn.pyg_layers as PygLayers
import graphium.nn.residual_connections as Residuals
import graphium.nn.encoders as Encoders
import graphium.trainer.losses as Losses
PE_ENCODERS_DICT = {
"laplacian_pe": Encoders.laplace_pos_encoder.LapPENodeEncoder,
"mlp": Encoders.mlp_encoder.MLPEncoder,
"signnet": Encoders.signnet_pos_encoder.SignNetNodeEncoder,
"gaussian_kernel": Encoders.gaussian_kernel_pos_encoder.GaussianKernelPosEncoder,
"bessel_kernel": Encoders.bessel_pos_encoder.BesselSphericalPosEncoder,
}
FC_LAYERS_DICT = {
"fc": BaseLayers.FCLayer,
}
ENSEMBLE_FC_LAYERS_DICT = {
"ens-fc": EnsembleLayers.EnsembleFCLayer,
}
PYG_LAYERS_DICT = {
"pyg:gcn": PygLayers.GCNConvPyg,
"pyg:gin": PygLayers.GINConvPyg,
"pyg:gine": PygLayers.GINEConvPyg,
"pyg:gated-gcn": PygLayers.GatedGCNPyg,
"pyg:pna-msgpass": PygLayers.PNAMessagePassingPyg,
"pyg:gps": PygLayers.GPSLayerPyg,
"pyg:dimenet": PygLayers.DimeNetPyg,
"pyg:mpnnplus": PygLayers.MPNNPlusPyg,
}
LAYERS_DICT = deepcopy(FC_LAYERS_DICT)
LAYERS_DICT.update(deepcopy(PYG_LAYERS_DICT))
ENSEMBLE_LAYERS_DICT = deepcopy(ENSEMBLE_FC_LAYERS_DICT)
RESIDUALS_DICT = {
"none": Residuals.ResidualConnectionNone,
"simple": Residuals.ResidualConnectionSimple,
"weighted": Residuals.ResidualConnectionWeighted,
"concat": Residuals.ResidualConnectionConcat,
"densenet": Residuals.ResidualConnectionDenseNet,
"random": Residuals.ResidualConnectionRandom,
}
LOSS_DICT = {
"bce": torch.nn.BCELoss,
"bce_logits": torch.nn.BCEWithLogitsLoss,
"mse": torch.nn.MSELoss,
"bce": torch.nn.BCELoss,
"l1": torch.nn.L1Loss,
"mae": torch.nn.L1Loss,
"hybrid_ce": Losses.HybridCELoss,
"bce_ipu": IPULosses.BCELossIPU,
"bce_logits_ipu": IPULosses.BCEWithLogitsLossIPU,
"mse_ipu": IPULosses.MSELossIPU,
"mae_ipu": IPULosses.L1LossIPU,
"l1_ipu": IPULosses.L1LossIPU,
"hybrid_ce_ipu": IPULosses.HybridCELossIPU,
}
SCHEDULER_DICT = {
"CosineAnnealingLR": sc.CosineAnnealingLR,
"CosineAnnealingWarmRestarts": sc.CosineAnnealingWarmRestarts,
"CyclicLR": sc.CyclicLR,
"ExponentialLR": sc.ExponentialLR,
"LambdaLR": sc.LambdaLR,
"MultiStepLR": sc.MultiStepLR,
"ReduceLROnPlateau": sc.ReduceLROnPlateau,
"StepLR": sc.StepLR,
"ConstantLR": sc.ConstantLR,
"WarmUpLinearLR": CustomLR.WarmUpLinearLR,
}
METRICS_CLASSIFICATION = {
"accuracy": TorchMetrics.accuracy,
"averageprecision": TorchMetrics.average_precision,
"auroc": TorchMetrics.auroc,
"confusionmatrix": TorchMetrics.confusion_matrix,
"f1": TorchMetrics.f1_score,
"fbeta": TorchMetrics.fbeta_score,
"precisionrecallcurve": TorchMetrics.precision_recall_curve,
"precision": TorchMetrics.precision,
"recall": TorchMetrics.recall,
"mcc": TorchMetrics.matthews_corrcoef,
"auroc_ipu": Metrics.auroc_ipu,
"accuracy_ipu": Metrics.accuracy_ipu,
"average_precision_ipu": Metrics.average_precision_ipu,
"f1_ipu": Metrics.f1_score_ipu,
"fbeta_ipu": Metrics.fbeta_score_ipu,
"precision_ipu": Metrics.precision_ipu,
"recall_ipu": Metrics.recall_ipu,
}
METRICS_REGRESSION = {
"mae": TorchMetrics.mean_absolute_error,
"mape": TorchMetrics.mean_absolute_percentage_error,
"mse": TorchMetrics.mean_squared_error,
"msle": TorchMetrics.mean_squared_log_error,
"pearsonr": TorchMetrics.pearson_corrcoef,
"spearmanr": TorchMetrics.spearman_corrcoef,
"r2_score": TorchMetrics.r2_score,
"cosine": TorchMetrics.cosine_similarity,
"pearsonr_ipu": Metrics.pearson_ipu,
"spearmanr_ipu": Metrics.spearman_ipu,
"r2_score_ipu": Metrics.r2_score_ipu,
"mae_ipu": Metrics.mean_absolute_error_ipu,
"mse_ipu": Metrics.mean_squared_error_ipu,
}
METRICS_DICT = deepcopy(METRICS_CLASSIFICATION)
METRICS_DICT.update(METRICS_REGRESSION)
DATAMODULE_DICT = {
"GraphOGBDataModule": Datamodules.GraphOGBDataModule,
"MultitaskFromSmilesDataModule": Datamodules.MultitaskFromSmilesDataModule,
"ADMETBenchmarkDataModule": Datamodules.ADMETBenchmarkDataModule,
"FakeDataModule": Datamodules.FakeDataModule,
}
GRAPHIUM_PRETRAINED_MODELS_DICT = {
"dummy-pretrained-model": "tests/dummy-pretrained-model.ckpt", # dummy model used for testing purposes
}
FINETUNING_HEADS_DICT = {
"mlp": Architectures.FeedForwardNN,
"gnn": Architectures.FeedForwardPyg,
"task_head": Architectures.TaskHeads,
"ens-mlp": Architectures.EnsembleFeedForwardNN,
}
================================================
FILE: graphium/utils/tensor.py
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from typing import Iterable, List, Union, Any, Callable, Dict
from inspect import getfullargspec
from copy import copy, deepcopy
from loguru import logger
from rdkit.Chem import AllChem
import torch
from torch import Tensor
def save_im(im_dir, im_name: str, ext: List[str] = ["svg", "png"], dpi: int = 600) -> None:
if not os.path.exists(im_dir):
if im_dir[-1] not in ["/", "\\"]:
im_dir += "/"
os.makedirs(im_dir)
if isinstance(ext, str):
ext = [ext]
full_name = os.path.join(im_dir, im_name)
for this_ext in ext:
plt.savefig(f"{full_name}.{this_ext}", dpi=dpi, bbox_inches="tight", pad_inches=0)
def is_dtype_torch_tensor(dtype: Union[np.dtype, torch.dtype]) -> bool:
r"""
Verify if the dtype is a torch dtype
Parameters:
dtype: dtype
The dtype of a value. E.g. np.int32, str, torch.float
Returns:
A boolean saying if the dtype is a torch dtype
"""
return isinstance(dtype, torch.dtype) or (dtype == Tensor)
def is_dtype_numpy_array(dtype: Union[np.dtype, torch.dtype]) -> bool:
r"""
Verify if the dtype is a numpy dtype
Parameters:
dtype: dtype
The dtype of a value. E.g. np.int32, str, torch.float
Returns:
A boolean saying if the dtype is a numpy dtype
"""
is_torch = is_dtype_torch_tensor(dtype)
is_num = dtype in (int, float, complex)
if hasattr(dtype, "__module__"):
is_numpy = dtype.__module__ == "numpy"
else:
is_numpy = False
return (is_num or is_numpy) and not is_torch
def one_of_k_encoding(val: Any, classes: Iterable[Any]) -> List[int]:
r"""Converts a single value to a one-hot vector.
Parameters:
val: int
class to be converted into a one hot vector
(integers from 0 to num_classes).
num_classes: iterator
a list or 1D array of allowed
choices for val to take
Returns:
A list of length len(num_classes) + 1
"""
encoding = [False] * (len(classes) + 1)
found = False
for i, v in enumerate(classes):
if v == val:
encoding[i] = True
found = True
break
if not found:
encoding[-1] = True
return encoding
def is_device_cuda(device: torch.device, ignore_errors: bool = False) -> bool:
r"""Check wheter the given device is a cuda device.
Parameters:
device: str, torch.device
object to check for cuda
ignore_errors: bool
Whether to ignore the error if the device is not recognized.
Otherwise, ``False`` is returned in case of errors.
Returns:
is_cuda: bool
"""
if ignore_errors:
is_cuda = False
try:
is_cuda = torch.device(device).type == "cuda"
except:
pass
else:
is_cuda = torch.device(device).type == "cuda"
return is_cuda
def nan_mean(input: Tensor, *args, **kwargs) -> Tensor:
r"""
Return the mean of all elements, while ignoring the NaNs.
Parameters:
input: The input tensor.
dim (int or tuple(int)): The dimension or dimensions to reduce.
keepdim (bool): whether the output tensor has dim retained or not.
dtype (torch.dtype, optional):
The desired data type of returned tensor.
If specified, the input tensor is casted to dtype before the operation is performed.
This is useful for preventing data type overflows. Default: None.
Returns:
output: The resulting mean of the tensor
"""
sum = torch.nansum(input, *args, **kwargs)
num = torch.sum(~torch.isnan(input), *args, **kwargs)
mean = sum / num
return mean
def nan_median(input: Tensor, **kwargs) -> Tensor:
r"""
Return the median of all elements, while ignoring the NaNs.
Contrarily to `torch.nanmedian`, this function supports a list
of dimensions, or `dim=None`, and does not return the index of the median
Parameters:
input: The input tensor.
dim (int or tuple(int)): The dimension or dimensions to reduce.
keepdim (bool): whether the output tensor has dim retained or not.
dtype (torch.dtype, optional):
The desired data type of returned tensor.
If specified, the input tensor is casted to dtype before the operation is performed.
This is useful for preventing data type overflows. Default: None.
Returns:
output: The resulting median of the tensor.
Contrarily to `torch.median`, it does not return the index of the median
"""
dim = kwargs.pop("dim", None)
keepdim = kwargs.pop("keepdim", False)
if isinstance(dim, Iterable) and not isinstance(dim, str):
dim = list(dim)
dim.sort()
# Implement the median for a list of dimensions
for d in dim:
input = input.unsqueeze(-1)
input = input.transpose(d, -1)
input = input.flatten(-len(dim))
median, _ = torch.nanmedian(input, dim=-1, keepdim=False)
if not keepdim:
for d in dim[::-1]:
median = median.squeeze(d)
else:
if dim is None:
median = torch.nanmedian(input.flatten().float()).to(input.dtype)
else:
median, _ = torch.nanmedian(input, dim=dim, keepdim=keepdim)
return median
def nan_var(input: Tensor, unbiased: bool = True, **kwargs) -> Tensor:
r"""
Return the variace of all elements, while ignoring the NaNs.
If unbiased is True, Bessel’s correction will be used.
Otherwise, the sample deviation is calculated, without any correction.
Parameters:
input: The input tensor.
unbiased: whether to use Bessel’s correction (δN=1\delta N = 1δN=1).
dim (int or tuple(int)): The dimension or dimensions to reduce.
keepdim (bool): whether the output tensor has dim retained or not.
dtype (torch.dtype, optional):
The desired data type of returned tensor.
If specified, the input tensor is casted to dtype before the operation is performed.
This is useful for preventing data type overflows. Default: None.
Returns:
output: The resulting variance of the tensor
"""
mean_kwargs = deepcopy(kwargs)
mean_kwargs.pop("keepdim", None)
dim = mean_kwargs.pop("dim", [ii for ii in range(input.ndim)])
mean = nan_mean(input, dim=dim, keepdim=True, **mean_kwargs)
dist = input - mean
dist2 = dist * dist
var = nan_mean(dist2, **kwargs)
if unbiased:
num = torch.sum(~torch.isnan(input), **kwargs)
var = var * num / (num - 1)
return var
def nan_std(input: Tensor, unbiased: bool = True, **kwargs) -> Tensor:
r"""
Return the standard deviation of all elements, while ignoring the NaNs.
If unbiased is True, Bessel’s correction will be used.
Otherwise, the sample deviation is calculated, without any correction.
Parameters:
input: The input tensor.
unbiased: whether to use Bessel’s correction (δN=1\delta N = 1δN=1).
dim (int or tuple(int)): The dimension or dimensions to reduce.
keepdim (bool): whether the output tensor has dim retained or not.
dtype (torch.dtype, optional):
The desired data type of returned tensor.
If specified, the input tensor is casted to dtype before the operation is performed.
This is useful for preventing data type overflows. Default: None.
Returns:
output: The resulting standard deviation of the tensor
"""
variances = nan_var(input=input, unbiased=unbiased, **kwargs)
return torch.sqrt(variances.float()).to(variances.dtype)
def nan_mad(input: Tensor, normal: bool = True, **kwargs) -> Tensor:
r"""
Return the median absolute deviation of all elements, while ignoring the NaNs.
Parameters:
input: The input tensor.
normal: whether to multiply the result by 1.4826 to mimic the
standard deviation for normal distributions.
dim (int or tuple(int)): The dimension or dimensions to reduce.
keepdim (bool): whether the output tensor has dim retained or not.
dtype (torch.dtype, optional):
The desired data type of returned tensor.
If specified, the input tensor is casted to dtype before the operation is performed.
This is useful for preventing data type overflows. Default: None.
Returns:
output: The resulting median absolute deviation of the tensor
"""
median_kwargs = deepcopy(kwargs)
median_kwargs.pop("keepdim", None)
dim = median_kwargs.pop("dim", [ii for ii in range(input.ndim)])
median = nan_median(input, dim=dim, keepdim=True, **median_kwargs)
dist = (input - median).abs()
mad = nan_median(dist, **kwargs)
if normal:
mad = mad * 1.4826
return mad
class ModuleWrap(torch.nn.Module):
r"""
Wrap a function into a `torch.nn.Module`, with possible `*args` and `**kwargs`
Parameters:
func: function to wrap into a module
"""
def __init__(self, func, *args, **kwargs) -> None:
super().__init__()
self.func = func
self.__name__ = f"ModuleWrap({self.func.__name__})"
self.args = args
self.kwargs = kwargs
def forward(self, *args, **kwargs):
"""
Calls the function `self.func` with the arguments `self.func(*self.args, *args, **self.kwargs, **kwargs)`
"""
return self.func(*self.args, *args, **self.kwargs, **kwargs)
def __repr__(self):
return self.__name__
class ModuleListConcat(torch.nn.ModuleList):
r"""
A list of neural modules similar to `torch.nn.ModuleList`,
but where the modules are applied on the same input and
concatenated together, instead of being applied sequentially.
Parameters:
dim: The dimension for the concatenation
"""
def __init__(self, dim: int = -1):
super().__init__()
self.dim = dim
def forward(self, *args, **kwargs) -> Tensor:
r"""
Apply all layers on the `args` and `kwargs`, and concatenate
their output alongside the dimension `self.dim`.
"""
h = []
for module in self:
h.append(module.forward(*args, **kwargs))
return torch.cat(h, dim=self.dim)
def parse_valid_args(param_dict, fn):
r"""
Check if a function takes the given argument.
Parameters
----------
fn: func
The function to check the argument.
param_dict: dict
Dictionary of the argument.
Returns
-------
param_dict: dict
Valid paramter dictionary for the given fucntions.
"""
param_dict_cp = copy(param_dict)
for key, param in param_dict.items():
if not arg_in_func(fn=fn, arg=key):
logger.warning(
f"{key} is not an available argument for {fn.__name__}, and is ignored by default."
)
param_dict_cp.pop(key)
return param_dict_cp
def arg_in_func(fn, arg):
r"""
Check if a function takes the given argument.
Parameters
----------
fn: func
The function to check the argument.
arg: str
The name of the argument.
Returns
-------
res: bool
True if the function contains the argument, otherwise False.
"""
fn_args = getfullargspec(fn)
return (fn_args.varkw is not None) or (arg in fn_args[0])
def tensor_fp16_to_fp32(tensor: Tensor) -> Tensor:
r"""Cast a tensor from fp16 to fp32 if it is in fp16
Parameters:
tensor: A tensor. If it is in fp16, it will be casted to fp32
Returns:
tensor: The tensor casted to fp32 if it was in fp16
"""
if tensor.dtype == torch.float16:
return tensor.to(dtype=torch.float32)
return tensor
def dict_tensor_fp16_to_fp32(
dict_tensor: Union[Tensor, Dict[str, Tensor], Dict[str, Dict[str, Tensor]]],
) -> Union[Tensor, Dict[str, Tensor], Dict[str, Dict[str, Tensor]]]:
r"""Recursively Cast a dictionary of tensors from fp16 to fp32 if it is in fp16
Parameters:
dict_tensor: A recursive dictionary of tensors. To be casted to fp32 if it was in fp16.
Returns:
dict_tensor: The recursive dictionary of tensors casted to fp32 if it was in fp16
"""
if isinstance(dict_tensor, dict):
for key, value in dict_tensor.items():
dict_tensor[key] = dict_tensor_fp16_to_fp32(value)
else:
dict_tensor = tensor_fp16_to_fp32(dict_tensor)
return dict_tensor
================================================
FILE: install_ipu.sh
================================================
"""
--------------------------------------------------------------------------------
Copyright (c) 2023 Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file.
Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without
warranties of any kind.
Graphcore Limited is not liable
for any damages arising from its use.
Refer to the LICENSE file for the full terms and conditions.
--------------------------------------------------------------------------------
"""
#!/bin/bash
# Default location for the virtual environment
default_venv_name=".graphium_ipu"
# Allow the user to specify the location of their virtual environment
# If not specified, use the default location
venv_name=${1:-$default_venv_name}
# Constants
sdk_compressed_file="poplar_sdk-ubuntu_20_04-3.3.0-208993bbb7.tar.gz"
sdk_wheel_file="poptorch-3.3.0+113432_960e9c294b_ubuntu_20_04-cp38-cp38-linux_x86_64.whl"
sdk_url="https://downloads.graphcore.ai/direct?package=poplar-poplar_sdk_ubuntu_20_04_3.3.0_208993bbb7-3.3.0&file=${sdk_compressed_file}"
sdk_path="${venv_name}/poplar_sdk-ubuntu_20_04-3.3.0+1403-208993bbb7"
# Check for Python3 and pip
if ! command -v python3 &>/dev/null; then
echo "Python3 is required but it's not installed. Exiting."
exit 1
fi
if ! command -v pip3 &>/dev/null; then
echo "pip3 is required but it's not installed. Exiting."
exit 1
fi
# Remove existing venv directory if it exists
if [[ -d $venv_name ]]; then
echo "Removing existing virtual environment directory..."
rm -rf $venv_name
fi
# Create the virtual environment
echo "Creating virtual environment..."
mkdir -p $venv_name
python3 -m venv $venv_name
source $venv_name/bin/activate
# Update pip to the latest version
echo "Upgrading pip..."
python3 -m pip install --upgrade pip
# Download the Poplar SDK
echo "Downloading Poplar SDK..."
wget -q -O "${venv_name}/${sdk_compressed_file}" "$sdk_url"
# Check the wget exit status
if [ $? -ne 0 ]; then
echo "Failed to download Poplar SDK. Exiting."
exit 1
fi
# Unzip the SDK file
echo "Extracting Poplar SDK..."
tar -xzf "$venv_name/$sdk_compressed_file" -C $venv_name
# Install the PopTorch wheel
echo "Installing PopTorch..."
python3 -m pip install "${sdk_path}/${sdk_wheel_file}"
# Enable Poplar SDK (including Poplar and PopART)
echo "Enabling Poplar SDK..."
source ${sdk_path}/enable
# Install the IPU specific and Graphium requirements
echo "Installing IPU specific and Graphium requirements..."
python3 -m pip install -r requirements_ipu.txt
# Install Graphium in dev mode
echo "Installing Graphium in dev mode..."
python3 -m pip install --no-deps -e .
# This is a quick test make sure poptorch is correctly installed
if python3 -c "import poptorch;print('poptorch installed correctly')" &> /dev/null; then
echo "Installation completed successfully."
else
echo "Installation was not successful. Please check the logs and try again."
exit 1 # Exit with status code 1 to indicate failure
fi
# Download the datafiles (Total ~ 10Mb - nothing compared to the libraries)
echo "Downloading the sub-datasets consisting on the ToyMix dataset"
toymix_dir=expts/data/neurips2023/small-dataset/
mkdir -p $toymix_dir
base_url="https://storage.valencelabs.com/graphium/datasets/neurips_2023/Small-dataset/"
files=("ZINC12k.csv.gz" "Tox21-7k-12-labels.csv.gz" "qm9.csv.gz" "qm9_random_splits.pt" "Tox21_random_splits.pt" "ZINC12k_random_splits.pt")
for file in "${files[@]}"; do
if [ ! -f "${toymix_dir}${file}" ]; then
echo "Downloading ${file}..."
wget -P "${toymix_dir}" "${base_url}${file}"
else
echo "${file} already exists. Skipping..."
fi
done
echo "Data has been successfully downloaded."
================================================
FILE: mkdocs.yml
================================================
site_name: "graphium"
site_description: "Graphium: Scaling molecular GNNs to infinity."
site_url: "https://github.com/datamol-io/graphium"
repo_url: "https://github.com/datamol-io/graphium"
repo_name: "valence-discovery/graphium"
copyright: Copyright 2020 - 2023 datamol.io
remote_branch: "gh-pages"
use_directory_urls: false
docs_dir: "docs"
nav:
- Overview: index.md
- Baseline: baseline.md
- API:
- graphium.nn:
- graphium.nn: api/graphium.nn/graphium.nn.md
- graphium.nn.architectures: api/graphium.nn/architectures.md
- graphium.nn.encoders: api/graphium.nn/encoders.md
- graphium.nn.pyg_layers: api/graphium.nn/pyg_layers.md
- graphium.features: api/graphium.features.md
- graphium.trainer: api/graphium.trainer.md
- graphium.data: api/graphium.data.md
- graphium.utils: api/graphium.utils.md
- graphium.config: api/graphium.config.md
- graphium.ipu: api/graphium.ipu.md
- graphium.finetuning: api/graphium.finetuning.md
- Tutorials:
- feature_processing:
- Add new positional encoding: tutorials/feature_processing/add_new_positional_encoding.ipynb
- Convert CSV to Parquet files: tutorials/feature_processing/csv_to_parquet.ipynb
- Timing parallel processing: tutorials/feature_processing/timing_parallel.ipynb
- gnn_layers:
- Add new gnn layers: tutorials/gnn/add_new_gnn_layers.ipynb
- Making GNN networks: tutorials/gnn/making_gnn_networks.ipynb
- Using GNN layers: tutorials/gnn/using_gnn_layers.ipynb
- model_training:
- Simple Molecular Model: tutorials/model_training/simple-molecular-model.ipynb
- Training on IPU: tutorials/model_training/running-multitask-ipu.ipynb
- Design: design.md
- Datasets: datasets.md
- Pretrained Models: pretrained_models.md
- CLI:
- About: cli/reference.md
- Commands:
- graphium: cli/graphium.md
- graphium-train: cli/graphium-train.md
- Contribute: contribute.md
- License: license.md
theme:
name: material
# NOTE(hadim): to customize the material primary and secondary
# color check `docs/_assets/css/custom-graphium.css`.
features:
- navigation.tabs
- navigation.expand
favicon: images/logo-mini-dark.png
logo: images/logo-mini-white.svg
extra_css:
- _assets/css/custom.css
- _assets/css/custom-graphium.css
extra_javascript:
- https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
- _assets/js/google-analytics.js
markdown_extensions:
- admonition
- markdown_include.include
- pymdownx.emoji
- pymdownx.magiclink
- pymdownx.superfences
- pymdownx.tabbed
- pymdownx.tasklist
- pymdownx.details
- pymdownx.arithmatex:
generic: true
- toc:
permalink: true
watch:
- graphium/
plugins:
- search
- mkdocstrings:
handlers:
python:
setup_commands:
- import sys
- sys.path.append("docs")
- sys.path.append("graphium")
options:
new_path_syntax: yes
show_root_heading: yes
heading_level: 3
show_source: false
- mkdocs-jupyter:
execute: False
- mike:
version_selector: true
extra:
version:
# Multi versioning provider for mkdocs-material (used for the JS selector)
provider: mike
================================================
FILE: notebooks/compare-pretraining-finetuning-performance.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import fsspec\n",
"import yaml\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from tqdm import tqdm\n",
"from typing import Literal\n",
"from dataclasses import dataclass, field\n",
"from collections import defaultdict\n",
"from graphium.utils import fs"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"ROOT_DIR = \"gs://graphium-private/pretrained-models/ToyMix/cas\"\n",
"\n",
"PT_TASKS = [\"qm9\", \"tox21\", \"zinc\"]\n",
"PT_FT_RELS = {\n",
" \"toymix_gcn_1\": {\"caco2\": \"finetuning_caco2_wang_gcn_1\", \"lipophilicity\": \"finetuning_lipophilicity_astrazeneca_gcn_1\"},\n",
" \"toymix_gcn_2\": {\"caco2\": \"finetuning_caco2_wang_gcn_2\", \"lipophilicity\": \"finetuning_lipophilicity_astrazeneca_gcn_2\"},\n",
"}\n",
"FT_METRICS = {\"caco2\": \"graph_caco2_wang/mae/test\", \"lipophilicity\": \"graph_lipophilicity_astrazeneca/r2_score/test\"}"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class FinetuningResult: \n",
" name: str\n",
" scores: dict[str, list[float]] = field(default_factory=lambda: defaultdict(list))\n",
"\n",
" def best(self, metric, minimize: bool = False):\n",
" return max(self.scores[metric]) if not minimize else min(self.scores[metric])\n",
"\n",
"\n",
"@dataclass\n",
"class PretrainingResult: \n",
" name: str\n",
" loss: dict[Literal[\"qm9\", \"zinc\", \"tox21\", \"all\"], float] = field(default_factory=dict)\n",
" ft_results: dict[str, FinetuningResult] = field(default_factory=dict)\n",
"\n",
" @property\n",
" def finetuning_tasks(self):\n",
" return sorted(list(self.ft_results.keys()))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/cas.wognum/micromamba/envs/graphium/lib/python3.10/site-packages/google/auth/_default.py:78: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK without a quota project. You might receive a \"quota exceeded\" or \"API not enabled\" error. See the following page for troubleshooting: https://cloud.google.com/docs/authentication/adc-troubleshooting/user-creds. \n",
" warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)\n",
"100%|██████████| 2/2 [00:20<00:00, 10.05s/it]\n"
]
}
],
"source": [
"globber, _ = fsspec.core.url_to_fs(ROOT_DIR)\n",
"\n",
"results = {}\n",
"\n",
"for pt_dir, ft_dirs in tqdm(PT_FT_RELS.items()):\n",
" \n",
" # Create a new results object\n",
" results[pt_dir] = PretrainingResult(name=pt_dir)\n",
"\n",
" # Parse the pre-training results\n",
" pt_results_path = fs.join(ROOT_DIR, pt_dir, \"results\", \"test_results.yaml\")\n",
" with fsspec.open(pt_results_path, \"r\") as f:\n",
" pt_results = yaml.safe_load(f)\n",
" results[pt_dir].loss = {k: pt_results[f\"graph_{k}/loss/test\"] for k in PT_TASKS}\n",
" results[pt_dir].loss[\"all\"] = pt_results[\"loss/test\"]\n",
"\n",
" # Parse the associated fine-tuning results\n",
" for ft_label, ft_dir in ft_dirs.items():\n",
"\n",
" # Create a new results object\n",
" ft_results = FinetuningResult(name=ft_label)\n",
" \n",
" # Find all results for all trials\n",
" ft_results_pattern = fs.join(ROOT_DIR, ft_dir, \"**\", \"test_results.yaml\")\n",
" paths = globber.glob(ft_results_pattern)\n",
"\n",
" # Save all scores\n",
" for path in paths: \n",
" with globber.open(path, \"r\") as f:\n",
" data = yaml.safe_load(f)\n",
" for k, v in data.items():\n",
" ft_results.scores[k].append(v)\n",
" \n",
" # Save the finetuning results to the pre-training results\n",
" results[pt_dir].ft_results[ft_label] = ft_results\n",
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def draw_boxplot(results: dict[str, FinetuningResult], fine_tuning_task: str, metric_label: str = None, loss_label: str = \"all\", ax=None):\n",
" if ax is None: \n",
" _, ax = plt.subplots()\n",
" if metric_label is None:\n",
" metric_label = FT_METRICS[fine_tuning_task]\n",
" for pt_label, pt_results in results.items():\n",
" positions = [round(pt_results.loss[loss_label], 3)]\n",
" data = pt_results.ft_results[fine_tuning_task].scores[metric_label]\n",
" ax.boxplot(data, positions=positions)\n",
" ax.set_xlabel(f\"Pre-training loss on {loss_label}\")\n",
" ax.set_ylabel(metric_label) "
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABKUAAAGGCAYAAACqvTJ0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABT3klEQVR4nO3dfXzO9f////sxY4ZtcpazNeZsc5KclBhKqFRC75KKkPq+lXKWTsS7dEZnIokiSQmdUd4V8hZJEs2kNKxYpMlZzFnTtufvj36OT2sbx/HaseP13Ha7Xi67vNvreO11PI9jHbf363h0HMc8xhgjAAAAAAAAIIhC3F4AAAAAAAAASh6GUgAAAAAAAAg6hlIAAAAAAAAIOoZSAAAAAAAACDqGUgAAAAAAAAg6hlIAAAAAAAAIOoZSAAAAAAAACDqGUgAAAAAAAAi6ULcXUBDZ2dn69ddfFRERIY/H4/ZyADhkjNHRo0dVs2ZNhYQUn1k5jQKKh+LYKPoEFA/0CYCtfO1TkR5K/frrr4qOjnZ7GQACZPfu3apdu7bbywgYGgUUL8WpUfQJKF7oEwBbna1PRXooFRERIemvGxkZGenyagA4lZ6erujoaO9jurigUUDxUBwbRZ+A4oE+AbCVr30q0kOp0y/njIyMJFhAMVDcXqJNo4DipTg1ij4BxQt9AmCrs/WpeLzxGAAAAAAAAEUKQykAAAAAAAAEHUMpAAAAAAAABB1DKQAAAAAAAAQdQykAAAAAAAAEHUMpAAAAAAAABB1DKQAAAAAAAAQdQykAAAAAAAAEHUMpAAAAAAAABB1DKQAAAAAAAAQdQykAAAAAAAAEXajbCwCKkhMnTmjr1q25tp88eVKpqamqU6eOwsPDc10eFxencuXKBWOJAIohp+2R6A8Ad3DOBMBW9MkuDKUAP2zdulWtWrXy++cSExPVsmXLQlgRgJLAaXsk+gPAHZwzAbAVfbILQynAD3FxcUpMTMy1PTk5WX379tXcuXMVHx+f588BgFNO23P6ZwEg2DhnAmAr+mQXhlKAH8qVK3fG6Xh8fDzTcwABR3sAFDV0C4Ct6JNd+KBzAAAAAAAABB1DKQAAAAAAAAQdQykAAAAAAAAEHUMpAAAAAAAABB1DKQAAAAAAAAQdQykAAAAAAAAEHUMpAAAAAAAABB1DKQAAAAAAAAQdQykAAAAAAAAEHUMpAAAAAAAABB1DKQAAAAAAAAQdQykAAAAAAAAEHUMpAAAAAAAABB1DKQAAAAAAAASd60OpPXv2qG/fvqpcubLKlSunCy64QImJiW4vCwDoEwBr0ScAtqJPAPwR6uaV//7770pISFCnTp20ZMkSVatWTT/99JMqVqzo5rIAgD4BsBZ9AmAr+gTAX64OpZ5++mlFR0dr9uzZ3m116tRxb0EA8P+jTwBsRZ8A2Io+AfCXq2/fW7x4sVq3bq0bbrhB1apVU4sWLTRz5sx898/IyFB6enqOLwAoDP72SaJRAIKDPgGwFX0C4C9Xh1I7duzQ9OnT1aBBAy1btkyDBw/W0KFD9cYbb+S5/4QJExQVFeX9io6ODvKKAZQU/vZJolEAgoM+AbAVfQLgL48xxrh15WXKlFHr1q21du1a77ahQ4dqw4YN+uqrr3Ltn5GRoYyMDO/36enpio6O1pEjRxQZGRmUNQN52bhxo1q1aqXExES1bNnS7eUUOenp6YqKirLqsexvnyQaheCjPcFhW6PoE4oyuhVY9AkIHPoUWL72ydVXStWoUUONGzfOsS0+Pl67du3Kc/+wsDBFRkbm+AKAwuBvnyQaBSA46BMAW9EnAP5y9YPOExIStG3bthzbtm/frpiYGJdWVHydOHFCW7duzbX95MmTSk1NVZ06dRQeHp7r8ri4OJUrVy4YSwSsQp/skV+/JBqGkok+ucPpuZREi1By0Cf30CgUVa4OpUaMGKF27dpp/Pjx6t27t9avX68ZM2ZoxowZbi6rWNq6datatWrl98/x0kWUVPTJHk77JdEwFE/0yR20CDg7+uQeGoWiytWh1IUXXqhFixZp9OjReuyxx1S3bl1NnjxZt9xyi5vLKpbi4uKUmJiYa3tycrL69u2ruXPnKj4+Ps+fA0oi+mSP/Pol0TCUTPTJHU7PpU7/LFAS0Cf30CgUVa4OpSTpmmuu0TXXXOP2Moq9cuXKnXH6HR8fz3Qc+Af6ZIez9UuiYSh56FPwcS4F+IY+uYNGoahy9YPOAQAAAAAAUDIxlAIAAAAAAEDQMZQCAAAAAABA0DGUAgAAAAAAQNAxlAIAAAAAAEDQMZQCAAAAAABA0DGUAgAAAAAAQNAxlAIAAAAAAEDQMZQCAAAAAABA0DGUAgAAAAAAQNAxlAIAAAAAAEDQMZQCAAAAAABA0DGUAgAAAAAAQNAxlAIAAAAAAEDQMZQCAAAAAABA0DGUAgAAAAAAQNAxlAIAAAAAAEDQMZQCAAAAAABA0DGUAgAAAAAAQNAxlAIAAAAAAEDQMZQCAAAAAABA0DGUAgAAAAAAQNA5GkrFxsbq4MGDubYfPnxYsbGxBV4UADhFnwDYij4BsBmNAuAGR0Op1NRUZWVl5dqekZGhPXv2FHhRAOAUfQJgK/oEwGY0CoAbQv3ZefHixd5/XrZsmaKiorzfZ2VlacWKFapTp07AFgcAvqJPAGxFnwDYjEYBcJNfQ6mePXtKkjwej/r375/jstKlS6tOnTqaOHFiwBYHAL6iTwBsRZ8A2IxGAXCTX0Op7OxsSVLdunW1YcMGValSpVAWBQD+ok8AbEWfANiMRgFwk19DqdN27tyZa9vhw4dVsWLFgq4HAAqEPgGwFX0CYDMaBcANjj7o/Omnn9bbb7/t/f6GG25QpUqVVKtWLX377bcBWxwA+Is+AbAVfQJgMxoFwA2OhlKvvPKKoqOjJUnLly/X//73Py1dulTdunXTfffdF9AFAoA/6BMAW9EnADajUQDc4Ojte2lpad5gffTRR+rdu7cuv/xy1alTR23atAnoAgHAH/QJgK3oEwCb0SgAbnD0SqlzzjlHu3fvliQtXbpUXbp0kSQZY5SVlRW41QGAn+gTAFvRJwA2o1EA3ODolVLXXXedbr75ZjVo0EAHDx5Ut27dJEmbNm1S/fr1A7pAAPAHfQJgK/oEwGY0CoAbHA2lJk2apDp16mj37t165plnVKFCBUl/veTzrrvuCugCAcAf9AmAregTAJvRKABucDSUKl26tEaNGpVr+/Dhwwu6HgAoEPoEwFb0CYDNaBQANzj6TClJevPNN9W+fXvVrFlTP//8syRp8uTJ+vDDDwO2OABwgj4BsBV9AmAzGgUg2BwNpaZPn66RI0eqW7duOnz4sPeD7ypWrKjJkycHcn0A4Bf6BMBW9AmAzWgUADc4Gkq9+OKLmjlzpsaMGaNSpUp5t7du3VrfffddwBYHAP6iTwBsRZ8A2IxGAXCDo6HUzp071aJFi1zbw8LCdPz48QIvCgCcok8AbEWfANiMRgFwg6OhVN26dbVp06Zc25csWaLGjRv7fJxx48bJ4/Hk+KpevbqTJQGAJPoEwF70CYDNaBQANzj663v33XefhgwZoj/++EPGGK1fv17z58/XhAkT9Oqrr/p1rCZNmuh///uf9/u/v1QUAPxFnwDYij4BsBmNAuAGR0OpgQMHKjMzU/fff79OnDihm2++WbVq1dILL7ygPn36+LeA0FAm5wAChj4BsBV9AmAzGgXADY6GUpJ0xx136I477tCBAweUnZ2tatWqOTpOSkqKatasqbCwMLVp00bjx49XbGxsnvtmZGQoIyPD+316erqj6wRQvLnRJ4lGATg7+gTAZjzHAxBsjj5T6rLLLtPhw4clSVWqVPHGKj09XZdddpnPx2nTpo3eeOMNLVu2TDNnztTevXvVrl07HTx4MM/9J0yYoKioKO9XdHS0k+UDKMbc6pNEowCcGX0CYDOe4wFwg6Oh1KpVq3Tq1Klc2//44w998cUXPh+nW7du+te//qVmzZqpS5cu+vjjjyVJc+bMyXP/0aNH68iRI96v3bt3O1k+gGLMrT5JNArAmdEnADbjOR4AN/j19r3Nmzd7//mHH37Q3r17vd9nZWVp6dKlqlWrluPFlC9fXs2aNVNKSkqel4eFhSksLMzx8QEUX273SaJRAPJGnwDYzO1G0SegZPNrKHXBBRd4/6xnXi/hDA8P14svvuh4MRkZGUpOTlaHDh0cHwNAyUSfANiKPgGwGY0C4Ca/hlI7d+6UMUaxsbFav369qlat6r2sTJkyqlatml9/7nPUqFHq3r27zjvvPO3bt09PPPGE0tPT1b9/f3+WBQD0CYC16BMAm9EoAG7yayj1yiuvqGfPnsrOzg7Ilf/yyy+66aabdODAAVWtWlUXX3yx1q1bp5iYmIAcH0DJQZ8A2Io+AbAZjQLgJr+GUmlpabrmmmtUqlQpde/eXT169FCXLl0cvwd4wYIFjn4OAP6JPgGwFX0CYDMaBcBNfv31vdmzZ+u3337TO++8o4oVK+ree+9VlSpVdN111+n111/XgQMHCmudAHBG9AmAregTAJvRKABu8msoJUkej0cdOnTQM888o61bt2r9+vW6+OKLNXPmTNWqVUsdO3bUc889pz179hTGegEgX/QJgK3oEwCb0SgAbvF7KPVP8fHxuv/++/Xll19q9+7d6t+/v7744gvNnz8/EOsDAMfoEwBb0ScANqNRAILFr8+U+qcff/xRP/30kzp27Kjw8HBVrVpVgwYN0qBBgwK1PgBwhD4BsBV9AmAzGgUgmBy9UurgwYPq0qWLGjZsqKuuukppaWmSpNtvv12jRo0K6AIBwB/0CYCt6BMAm9EoAG5wNJQaMWKEQkNDtWvXLpUrV867/cYbb9SSJUsCtjgA8Bd9AmAr+gTAZjQKgBscvX3v008/1bJly1S7du0c2xs0aKCff/45IAsDACfoEwBb0ScANqNRANzg6JVSx48fzzE9P+3AgQMKCwsr8KIAwCn6BMBW9AmAzWgUADc4Gkp17NhRb7zxhvd7j8ej7OxsPfvss+rUqVPAFgcA/qJPAGxFnwDYjEYBcIOjt+89++yzuvTSS/XNN9/o1KlTuv/++7VlyxYdOnRIX375ZaDXCAA+o08AbEWfANiMRgFwg6NXSjVu3FibN2/WRRddpK5du+r48eO67rrrlJSUpHr16gV6jQDgM/oEwFb0CYDNaBQANzh6pZQkVa9eXY8++mgg1wIAAUGfANiKPgGwGY0CEGyOh1KSdOLECe3atUunTp3Ksf38888v0KIAoKDoEwBb0ScANqNRAILJ0VBq//79GjhwoJYsWZLn5VlZWQVaFAA4RZ8A2Io+AbAZjQLgBkefKTV8+HD9/vvvWrduncLDw7V06VLNmTNHDRo00OLFiwO9RgDwGX0CYCv6BMBmNAqAGxy9Uuqzzz7Thx9+qAsvvFAhISGKiYlR165dFRkZqQkTJujqq68O9DoBwCf0CYCt6BMAm9EoAG5w9Eqp48ePq1q1apKkSpUqaf/+/ZKkZs2aaePGjYFbHQD4iT4BsBV9AmAzGgXADY6GUo0aNdK2bdskSRdccIFeeeUV7dmzRy+//LJq1KgR0AUCgD/oEwBb0ScANqNRANzg6O17w4cPV1pamiTpkUce0RVXXKG33npLZcqU0euvvx7I9QGAX+gTAFvRJwA2o1EA3OBoKHXLLbd4/7lFixZKTU3V1q1bdd5556lKlSoBWxwA+Is+AbAVfQJgMxoFwA2OhlL/VK5cObVs2TIQhwKAgKJPAGxFnwDYjEYBCAZHQyljjN577z2tXLlS+/btU3Z2do7LFy5cGJDFAYC/6BMAW9EnADajUQDc4GgoNWzYMM2YMUOdOnXSueeeK4/HE+h1AYAj9AmAregTAJvRKABucDSUmjt3rhYuXKirrroq0OsBgAKhTwBsRZ8A2IxGAXBDiJMfioqKUmxsbKDXAgAFRp8A2Io+AbAZjQLgBkdDqXHjxunRRx/VyZMnA70eACgQ+gTAVvQJgM1oFAA3OHr73g033KD58+erWrVqqlOnjkqXLp3j8o0bNwZkcQDgL/oEwFb0CYDNaBQANzgaSg0YMECJiYnq27cvH4IHwCr0CYCt6BMAm9EoAG5wNJT6+OOPtWzZMrVv3z7Q6wGAAqFPAGxFnwDYjEYBcIOjz5SKjo5WZGRkoNcCAAVGnwDYij4BsBmNAuAGR0OpiRMn6v7771dqamqAlwMABUOfANiKPgGwGY0C4AZHb9/r27evTpw4oXr16qlcuXK5PgTv0KFDAVkcAPiLPgGwFX0CYDMaBcANjoZSkyZN4oPvAFiJPgGwFX0CYDMaBcANfg2lPv30U3Xq1EkDBgwopOUAgDP0CYCt6BMAm9EoAG7y6zOlBg8erKpVq+rGG2/UvHnzdPjw4UJaFgD4hz4BsBV9AmAzGgXATX4NpXbs2KHVq1erWbNmmjx5sqpXr67OnTtrypQpfCAeAFfRJwC2ok8AbEajALjJ77++d/7552vs2LFav369duzYoRtuuEFLly5VfHy8mjdvrocffljffPNNYawVAM6IPgGwFX0CYDMaBcAtfg+l/q5mzZoaPHiwPvnkEx04cED/+c9/lJqaqiuvvFLjx48P1BoBwG/0CYCt6BMAm9EoAMHk6K/v5aV8+fK6/vrrdf311ys7O1sHDx4M1KEBoEDoEwBb0ScANqNRAAqbo6HUlClT8tzu8XhUtmxZNWjQQB06dPDrmBMmTNBDDz2kYcOGafLkyU6WBQCF0ieJRgEoOPoEwGY8xwPgBkdDqUmTJmn//v06ceKEzjnnHBljdPjwYZUrV04VKlTQvn37FBsbq5UrVyo6Ovqsx9uwYYNmzJih888/38lyAMAr0H2SaBSAwKBPAGzGczwAbnD0mVLjx4/XhRdeqJSUFB08eFCHDh3S9u3b1aZNG73wwgvatWuXqlevrhEjRpz1WMeOHdMtt9yimTNn6pxzznGyHADwCmSfJBoFIHDoEwCb8RwPgBscDaXGjh2rSZMmqV69et5t9evX13PPPafRo0erdu3aeuaZZ/Tll1+e9VhDhgzR1VdfrS5dupx134yMDKWnp+f4AoC/C2SfJBoFIHDoEwCb8RwPgBscvX0vLS1NmZmZubZnZmZq7969kv76qw1Hjx4943EWLFigjRs3asOGDT5d74QJE/Too4/6v2AAJUag+iTRKACBRZ8A2IzneADc4OiVUp06ddK///1vJSUlebclJSXpzjvv1GWXXSZJ+u6771S3bt18j7F7924NGzZMc+fOVdmyZX263tGjR+vIkSPer927dztZPoBiLBB9kmgUgMCjTwBsxnM8AG5wNJSaNWuWKlWqpFatWiksLExhYWFq3bq1KlWqpFmzZkmSKlSooIkTJ+Z7jMTERO3bt0+tWrVSaGioQkND9fnnn2vKlCkKDQ1VVlZWrp8JCwtTZGRkji8A+LtA9EmiUQACjz4BsBnP8QC4wdHb96pXr67ly5dr69at2r59u4wxiouLU6NGjbz7dOrU6YzH6Ny5s7777rsc2wYOHKi4uDg98MADKlWqlJOlASjhAtEniUYBCDz6BMBmPMcD4AZHQ6nT4uLiFBcX5+hnIyIi1LRp0xzbypcvr8qVK+faDgD+KkifJBoFoPDQJwA24zkegGByNJTKysrS66+/rhUrVmjfvn3Kzs7Ocflnn30WkMUBgL/oEwBb0ScANqNRANzgaCg1bNgwvf7667r66qvVtGlTeTyegCxm1apVATkOgJKrsPok0SgABUOfANiM53gA3OBoKLVgwQK98847uuqqqwK9HgAoEPoEwFb0CYDNaBQANzj663tlypRR/fr1A70WACgw+gTAVvQJgM1oFAA3OBpK3XvvvXrhhRdkjAn0egCgQOgTAFvRJwA2o1EA3ODo7Xtr1qzRypUrtWTJEjVp0kSlS5fOcfnChQsDsjgA8Bd9AmAr+gTAZjQKgBscDaUqVqyoXr16BXotAFBg9AmAregTAJvRKABucDSUmj17dqDXAQABQZ8A2Io+AbAZjQLgBkefKQUAAAAAAAAUhKNXSknSe++9p3feeUe7du3SqVOncly2cePGAi8MAJyiTwBsRZ8A2IxGAQg2R6+UmjJligYOHKhq1aopKSlJF110kSpXrqwdO3aoW7dugV4jAPiMPgGwFX0CYDMaBcANjoZS06ZN04wZMzR16lSVKVNG999/v5YvX66hQ4fqyJEjgV4jAPiMPgGwFX0CYDMaBcANjoZSu3btUrt27SRJ4eHhOnr0qCSpX79+mj9/fuBWBwB+ok8AbEWfANiMRgFwg6OhVPXq1XXw4EFJUkxMjNatWydJ2rlzp4wxgVsdAPiJPgGwFX0CYDMaBcANjoZSl112mf773/9KkgYNGqQRI0aoa9euuvHGG9WrV6+ALhAA/EGfANiKPgGwGY0C4AZHf31vxowZys7OliQNHjxYlSpV0po1a9S9e3cNHjw4oAsEAH/QJwC2ok8AbEajALjB0VAqJCREISH/9yKr3r17q3fv3gFbFAA4RZ8A2Io+AbAZjQLgBkdDqYSEBF1yySW69NJLlZCQoPLlywd6XQDgCH0CYCv6BMBmNAqAGxx9ptQ111yjjRs36vrrr9c555yjtm3b6sEHH9TSpUt17NixQK8RAHxGnwDYij4BsBmNAuAGR0Op0aNHa+nSpfr999+1evVq9ejRQ5s2bdK1116rypUrB3qNAOAz+gTAVvQJgM1oFAA3OHr73mkpKSn69ttv9e2332rz5s2KjIxUhw4dArU2AHCMPgGwFX0CYDMaBSCYHA2lbrzxRq1evVrZ2dnq2LGjOnbsqNGjR+v8888P9PoAwC/0CYCt6BMAm9EoAG5wNJR69913VaVKFQ0YMECdOnVShw4dVKFChUCvDQD8Rp8A2Io+AbAZjQLgBkefKXXo0CG9+uqryszM1NixY1WlShW1adNGDzzwgJYsWRLoNQKAz+gTAFvRJwA2o1EA3OBoKFWxYkVde+21ev7555WYmKgtW7aocePGev7553XNNdcEeo0A4DP6BMBW9AmAzWgUADc4evveoUOH9Pnnn2vVqlVatWqVtmzZokqVKqlHjx7q1KlToNcIAD6jTwBsRZ8A2IxGAXCDo6FU1apVVaVKFXXo0EF33HGHLr30UjVt2jTQawMAv9EnALaiTwBsRqMAuMHRUOrbb7/1KVBffvmlWrdurbCwMCdXAwB+o08AbEWfANiMRgFwg6PPlPJ1Yt6tWzft2bPHyVUAgCP0CYCt6BMAm9EoAG5wNJTylTGmMA8PAI7RJwC2ok8AbEajAARSoQ6lAAAAAAAAgLwwlAIAAAAAAEDQMZQCAAAAAABA0BXqUMrj8RTm4QHAMfoEwFb0CYDNaBSAQOKDzgGUSPQJgK3oEwCb0SgAgRRamAc/evRoYR4eAByjTwBsRZ8A2IxGAQgkv18p9e233+qJJ57QtGnTdODAgRyXpaen67bbbgvY4gDAH/QJgK3oEwCb0SgAbvFrKPXpp5/qoosu0oIFC/T0008rPj5eK1eu9F5+8uRJzZkzJ+CLBICzoU8AbEWfANiMRgFwk19DqXHjxmnUqFH6/vvvlZqaqvvvv1/XXnutli5dWljrAwCf0CcAtqJPAGxGowC4ya/PlNqyZYvefPNNSX/91YX77rtPtWvX1vXXX6/58+froosuKpRFAsDZ0CcAtqJPAGxGowC4ya+hVFhYmA4fPpxj20033aSQkBD16dNHEydODOTaAMBn9AmAregTAJvRKABu8uvtexdccEGO9xefduONN+rVV1/V0KFD/bry6dOn6/zzz1dkZKQiIyPVtm1bLVmyxK9jAIBEnwDYiz4BsBmNAuAmv14pdeedd2r16tV5XnbTTTdJkmbMmOHz8WrXrq2nnnpK9evXlyTNmTNHPXr0UFJSkpo0aeLP0gCUcPQJgK3oEwCb0SgAbvJrKNWrVy/16tUr38tvuukmb7h80b179xzfP/nkk5o+fbrWrVtHsAD4hT4BsBV9AmAzGgXATX4Npf4pMTFRycnJ8ng8io+PV8uWLR0fKysrS++++66OHz+utm3bFmRZAECfAFiLPgGwGY0CEEyOhlL79u1Tnz59tGrVKlWsWFHGGB05ckSdOnXSggULVLVqVZ+P9d1336lt27b6448/VKFCBS1atEiNGzfOc9+MjAxlZGR4v09PT3eyfADFmFt9kmgUgDOjTwBsxnM8AG7w64POT7vnnnuUnp6uLVu26NChQ/r999/1/fffKz093e8PwmvUqJE2bdqkdevW6c4771T//v31ww8/5LnvhAkTFBUV5f2Kjo52snwAxZhbfZJoFIAzo08AbMZzPABucDSUWrp0qaZPn674+HjvtsaNG+ull17y+y8rlClTRvXr11fr1q01YcIENW/eXC+88EKe+44ePVpHjhzxfu3evdvJ8gEUY271SaJRAM6MPgGwGc/xALjB0dv3srOzVbp06VzbS5curezs7AItyBiT4+WbfxcWFqawsLACHR/wVUpKio4ePerTvsnJyTn+1xcRERFq0KCBo7Uhf271SaJRKDh/uiM5a89pNCj46BOKq8I+Z5JoVjDwHA/FTTDOq2hTwTkaSl122WUaNmyY5s+fr5o1a0qS9uzZoxEjRqhz584+H+ehhx5St27dFB0draNHj2rBggVatWqVli5d6mRZQMCkpKSoYcOGfv9c3759/dp/+/btRCzA6BOKKqfdkfxvz2k0KLjoE4qjYJ0zSTSrsNEoFCfBPK+iTQXjaCg1depU9ejRQ3Xq1FF0dLQ8Ho927dqlZs2aae7cuT4f57ffflO/fv2UlpamqKgonX/++Vq6dKm6du3qZFlAwJyeqM+dOzfHS5jzc/LkSaWmpqpOnToKDw8/6/7Jycnq27evX5N7+IY+oajytzuS/+05jQa5gz6hOCrscyaJZgULjUJxEozzKtoUGI6GUtHR0dq4caOWL1+urVu3yhijxo0bq0uXLn4dZ9asWU6uHggaf/4MbkJCQiGvBr6gTyjq/P3z27Sn6KBPKM44Zyr6aBSKI86r7OdoKHVa165dmXgDsBJ9AmAr+gTAZjQKQDA5+ut7Q4cO1ZQpU3Jtnzp1qoYPH17QNQGAY/QJgK3oEwCb0SgAbnA0lHr//ffzfFlbu3bt9N577xV4UQDgFH0CYCv6BMBmNAqAGxwNpQ4ePKioqKhc2yMjI3XgwIECLwoAnKJPAGxFnwDYjEYBcIOjoVT9+vXz/JOeS5YsUWxsbIEXBQBO0ScAtqJPAGxGowC4wdEHnY8cOVJ333239u/fr8suu0yStGLFCk2cOFGTJ08O5PoAwC/0CYCt6BMAm9EoAG5wNJS67bbblJGRoSeffFKPP/64JKlOnTqaPn26br311oAuEAD8QZ8A2Io+AbAZjQLgBkdDKUm68847deedd2r//v0KDw9XhQoVArkuAHCMPgGwFX0CYDMaBSDYHA+lTqtatWog1oEASklJ0dGjR33aNzk5Ocf/+iIiIkINGjRwtDYgmOhT0eJPu05z0jCJjsF99MluhX0uJdEh2I1G2Y1GoThxPJR677339M4772jXrl06depUjss2btxY4IXBmZSUFDVs2NDvn+vbt69f+2/fvp1IwVr0qehx2q7T/G2YRMfgDvpkv2CdS0l0CPahUfajUShuHA2lpkyZojFjxqh///768MMPNXDgQP3000/asGGDhgwZEug1wg+nJ+Zz585VfHz8Wfc/efKkUlNTVadOHYWHh591/+TkZPXt29fvVzMAwUKfiiZ/23Wavw2T6BjcQ5+KhsI+l5LoEOxEo4oGGoXixtFQatq0aZoxY4ZuuukmzZkzR/fff79iY2P18MMP69ChQ4FeIxyIj49Xy5Ytfdo3ISGhkFcDBA99Ktr8addpNAxFBX0qWjiXQklDo4oWGoXiIsTJD+3atUvt2rWTJIWHh3snqP369dP8+fMDtzoA8BN9AmAr+gTAZjQKgBscDaWqV6+ugwcPSpJiYmK0bt06SdLOnTtljAnc6gDAT/QJgK3oEwCb0SgAbnA0lLrsssv03//+V5I0aNAgjRgxQl27dtWNN96oXr16BXSBAOAP+gTAVvQJgM1oFAA3OPpMqRkzZig7O1uSNHjwYFWqVElr1qxR9+7dNXjw4IAuEAD8QZ8A2Io+AbAZjQLgBkdDqZCQEIWE/N+LrHr37q3evXsHbFEA4BR9AmAr+gTAZjQKgBscvX1v9uzZevfdd3Ntf/fddzVnzpwCLwoAnKJPAGxFnwDYjEYBcIOjodRTTz2lKlWq5NperVo1jR8/vsCLAgCn6BMAW9EnADajUQDc4Ggo9fPPP6tu3bq5tsfExGjXrl0FXhQAOEWfANiKPgGwGY0C4AZHQ6lq1app8+bNubZ/++23qly5coEXBQBO0ScAtqJPAGxGowC4wdFQqk+fPho6dKhWrlyprKwsZWVl6bPPPtOwYcPUp0+fQK8RAHxGnwDYij4BsBmNAuAGR39974knntDPP/+szp07KzT0r0NkZ2fr1ltv5f3GAFxFnwDYij4BsBmNAuAGR0OpMmXK6O2339YTTzyhTZs2KTw8XM2aNVNMTEyg1wcAfqFPAGxFnwDYjEYBcIOjodRpDRo0UIMGDfK9PDIyUps2bVJsbGxBrgYA/EafANiKPgGwGY0CEEyOPlPKV8aYwjw8ADhGnwDYij4BsBmNAhBIhTqUAgAAAAAAAPLCUAoAAAAAAABBx1AKAAAAAAAAQVeoQymPx1OYhwcAx+gTAFvRJwA2o1EAAokPOgdQItEnALaiTwBsRqMABFKhDqWWLFmiWrVqFeZVAIAj9AmAregTAJvRKACBFOrkh7KysvT6669rxYoV2rdvn7Kzs3Nc/tlnn0mS2rdvX/AVAoAf6BMAW9EnADajUQDc4GgoNWzYML3++uu6+uqr1bRpU95XDMAa9AmAregTAJvRKABucDSUWrBggd555x1dddVVgV4PABQIfQJgK/oEwGY0CoAbHH2mVJkyZVS/fv1ArwUACow+AbAVfQJgMxoFwA2OhlL33nuvXnjhBf7yAgDr0CcAtqJPAGxGowC4wee371133XU5vv/ss8+0ZMkSNWnSRKVLl85x2cKFCwOzOgDwAX0CYCv6BMBmNAqA23weSkVFReX4vlevXgFfDAA4QZ8A2Io+AbAZjQLgNp+HUrNnzw74lU+YMEELFy7U1q1bFR4ernbt2unpp59Wo0aNAn5dAIov+gTAVvQJgM0C3Sj6BMBfjj5T6rR9+/bpiy++0Jo1a7Rv3z6/f/7zzz/XkCFDtG7dOi1fvlyZmZm6/PLLdfz48YIsCwDoEwBr0ScANitIo+gTAH/5/Eqpv0tPT9eQIUO0YMECZWVlSZJKlSqlG2+8US+99FKul4HmZ+nSpTm+nz17tqpVq6bExER17NjRydIAlHD0CYCt6BMAmwWiUfQJgL8cvVLq9ttv19dff62PPvpIhw8f1pEjR/TRRx/pm2++0R133OF4MUeOHJEkVapUyfExAJRs9AmAregTAJsVRqPoE4CzcfRKqY8//ljLli1T+/btvduuuOIKzZw5U1deeaWjhRhjNHLkSLVv315NmzbNc5+MjAxlZGR4v09PT3d0XQCKL7f6JNEoAGdGnwDYLNCNok8AfOHolVKVK1fO8+WbUVFROueccxwt5O6779bmzZs1f/78fPeZMGGCoqKivF/R0dGOrgtA8eVWnyQaBeDM6BMAmwW6UfQJgC8cDaXGjh2rkSNHKi0tzbtt7969uu+++/Sf//zH7+Pdc889Wrx4sVauXKnatWvnu9/o0aN15MgR79fu3budLB9AMeZWnyQaBeDM6BMAmwWyUfQJgK8cvX1v+vTp+vHHHxUTE6PzzjtPkrRr1y6FhYVp//79euWVV7z7bty4Md/jGGN0zz33aNGiRVq1apXq1q17xusNCwtTWFiYkyUDKCHc6pNEowCcGX0CYLNANIo+AfCXo6FUz549A3LlQ4YM0bx58/Thhx8qIiJCe/fulfTXS0TDw8MDch0AShb6BMBW9AmAzQLRKPoEwF+OhlKPPPJIQK58+vTpkqRLL700x/bZs2drwIABAbkOACULfQJgK/oEwGaBaBR9AuAvR0OpQDHGuHn1AJAv+gTAVvQJgK3oEwB/ORpKZWVladKkSXrnnXe0a9cunTp1Ksflhw4dCsjiAMBf9AmAregTAJvRKABucPTX9x599FE9//zz6t27t44cOaKRI0fquuuuU0hIiMaNGxfgJQKA7+gTAFvRJwA2o1EA3OBoKPXWW29p5syZGjVqlEJDQ3XTTTfp1Vdf1cMPP6x169YFeo0A4DP6BMBW9AmAzWgUADc4Gkrt3btXzZo1kyRVqFBBR44ckSRdc801+vjjjwO3OgDwE30CYCv6BMBmNAqAGxwNpWrXrq20tDRJUv369fXpp59KkjZs2KCwsLDArQ4A/ESfANiKPgGwGY0C4AZHQ6levXppxYoVkqRhw4bpP//5jxo0aKBbb71Vt912W0AXCAD+oE8AbEWfANiMRgFwg6O/vvfUU095//n6669XdHS0vvzyS9WvX1/XXnttwBYHAP6iTwBsRZ8A2IxGAXCD30OpP//8U//v//0//ec//1FsbKwkqU2bNmrTpk3AFwcA/qBPAGxFnwDYjEYBcIvfb98rXbq0Fi1aVBhrAYACoU8AbEWfANiMRgFwi+PPlPrggw8CvBQAKDj6BMBW9AmAzWgUADc4+kyp+vXr6/HHH9fatWvVqlUrlS9fPsflQ4cODcjiAMBf9AmAregTAJvRKABucDSUevXVV1WxYkUlJiYqMTExx2Uej4dgAXANfQJgK/oEwGY0CoAbHA2ldu7cGeh1AEBA0CcAtqJPAGxGowC4wdFQauTIkXlu93g8Klu2rOrXr68ePXqoUqVKBVocAPiLPgGwFX0CYDMaBcANjoZSSUlJ2rhxo7KystSoUSMZY5SSkqJSpUopLi5O06ZN07333qs1a9aocePGgV4zAOSLPgGwFX0CYDMaBcANjv76Xo8ePdSlSxf9+uuvSkxM1MaNG7Vnzx517dpVN910k/bs2aOOHTtqxIgRgV4vAJwRfQJgK/oEwGY0CoAbHA2lnn32WT3++OOKjIz0bouMjNS4ceP0zDPPqFy5cnr44YdzfUAeABQ2+gTAVvQJgM1oFAA3OBpKHTlyRPv27cu1ff/+/UpPT5ckVaxYUadOnSrY6gDAT/QJgK3oEwCb0SgAbnD89r3bbrtNixYt0i+//KI9e/Zo0aJFGjRokHr27ClJWr9+vRo2bBjItQLAWdEnALaiTwBsRqMAuMHRB52/8sorGjFihPr06aPMzMy/DhQaqv79+2vSpEmSpLi4OL366quBWykA+IA+AbAVfQJgMxoFwA2OhlIVKlTQzJkzNWnSJO3YsUPGGNWrV08VKlTw7nPBBRcEao0A4DP6BMBW9AmAzWgUADc4GkqdVqFCBZ1//vmBWgsABAx9AmAr+gTAZjQKQDA5+kwpAAAAAAAAoCAYSgEAAAAAACDoGEoBAAAAAAAg6BhKAQAAAAAAIOgYSgEAAAAAACDoGEoBAAAAAAAg6BhKAQAAAAAAIOgYSgEAAAAAACDoGEoBAAAAAAAg6BhKAQAAAAAAIOgYSgEAAAAAACDoGEoBAAAAAAAg6BhKAQAAAAAAIOgYSgEAAAAAACDoGEoBAAAAAAAg6BhKAQAAAAAAIOgYSgEAAAAAACDoXB1KrV69Wt27d1fNmjXl8Xj0wQcfuLkcAMiBRgGwFX0CYCv6BMAfrg6ljh8/rubNm2vq1KluLgMA8kSjANiKPgGwFX0C4I9QN6+8W7du6tatm5tLAIB80SgAtqJPAGxFnwD4g8+UAgAAAAAAQNC5+kopf2VkZCgjI8P7fXp6uourAYCcaBQAW9EnALaiT0DJVqReKTVhwgRFRUV5v6Kjo91eEgB40SgAtqJPAGxFn4CSrUgNpUaPHq0jR454v3bv3u32kgDAi0YBsBV9AmAr+gSUbEXq7XthYWEKCwtzexkAkCcaBcBW9AmAregTULK5OpQ6duyYfvzxR+/3O3fu1KZNm1SpUiWdd955Lq4MAGgUAHvRJwC2ok8A/OHqUOqbb75Rp06dvN+PHDlSktS/f3+9/vrrLq0KAP5CowDYij4BsBV9AuAPV4dSl156qYwxbi4BAPJFowDYij4BsBV9AuCPIvVB5wAAAAAAACgeGEoBAAAAAAAg6BhKAQAAAAAAIOgYSgEAAAAAACDoGEoBAAAAAAAg6BhKAQAAAAAAIOhC3V4AYCNP5h9qUT1E4Ye3S78GfnYbfni7WlQPkSfzj4AfG0DRVNjd+TsaBCBQgtEumgXAX7Sp6GAoBeSh7LFd2vjvCtLqf0urA3/8eEkb/11Bycd2SWoX+CsAUOQUdnf+jgYBCJRgtItmAfAXbSo6GEoBefijwnlq+coxvfXWW4qPiwv48ZO3btUtt9yiWVedF/BjAyiaCrs7f0eDAARKMNpFswD4izYVHQylgDyY0LJK2putkxUbSjUvCPjxT+7NVtLebJnQsgE/NoCiqbC783c0CECgBKNdNAuAv2hT0cEHnQMAAAAAACDoeKVUMcMHdAMoiviQbwC24MNxAdiMRqG4YShVzPAB3QCKIj7kG4At+HBcADajUShuGEoVM3xAN4CiiA/5BmALPhwXgM1oFIobhlLFDB/QDaAo4kO+AdiCD8cFYDMaheKGDzoHAAAAAABA0PFKKSAPJ06ckCRt3LjRp/1Pnjyp1NRU1alTR+Hh4WfdPzk5uUDrA1D8+Nsdyf/2nEaDAARKYZ8zSTQLgP+CcV5FmwKDoVQxwzAlMLZu3SpJuuOOOwr1eiIiIgr1+EBR4eTEQSpeT26C1Z2/o0FAbgxZ/BPMdtEsgEb5ijYVHQylihmGKYHRs2dPSVJcXJzKlSvn3Z6cnKy+ffv6fby5c+cqPj4+x7aIiAg1aNCgQOsEigsGMvl3R3LeHinv/kg0CMgPT2T8E4xzJolmAafRKN8UxnkVz+kKh8cYY9xehFPp6emKiorSkSNHFBkZ6fZyrHDgwAF98MEHDFMKyYkTJ7z/R/B3Z/svEHnFEP+nuD6Wi+vtKgz5tUtiICM5b49EfwKhOD6Wi+NtCpRAn0tJJfN8inOm4CiOj+XieJsCiUYVHH0KDl8fywylSggeeLBZcX0sF9fbFWz59UuiYQiO4vhYLo63qbAxHIaNiuNjuTjepmCgUbANQykARUZxfSwX19sFlDTF8bFcHG8TUBIVx8dycbxNQEnk62M5JIhrAgAAAAAAACQxlAIAAAAAAIALGEoBAAAAAAAg6BhKAQAAAAAAIOgYSgEAAAAAACDoGEoBAAAAAAAg6BhKAQAAAAAAIOgYSgEAAAAAACDoGEoBAAAAAAAg6BhKAQAAAAAAIOgYSgEAAAAAACDoQt1eQEEYYyRJ6enpLq8EQEGcfgyffkwXFzQKKB6KY6PoE1A80CcAtvK1T0V6KHX06FFJUnR0tMsrARAIR48eVVRUlNvLCBgaBRQvxalR9AkoXugTAFudrU8eU4TH6tnZ2fr1118VEREhj8fj9nKKpPT0dEVHR2v37t2KjIx0ezlFFvdjwRhjdPToUdWsWVMhIcXnXcU0qvCV9MdeSb/9wVIcG0WfAofHoX+4vwKLPuFseMz5jvsqsHztU5F+pVRISIhq167t9jKKhcjISB54AcD96Fxx+a97f0ejgqekP/ZK+u0PhuLWKPoUeDwO/cP9FTj0Cb7gMec77qvA8aVPxWOcDgAAAAAAgCKFoRQAAAAAAACCjqFUCRcWFqZHHnlEYWFhbi+lSON+BNxR0h97Jf32Azbgcegf7i8guHjM+Y77yh1F+oPOAQAAAAAAUDTxSikAAAAAAAAEHUMpAAAAAAAABB1DKQAAAAAAAAQdQ6kibtq0aapbt67Kli2rVq1a6Ysvvsh334ULF6pr166qWrWqIiMj1bZtWy1btizXPq1bt1bFihVVvnx5XXDBBXrzzTdzHWvPnj3q27evKleurHLlyumCCy5QYmJiwG9fsPhzPw4YMEAejyfXV5MmTfLcf8GCBfJ4POrZs2eO7ePGjct1jOrVqwfyZgHW8+exl5aWpptvvlmNGjVSSEiIhg8fnmufP//8U4899pjq1aunsmXLqnnz5lq6dGmOfTIzMzV27FjVrVtX4eHhio2N1WOPPabs7OxA37yz8uf2r1q1Ks/2bN26Ncd+hw8f1pAhQ1SjRg2VLVtW8fHx+uSTT7yXr169Wt27d1fNmjXl8Xj0wQcfFNbNA4oEfx6Ha9asUUJCgipXrqzw8HDFxcVp0qRJOfbZsmWL/vWvf6lOnTryeDyaPHlyruMU5cdhYZwzna1bnDOhJAv0872/y+95ik3nSv7w576SpIyMDI0ZM0YxMTEKCwtTvXr19Nprr3kv9+W8sij33BYMpYqwt99+W8OHD9eYMWOUlJSkDh06qFu3btq1a1ee+69evVpdu3bVJ598osTERHXq1Endu3dXUlKSd59KlSppzJgx+uqrr7R582YNHDhQAwcOzBGz33//XQkJCSpdurSWLFmiH374QRMnTlTFihUL+yYXCn/vxxdeeEFpaWner927d6tSpUq64YYbcu37888/a9SoUerQoUOex2rSpEmOY3333XcBvW2Azfx97GVkZKhq1aoaM2aMmjdvnuc+Y8eO1SuvvKIXX3xRP/zwgwYPHqxevXrl6NzTTz+tl19+WVOnTlVycrKeeeYZPfvss3rxxRcL5Xbmx9/bf9q2bdtydKNBgwbey06dOqWuXbsqNTVV7733nrZt26aZM2eqVq1a3n2OHz+u5s2ba+rUqYV224Ciwt/HYfny5XX33Xdr9erVSk5O1tixYzV27FjNmDHDu8+JEycUGxurp556Kt/BSVF9HBbGOZMv3ZI4Z0LJVBjP90470/MUW86V/OHkvKp3795asWKFZs2apW3btmn+/PmKi4vzXu7LeWVR7blVDIqsiy66yAwePDjHtri4OPPggw/6fIzGjRubRx999Iz7tGjRwowdO9b7/QMPPGDat2/v32ItVtD7cdGiRcbj8ZjU1NQc2zMzM01CQoJ59dVXTf/+/U2PHj1yXP7II4+Y5s2bF2TpQJFWkMfeJZdcYoYNG5Zre40aNczUqVNzbOvRo4e55ZZbvN9fffXV5rbbbsuxz3XXXWf69u3rx+oLzt/bv3LlSiPJ/P777/kec/r06SY2NtacOnXKpzVIMosWLfJ1yUCxE4hzqV69euXbj5iYGDNp0qQz/nxRehwWxjmTL93inAklVWE93zvb8xRbzpX84e99tWTJEhMVFWUOHjyY7zF9Oa/8u6LUc5vwSqki6tSpU0pMTNTll1+eY/vll1+utWvX+nSM7OxsHT16VJUqVcrzcmOMVqxYoW3btqljx47e7YsXL1br1q11ww03qFq1amrRooVmzpzp/Ma4KBD346xZs9SlSxfFxMTk2P7YY4+patWqGjRoUL4/m5KSopo1a6pu3brq06ePduzY4f+NAIqgQDz28pKRkaGyZcvm2BYeHq41a9Z4v2/fvr1WrFih7du3S5K+/fZbrVmzRldddZXj6/VXQW5/ixYtVKNGDXXu3FkrV67McdnixYvVtm1bDRkyROeee66aNm2q8ePHKysrK+C3ASjqAtGhpKQkrV27VpdccklhLNEqhXXO5Gu3OGdCSVOYz/fO9jzFhnMlfzi5r04/p33mmWdUq1YtNWzYUKNGjdLJkye9+/hyXomCC3V7AXDmwIEDysrK0rnnnptj+7nnnqu9e/f6dIyJEyfq+PHj6t27d47tR44cUa1atZSRkaFSpUpp2rRp6tq1q/fyHTt2aPr06Ro5cqQeeughrV+/XkOHDlVYWJhuvfXWgt+4ICro/ZiWlqYlS5Zo3rx5ObZ/+eWXmjVrljZt2pTvz7Zp00ZvvPGGGjZsqN9++01PPPGE2rVrpy1btqhy5cqObg9QVASiYXm54oor9Pzzz6tjx46qV6+eVqxYoQ8//DDHk5sHHnhAR44cUVxcnEqVKqWsrCw9+eSTuummmxxfr7+c3P4aNWpoxowZatWqlTIyMvTmm2+qc+fOWrVqlfc/HOzYsUOfffaZbrnlFn3yySdKSUnRkCFDlJmZqYcffrjQbxdQlBSkQ7Vr19b+/fuVmZmpcePG6fbbby/MpVqhsM6ZfOkW50woiQrr+Z4vz1NsOFfyh5P7aseOHVqzZo3Kli2rRYsW6cCBA7rrrrt06NAh7+dK+XJeiYJjKFXEeTyeHN8bY3Jty8v8+fM1btw4ffjhh6pWrVqOyyIiIrRp0yYdO3ZMK1as0MiRIxUbG6tLL71U0l8T99atW2v8+PGS/vqv9lu2bNH06dOL3FDqNKf34+uvv66KFSvm+HDAo0ePqm/fvpo5c6aqVKmS789269bN+8/NmjVT27ZtVa9ePc2ZM0cjR470/0YARZDTx15+XnjhBd1xxx2Ki4uTx+NRvXr1NHDgQM2ePdu7z9tvv625c+dq3rx5atKkiTZt2qThw4erZs2a6t+/v+PrdsKf29+oUSM1atTI+33btm21e/duPffcc96hVHZ2tqpVq6YZM2aoVKlSatWqlX799Vc9++yzDKWAfDjp0BdffKFjx45p3bp1evDBB1W/fn1rn6wFWiDPmSTfusU5E0qyQD7f8/V5ik3nSv7w577Kzs6Wx+PRW2+9paioKEnS888/r+uvv14vvfSSwsPDfTqvRMExlCqiqlSpolKlSuWa/O7bty/XhPif3n77bQ0aNEjvvvuuunTpkuvykJAQ1a9fX5J0wQUXKDk5WRMmTPAOpWrUqKHGjRvn+Jn4+Hi9//77BbhF7ijI/WiM0WuvvaZ+/fqpTJky3u0//fSTUlNT1b17d++203+pIjQ0VNu2bVO9evVyHa98+fJq1qyZUlJSCnKTgCKhII+9M6latao++OAD/fHHHzp48KBq1qypBx98UHXr1vXuc9999+nBBx9Unz59JP31BOfnn3/WhAkTgnaiFajbf/HFF2vu3Lne72vUqKHSpUurVKlS3m3x8fHau3evTp06laNVQElXkMfh6aY0a9ZMv/32m8aNG1fsh1KFcc4kOesW50woCQrj+Z6vz1NsOFfyh5P7qkaNGqpVq5Z3ICX91R5jjH755Rc1aNDAp/NKFByfKVVElSlTRq1atdLy5ctzbF++fLnatWuX78/Nnz9fAwYM0Lx583T11Vf7dF3GGGVkZHi/T0hI0LZt23Lss3379lyfqVQUOL0fJenzzz/Xjz/+mOu92HFxcfruu++0adMm79e1116rTp06adOmTYqOjs7zeBkZGUpOTlaNGjUKdqOAIqAgjz1flC1bVrVq1VJmZqbef/999ejRw3vZiRMnFBKS8//+SpUqFdQ/cxyo25+UlJSjGQkJCfrxxx9z3Jbt27erRo0aDKSAfwjU4/Cf50nFVWGcM0nOusU5E0qCwni+5+vzFBvOlfzh5L5KSEjQr7/+qmPHjnm3bd++XSEhIapdu3aOfc90XokAcOPT1REYCxYsMKVLlzazZs0yP/zwgxk+fLgpX7689y+aPPjgg6Zfv37e/efNm2dCQ0PNSy+9ZNLS0rxfhw8f9u4zfvx48+mnn5qffvrJJCcnm4kTJ5rQ0FAzc+ZM7z7r1683oaGh5sknnzQpKSnmrbfeMuXKlTNz584N3o0PIH/vx9P69u1r2rRp49N15PVXLe69916zatUqs2PHDrNu3TpzzTXXmIiIiFx/xQ8orpw89pKSkkxSUpJp1aqVufnmm01SUpLZsmWL9/J169aZ999/3/z0009m9erV5rLLLjN169bN8Rfr+vfvb2rVqmU++ugjs3PnTrNw4UJTpUoVc//99wfldp/m7+2fNGmSWbRokdm+fbv5/vvvzYMPPmgkmffff9+7z65du0yFChXM3XffbbZt22Y++ugjU61aNfPEE0949zl69Kj3fpRknn/+eZOUlGR+/vnn4N14wBL+Pg6nTp1qFi9ebLZv3262b99uXnvtNRMZGWnGjBnj3ScjI8P7GKtRo4YZNWqUSUpKMikpKd59iurjsDDOmXzpFudMKKkK4/neP+X1PMWWcyV/+HtfHT161NSuXdtcf/31ZsuWLebzzz83DRo0MLfffrt3H1/OK4tqz23CUKqIe+mll0xMTIwpU6aMadmypfn888+9l/Xv399ccskl3u8vueQSIynXV//+/b37jBkzxtSvX9+ULVvWnHPOOaZt27ZmwYIFua73v//9r2natKkJCwszcXFxZsaMGYV5MwudP/ejMcYcPnzYhIeH+3y784r9jTfeaGrUqGFKly5tatasaa677rocT66BksDfx15eDYuJifFevmrVKhMfH2/CwsJM5cqVTb9+/cyePXtyHCM9Pd0MGzbMnHfeeaZs2bImNjbWjBkzxmRkZBTmTc2TP7f/6aefNvXq1fP2uX379ubjjz/Odcy1a9eaNm3amLCwMBMbG2uefPJJk5mZ6b185cqVZ/3/AqAk8edxOGXKFNOkSRNTrlw5ExkZaVq0aGGmTZtmsrKyvPvs3Lkzz8fY349TlB+HhXHOdLZucc6EkizQz/f+Ka/nKTadK/nD3z4lJyebLl26mPDwcFO7dm0zcuRIc+LECe/lvpxXFuWe28JjjDGF+EIsAAAAAAAAIBc+UwoAAAAAAABBx1AKAAAAAAAAQcdQCgAAAAAAAEHHUAoAAAAAAABBx1AKAAAAAAAAQcdQCgAAAAAAAEHHUAoAAAAAAABBx1AKAAAAAAAAQcdQCrlceumlGj58uM/7p6amyuPxaNOmTYW2JklatWqVPB6PDh8+XKjXA8Ae9AhAUUO37FDSbi9QEHTLLh6PRx988IGk4N3XbmIoVUgGDBggj8cjj8ej0qVLKzY2VqNGjdLx48cL5bp69uwZsOMtXLhQjz/+uM/7R0dHKy0tTU2bNg3YGlBwK1asULt27RQREaEaNWrogQceUGZmptvLggvoEXwR6N/daRMmTNCFF16oiIgIVatWTT179tS2bdty7LNw4UJdccUVqlKlSrE/8YJv6BYKql27dkpLS1NUVJTbS0EJQbcAZxhKFaIrr7xSaWlp2rFjh5544glNmzZNo0aNynPfP//8s9DX4+t1VKpUSRERET4ft1SpUqpevbpCQ0OdLg0BtnnzZl111VW68sorlZSUpAULFmjx4sV68MEH3V4aXEKP4JbPP/9cQ4YM0bp167R8+XJlZmbq8ssvz3GSfvz4cSUkJOipp55ycaWwDd1CQZQpU0bVq1eXx+NxeykoQegW4IBBoejfv7/p0aNHjm233367qV69ujHGmEceecQ0b97czJo1y9StW9d4PB6TnZ1tDh8+bO644w5TtWpVExERYTp16mQ2bdqU7/U88sgjRlKOr5UrV5qdO3caSebtt982l1xyiQkLCzOvvfaaOXDggOnTp4+pVauWCQ8PN02bNjXz5s3LccxLLrnEDBs2zPt9TEyMefLJJ83AgQNNhQoVTHR0tHnllVe8l5++rqSkJGOMMStXrjSSzP/+9z/TqlUrEx4ebtq2bWu2bt2a43oef/xxU7VqVVOhQgUzaNAg88ADD5jmzZvne1tPH/f333/3bnvvvfdM48aNTZkyZUxMTIx57rnncvzMSy+9ZOrXr2/CwsJMtWrVzL/+9S/vZe+++65p2rSpKVu2rKlUqZLp3LmzOXbsWL7Xv2rVKnPhhReaMmXKmOrVq5sHHnjA/Pnnnznut3vuucfcd9995pxzzjHnnnuueeSRR/I9njHGZGZmmhEjRpioqChTqVIlc99995lbb701x787l1xyibn77rvNsGHDTMWKFU21atXMK6+8Yo4dO2YGDBhgKlSoYGJjY80nn3zi/ZnRo0eb1q1b57iuRYsWmbJly5r09PQzrgnFDz2iR2frUX6/O2OM2bx5s+nUqZN3bXfccYc5evSo934oXbq0Wb16tfdYzz33nKlcubL59ddf87yuffv2GUnm888/z3XZP39/KLnoFt3y5Tzqn787SSYmJibP2zt79mwTFRVlli5dauLi4kz58uXNFVdckatVs2bN8t4n1atXN0OGDDnjGoDT6Bbd8qVb69evN126dDGVK1c2kZGRpmPHjiYxMTHHPpLMokWLjDEl49yIoVQhyStK99xzj6lcubIx5q+YnP4/w40bN5pvv/3WZGdnm4SEBNO9e3ezYcMGs337dnPvvfeaypUrm4MHD+Z5PUePHjW9e/c2V155pUlLSzNpaWkmIyPD+y9vnTp1zPvvv2927Nhh9uzZY3755Rfz7LPPmqSkJPPTTz+ZKVOmmFKlSpl169Z5j5lXlCpVqmReeuklk5KSYiZMmGBCQkJMcnKyMSb/KLVp08asWrXKbNmyxXTo0MG0a9fOe8y5c+easmXLmtdee81s27bNPProoyYyMtKvKH3zzTcmJCTEPPbYY2bbtm1m9uzZJjw83MyePdsYY8yGDRtMqVKlzLx580xqaqrZuHGjeeGFF4wxxvz6668mNDTUPP/882bnzp1m8+bN5qWXXvI+0fqnX375xZQrV87cddddJjk52SxatMhUqVIlR3QuueQSExkZacaNG2e2b99u5syZYzwej/n000/zvU1PP/20iYqKMu+995754YcfzKBBg0xERESuoVRERIR5/PHHzfbt283jjz9uQkJCTLdu3cyMGTPM9u3bzZ133mkqV65sjh8/bowxZuTIkaZ9+/Y5rmvp0qU5nmii5KBH9OhsPcrvd3f8+HFTs2ZNc91115nvvvvOrFixwtStW9f079/f+7P33XefiYmJMYcPHzabNm0yYWFhZuHChfnedykpKUaS+e6773JdVhJOvOAbukW3fDmPOv07S0tLMz/++KOpX7++6devX563d/bs2aZ06dKmS5cuZsOGDSYxMdHEx8ebm2++2Xu8adOmmbJly5rJkyebbdu2mfXr15tJkyble/3A39EtuuVLt1asWGHefPNN88MPP3if/5177rk5XjjAUAoB8c8off3116Zy5cqmd+/expi/olS6dGmzb98+7z4rVqwwkZGR5o8//shxrHr16uWYTJ/tuoz5v395J0+efNa1XnXVVebee+/1fp9XlPr27ev9Pjs721SrVs1Mnz49x3XlNSk/7eOPPzaSzMmTJ40xxrRp0ybXf3lKSEjwK0o333yz6dq1a4597rvvPtO4cWNjjDHvv/++iYyMzPOVQYmJiUaSSU1Nzff6/u6hhx4yjRo1MtnZ2d5tL730kqlQoYLJysoyxvx1v/1zEHThhReaBx54IN/j1qhRwzz11FPe7//8809Tu3btXEOpvx83MzPTlC9f3nvSZcxfJ2WSzFdffWWMMWbZsmUmJCTEzJs3z2RmZppffvnFtG/f3kjK9V9GUPzRI3pkzNl7lNfvbsaMGeacc87J8V8RP/74YxMSEmL27t1rjDEmIyPDtGjRwvTu3ds0adLE3H777fleR3Z2tunevXuutZ1WEk684Bu6RbeMOXu3TsvOzja9evUyrVq1MidOnMjz9s6ePdtIMj/++GOONZx77rne72vWrGnGjBnj020C/olu0S1jfO/WaZmZmSYiIsL897//9W4raUMpPlOqEH300UeqUKGCypYtq7Zt26pjx4568cUXvZfHxMSoatWq3u8TExN17NgxVa5cWRUqVPB+7dy5Uz/99JN27dqVY/v48ePPuobWrVvn+D4rK0tPPvmkzj//fO/1fPrpp9q1a9cZj3P++ed7/9nj8ah69erat2+fzz9To0YNSfL+zLZt23TRRRfl2P+f359NcnKyEhIScmxLSEhQSkqKsrKy1LVrV8XExCg2Nlb9+vXTW2+9pRMnTkiSmjdvrs6dO6tZs2a64YYbNHPmTP3+++9nvK62bdvm+FyChIQEHTt2TL/88kuet/n07c7vfjpy5IjS0tLUtm1b77bQ0NBcv7N/HrdUqVKqXLmymjVr5t127rnnSvq/+/fyyy/Xs88+q8GDByssLEwNGzbU1Vdf7f15lDz0iB6dqUdnuq7mzZurfPnyOa4rOzvb+2HlZcqU0dy5c/X+++/r5MmTmjx5cr7Hu/vuu7V582bNnz/fr3WgZKJbdMvXbj300EP66quv9MEHHyg8PDzf/cqVK6d69erlefx9+/bp119/VefOnc96fUB+6BbdOlu39u3bp8GDB6thw4aKiopSVFSUjh07dtbfR3HGJ5MVok6dOmn69OkqXbq0atasqdKlS+e4/O8n+ZKUnZ2tGjVqaNWqVbmOVbFiRVWsWDHHXySqVKnSWdfwz+uYOHGiJk2apMmTJ6tZs2YqX768hg8frlOnTp3xOP9cu8fjUXZ2ts8/c/rB/Pef+ecHTxpjzni8fzLGnPEYERER2rhxo1atWqVPP/1UDz/8sMaNG6cNGzaoYsWKWr58udauXatPP/1UL774osaMGaOvv/5adevW9eu6/r7dyf3ki7yOe7b7d+TIkRoxYoTS0tJ0zjnnKDU1VaNHj87z9qH4o0f0yEmP8rquvx/vtLVr10qSDh06pEOHDuX6XUvSPffco8WLF2v16tWqXbu2X+tAyUS36JYv99PcuXM1adIkrVq16qxtyev4p9dxpmEW4Cu6RbfOdj8NGDBA+/fv1+TJkxUTE6OwsDC1bdv2rL+P4oxXShWi8uXLq379+oqJicn1L2teWrZsqb179yo0NFT169fP8VWlSpVc209HqUyZMsrKyvJpTV988YV69Oihvn37qnnz5oqNjVVKSkqBbqcTjRo10vr163Ns++abb/w6RuPGjbVmzZoc29auXauGDRt6Xw0UGhqqLl266JlnntHmzZuVmpqqzz77TNJfwUhISNCjjz6qpKQklSlTRosWLcr3utauXZsjemvXrlVERIRq1arl17pPi4qKUo0aNbRu3TrvtszMTCUmJjo6Xl48Ho9q1qyp8PBwzZ8/X9HR0WrZsmXAjo+igx7ljx79Ja/fXePGjbVp06Ycfynvyy+/VEhIiBo2bChJ+umnnzRixAjNnDlTF198sW699dYcJ2PGGN19991auHChPvvsMwbj8Bndyh/d+stXX32l22+/Xa+88oouvvhix8eR/noyW6dOHa1YsaJAx0HJRrfyR7f+8sUXX2jo0KG66qqr1KRJE4WFhenAgQOOj1ccMJSySJcuXdS2bVv17NlTy5YtU2pqqtauXauxY8ee8QFbp04dbd68Wdu2bdOBAwfO+Kc/69ev750QJycn69///rf27t1bGDfnjO655x7NmjVLc+bMUUpKip544glt3rzZrz/be++992rFihV6/PHHtX37ds2ZM0dTp071/tnVjz76SFOmTNGmTZv0888/64033lB2drYaNWqkr7/+WuPHj9c333yjXbt2aeHChdq/f7/i4+PzvK677rpLu3fv1j333KOtW7fqww8/1COPPKKRI0cqJMT5w2jYsGF66qmntGjRIm3dulV33XWXDh8+7Ph4f/fss8/qu+++05YtW/T444/rqaee0pQpU3j7HnxCj0pej/L63d1yyy0qW7as+vfvr++//14rV67UPffco379+uncc89VVlaW+vXrp8svv1wDBw7U7Nmz9f3332vixIne4w4ZMkRz587VvHnzFBERob1792rv3r06efKkd59Dhw5p06ZN+uGHHyT99RL/TZs2ufLvA4ouulWyurV371716tVLffr00RVXXOFty/79+x0dT5LGjRuniRMnasqUKUpJSdHGjRtzvPUKCDS6VbK6Jf31+3jzzTeVnJysr7/+WrfcckuJf6UmQymLeDweffLJJ+rYsaNuu+02NWzYUH369FFqaqr3M4Pycscdd6hRo0Zq3bq1qlatqi+//DLfff/zn/+oZcuWuuKKK3TppZeqevXq6tmzZyHcmjO75ZZbNHr0aI0aNUotW7bUzp07NWDAAJUtW9bnY7Rs2VLvvPOOFixYoKZNm+rhhx/WY489pgEDBkj66yWvCxcu1GWXXab4+Hi9/PLLmj9/vpo0aaLIyEitXr1aV111lRo2bKixY8dq4sSJ6tatW57XVatWLX3yySdav369mjdvrsGDB2vQoEEaO3Zsge6He++9V7feeqsGDBigtm3bKiIiQr169SrQMU9bsmSJOnTooNatW+vjjz/Whx9+6MrvGkUTPSp5Pcrrd1euXDktW7ZMhw4d0oUXXqjrr79enTt31tSpUyVJTz75pFJTUzVjxgxJUvXq1fXqq69q7Nix3rcbTJ8+XUeOHNGll16qGjVqeL/efvtt73UvXrxYLVq08H72XZ8+fdSiRQu9/PLLBbpNKFnoVsnq1tatW/Xbb79pzpw5Odpy4YUXOj5m//79NXnyZE2bNk1NmjTRNddc48orSlBy0K2S1S1Jeu211/T777+rRYsW6tevn4YOHapq1aoV6JhFncf4+0ZOoJB07dpV1atX15tvvun2Ulw1YMAAHT58WB988IHbSwFKLHoEoKihWwCKGroFiQ86h0tOnDihl19+WVdccYVKlSql+fPn63//+5+WL1/u9tIAlDD0CEBRQ7cAFDV0C/lhKAVXnH6p6hNPPKGMjAw1atRI77//vrp06eL20gCUMPQIQFFDtwAUNXQL+eHtewAAAAAAAAg6PugcAAAAAAAAQcdQCgAAAAAAAEHHUAoAAAAAAABBx1AKAAAAAAAAQcdQCgAAAAAAAEHHUAoAAAAAAABBx1AKAAAAAAAAQcdQCgAAAAAAAEHHUAoAAAAAAABB9/8BTFHE/p4+VYsAAAAASUVORK5CYII=",
"text/plain": [
"
\n",
" NOTE: This notebook is still work in progress. While the fine-tuning logic is unfinished, the notebook does demo how one could use the data-module to easily loop over each of the datasets in the benchmarking group and get the prescribed train-test split. Once the fine-tuning logic is finalized, we will finish this notebook and officially provide it as a tutorial within Graphium. \n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4d5af838",
"metadata": {},
"outputs": [],
"source": [
"# First, let's read the yaml configuration file\n",
"with open(\"../expts/configs/config_tdc_admet_demo.yaml\", \"r\") as file:\n",
" config = yaml.load(file, Loader=yaml.FullLoader)"
]
},
{
"cell_type": "markdown",
"id": "7b2047a3",
"metadata": {},
"source": [
"## Get all TDC benchmark names"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "faa5d4b4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"22"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"benchmarks = retrieve_benchmark_names(\"admet_group\")\n",
"len(benchmarks)"
]
},
{
"cell_type": "markdown",
"id": "f80de284",
"metadata": {},
"source": [
"While there is a total of 22, let's just use two for practicality sake: One regression and one classification task! "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3ac3c111",
"metadata": {},
"outputs": [],
"source": [
"benchmarks = [\"caco2_wang\", \"hia_hou\"]"
]
},
{
"cell_type": "markdown",
"id": "3643aa30",
"metadata": {},
"source": [
"## Initialize all training components per task\n",
"**NOTE**: Since we do not have fine-tuning logic, this for now just creates a new model. Ultimately, we will want to use fine-tuning code to evaluate how well the pre-trained model transfers to downstream tasks. "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9538abfb",
"metadata": {},
"outputs": [],
"source": [
"def training_testing_loop(cfg):\n",
" \"\"\"\n",
" Simple loop to train a model from scratch and test it. \n",
" \"\"\"\n",
" \n",
" # Initialize object from config\n",
" cfg, accelerator_type = load_accelerator(cfg)\n",
" datamodule = load_datamodule(cfg, accelerator_type)\n",
" model_class, model_kwargs = load_architecture(cfg, in_dims=datamodule.in_dims)\n",
" metrics = load_metrics(cfg)\n",
" \n",
" # Prepare data\n",
" datamodule.prepare_data()\n",
" \n",
" # Initialize the predictor\n",
" predictor = load_predictor(\n",
" cfg,\n",
" model_class,\n",
" model_kwargs,\n",
" metrics,\n",
" datamodule.get_task_levels(),\n",
" accelerator_type,\n",
" datamodule.featurization,\n",
" datamodule.task_norms\n",
" )\n",
" \n",
" # Initialize the trainer\n",
" date_time_suffix = datetime.now().strftime(\"%d.%m.%Y_%H.%M.%S\")\n",
" trainer = load_trainer(cfg, \"tdc-admet\", accelerator_type, date_time_suffix)\n",
" \n",
" # Train\n",
" predictor.set_max_nodes_edges_per_graph(datamodule, stages=[\"train\", \"val\"])\n",
" trainer.fit(model=predictor, datamodule=datamodule)\n",
" \n",
" # Test\n",
" predictor.set_max_nodes_edges_per_graph(datamodule, stages=[\"test\"])\n",
" results = trainer.test(model=predictor, datamodule=datamodule)\n",
" \n",
" return results\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1ee4586e",
"metadata": {},
"outputs": [],
"source": [
"def filter_cfg_based_on_benchmark_name(config, names: Union[List[str], str]):\n",
" \"\"\"\n",
" Filter a base config for the full TDC ADMET benchmarking group to only \n",
" have settings related to a subset of the endpoints\n",
" \"\"\"\n",
" \n",
" if config[\"datamodule\"][\"module_type\"] != \"ADMETBenchmarkDataModule\":\n",
" raise ValueError(\"You can only use this method for the `ADMETBenchmarkDataModule`\")\n",
" \n",
" if isinstance(names, str):\n",
" names = [names]\n",
" \n",
" def _filter(d):\n",
" return {k: v for k, v in d.items() if k in names}\n",
" \n",
" cfg = deepcopy(config)\n",
" \n",
" # Update the datamodule arguments\n",
" cfg[\"datamodule\"][\"args\"][\"tdc_benchmark_names\"] = names\n",
" \n",
" # Filter the relevant config sections\n",
" cfg[\"architecture\"][\"task_heads\"] = _filter(cfg[\"architecture\"][\"task_heads\"])\n",
" cfg[\"predictor\"][\"metrics_on_progress_bar\"] = _filter(cfg[\"predictor\"][\"metrics_on_progress_bar\"])\n",
" cfg[\"predictor\"][\"loss_fun\"] = _filter(cfg[\"predictor\"][\"loss_fun\"])\n",
" cfg[\"metrics\"] = _filter(cfg[\"metrics\"])\n",
" \n",
" return cfg"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "96bc606a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2023-07-13 14:02:22.438\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mgraphium.data.datamodule\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m2425\u001b[0m - \u001b[1mPreparing the TDC ADMET Benchmark Group splits for each of the 1 benchmarks.\u001b[0m\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mcwognum\u001b[0m (\u001b[33mvalence-ood\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Path logs/tdc-admet-demo/wandb/ wasn't writable, using system temp directory.\n",
"wandb: WARNING Path logs/tdc-admet-demo/wandb/ wasn't writable, using system temp directory\n"
]
},
{
"data": {
"text/html": [
"Tracking run with wandb version 0.15.5"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Run data is saved locally in /tmp/wandb/run-20230713_140223-iablpj4z"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Syncing run tdc_admet_demo_13.07.2023_14.02.22 to Weights & Biases (docs) "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View project at https://wandb.ai/valence-ood/tdc_admet_demo"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run at https://wandb.ai/valence-ood/tdc_admet_demo/runs/iablpj4z"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"\u001b[32m2023-07-13 14:02:28.065\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mgraphium.data.datamodule\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m1168\u001b[0m - \u001b[1m-------------------\n",
"MultitaskDataset\n",
"\tabout = training set\n",
"\tnum_graphs_total = 634\n",
"\tnum_nodes_total = 18351\n",
"\tmax_num_nodes_per_graph = 67\n",
"\tmin_num_nodes_per_graph = 2\n",
"\tstd_num_nodes_per_graph = 11.441048173903344\n",
"\tmean_num_nodes_per_graph = 28.944794952681388\n",
"\tnum_edges_total = 39466\n",
"\tmax_num_edges_per_graph = 144\n",
"\tmin_num_edges_per_graph = 2\n",
"\tstd_num_edges_per_graph = 25.32877270037731\n",
"\tmean_num_edges_per_graph = 62.24921135646688\n",
"-------------------\n",
"\u001b[0m\n",
"\u001b[32m2023-07-13 14:02:28.065\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mgraphium.data.datamodule\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m1169\u001b[0m - \u001b[1m-------------------\n",
"MultitaskDataset\n",
"\tabout = validation set\n",
"\tnum_graphs_total = 91\n",
"\tnum_nodes_total = 2791\n",
"\tmax_num_nodes_per_graph = 67\n",
"\tmin_num_nodes_per_graph = 13\n",
"\tstd_num_nodes_per_graph = 11.461406761225364\n",
"\tmean_num_nodes_per_graph = 30.67032967032967\n",
"\tnum_edges_total = 6034\n",
"\tmax_num_edges_per_graph = 148\n",
"\tmin_num_edges_per_graph = 28\n",
"\tstd_num_edges_per_graph = 25.619228830769817\n",
"\tmean_num_edges_per_graph = 66.3076923076923\n",
"-------------------\n",
"\u001b[0m\n",
"\u001b[32m2023-07-13 14:02:28.066\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mgraphium.data.datamodule\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m1186\u001b[0m - \u001b[1m-------------------\n",
"MultitaskDataset\n",
"\tabout = test set\n",
"\tnum_graphs_total = 181\n",
"\tnum_nodes_total = 5503\n",
"\tmax_num_nodes_per_graph = 68\n",
"\tmin_num_nodes_per_graph = 8\n",
"\tstd_num_nodes_per_graph = 9.244248028261115\n",
"\tmean_num_nodes_per_graph = 30.403314917127073\n",
"\tnum_edges_total = 11736\n",
"\tmax_num_edges_per_graph = 144\n",
"\tmin_num_edges_per_graph = 16\n",
"\tstd_num_edges_per_graph = 19.961337435561646\n",
"\tmean_num_edges_per_graph = 64.83977900552486\n",
"-------------------\n",
"\u001b[0m\n",
"\u001b[32m2023-07-13 14:02:28.066\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mgraphium.data.datamodule\u001b[0m:\u001b[36mget_max_num_nodes_datamodule\u001b[0m:\u001b[36m556\u001b[0m - \u001b[1mMax num nodes being calcuated train\u001b[0m\n",
"\u001b[32m2023-07-13 14:02:28.067\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mgraphium.data.datamodule\u001b[0m:\u001b[36mget_max_num_nodes_datamodule\u001b[0m:\u001b[36m565\u001b[0m - \u001b[1mMax num nodes being calcuated val\u001b[0m\n",
"\u001b[32m2023-07-13 14:02:28.067\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mgraphium.data.datamodule\u001b[0m:\u001b[36mprepare_data\u001b[0m:\u001b[36m943\u001b[0m - \u001b[1mData is already prepared. Skipping the preparation\u001b[0m\n",
"You are using a CUDA device ('NVIDIA GeForce RTX 3070') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
"\u001b[32m2023-07-13 14:02:28.069\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mgraphium.data.datamodule\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m1168\u001b[0m - \u001b[1m-------------------\n",
"MultitaskDataset\n",
"\tabout = training set\n",
"\tnum_graphs_total = 634\n",
"\tnum_nodes_total = 18351\n",
"\tmax_num_nodes_per_graph = 67\n",
"\tmin_num_nodes_per_graph = 2\n",
"\tstd_num_nodes_per_graph = 11.441048173903344\n",
"\tmean_num_nodes_per_graph = 28.944794952681388\n",
"\tnum_edges_total = 39466\n",
"\tmax_num_edges_per_graph = 144\n",
"\tmin_num_edges_per_graph = 2\n",
"\tstd_num_edges_per_graph = 25.32877270037731\n",
"\tmean_num_edges_per_graph = 62.24921135646688\n",
"-------------------\n",
"\u001b[0m\n",
"\u001b[32m2023-07-13 14:02:28.069\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mgraphium.data.datamodule\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m1169\u001b[0m - \u001b[1m-------------------\n",
"MultitaskDataset\n",
"\tabout = validation set\n",
"\tnum_graphs_total = 91\n",
"\tnum_nodes_total = 2791\n",
"\tmax_num_nodes_per_graph = 67\n",
"\tmin_num_nodes_per_graph = 13\n",
"\tstd_num_nodes_per_graph = 11.461406761225364\n",
"\tmean_num_nodes_per_graph = 30.67032967032967\n",
"\tnum_edges_total = 6034\n",
"\tmax_num_edges_per_graph = 148\n",
"\tmin_num_edges_per_graph = 28\n",
"\tstd_num_edges_per_graph = 25.619228830769817\n",
"\tmean_num_edges_per_graph = 66.3076923076923\n",
"-------------------\n",
"\u001b[0m\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"\n",
" | Name | Type | Params\n",
"----------------------------------------------------\n",
"0 | model | FullGraphMultiTaskNetwork | 111 K \n",
"----------------------------------------------------\n",
"111 K Trainable params\n",
"0 Non-trainable params\n",
"111 K Total params\n",
"0.448 Total estimated model params size (MB)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Sanity Checking: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a99952d7025f418bae86c250887d7b60",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=10` reached.\n",
"\u001b[32m2023-07-13 14:02:49.408\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mgraphium.data.datamodule\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m1168\u001b[0m - \u001b[1m-------------------\n",
"MultitaskDataset\n",
"\tabout = training set\n",
"\tnum_graphs_total = 634\n",
"\tnum_nodes_total = 18351\n",
"\tmax_num_nodes_per_graph = 67\n",
"\tmin_num_nodes_per_graph = 2\n",
"\tstd_num_nodes_per_graph = 11.441048173903344\n",
"\tmean_num_nodes_per_graph = 28.944794952681388\n",
"\tnum_edges_total = 39466\n",
"\tmax_num_edges_per_graph = 144\n",
"\tmin_num_edges_per_graph = 2\n",
"\tstd_num_edges_per_graph = 25.32877270037731\n",
"\tmean_num_edges_per_graph = 62.24921135646688\n",
"-------------------\n",
"\u001b[0m\n",
"\u001b[32m2023-07-13 14:02:49.409\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mgraphium.data.datamodule\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m1169\u001b[0m - \u001b[1m-------------------\n",
"MultitaskDataset\n",
"\tabout = validation set\n",
"\tnum_graphs_total = 91\n",
"\tnum_nodes_total = 2791\n",
"\tmax_num_nodes_per_graph = 67\n",
"\tmin_num_nodes_per_graph = 13\n",
"\tstd_num_nodes_per_graph = 11.461406761225364\n",
"\tmean_num_nodes_per_graph = 30.67032967032967\n",
"\tnum_edges_total = 6034\n",
"\tmax_num_edges_per_graph = 148\n",
"\tmin_num_edges_per_graph = 28\n",
"\tstd_num_edges_per_graph = 25.619228830769817\n",
"\tmean_num_edges_per_graph = 66.3076923076923\n",
"-------------------\n",
"\u001b[0m\n",
"\u001b[32m2023-07-13 14:02:49.410\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mgraphium.data.datamodule\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m1186\u001b[0m - \u001b[1m-------------------\n",
"MultitaskDataset\n",
"\tabout = test set\n",
"\tnum_graphs_total = 181\n",
"\tnum_nodes_total = 5503\n",
"\tmax_num_nodes_per_graph = 68\n",
"\tmin_num_nodes_per_graph = 8\n",
"\tstd_num_nodes_per_graph = 9.244248028261115\n",
"\tmean_num_nodes_per_graph = 30.403314917127073\n",
"\tnum_edges_total = 11736\n",
"\tmax_num_edges_per_graph = 144\n",
"\tmin_num_edges_per_graph = 16\n",
"\tstd_num_edges_per_graph = 19.961337435561646\n",
"\tmean_num_edges_per_graph = 64.83977900552486\n",
"-------------------\n",
"\u001b[0m\n",
"\u001b[32m2023-07-13 14:02:49.410\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mgraphium.data.datamodule\u001b[0m:\u001b[36mget_max_num_nodes_datamodule\u001b[0m:\u001b[36m574\u001b[0m - \u001b[1mMax num nodes being calcuated test\u001b[0m\n",
"\u001b[32m2023-07-13 14:02:49.411\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mgraphium.data.datamodule\u001b[0m:\u001b[36mprepare_data\u001b[0m:\u001b[36m943\u001b[0m - \u001b[1mData is already prepared. Skipping the preparation\u001b[0m\n",
"You are using a CUDA device ('NVIDIA GeForce RTX 3070') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
"\u001b[32m2023-07-13 14:02:49.412\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mgraphium.data.datamodule\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m1186\u001b[0m - \u001b[1m-------------------\n",
"MultitaskDataset\n",
"\tabout = test set\n",
"\tnum_graphs_total = 181\n",
"\tnum_nodes_total = 5503\n",
"\tmax_num_nodes_per_graph = 68\n",
"\tmin_num_nodes_per_graph = 8\n",
"\tstd_num_nodes_per_graph = 9.244248028261115\n",
"\tmean_num_nodes_per_graph = 30.403314917127073\n",
"\tnum_edges_total = 11736\n",
"\tmax_num_edges_per_graph = 144\n",
"\tmin_num_edges_per_graph = 16\n",
"\tstd_num_edges_per_graph = 19.961337435561646\n",
"\tmean_num_edges_per_graph = 64.83977900552486\n",
"-------------------\n",
"\u001b[0m\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e92ded2754004662a47e2fb8cb8112be",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Testing: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"